├── .github └── ISSUE_TEMPLATE │ └── bug_report.md ├── .gitignore ├── .gitmodules ├── .nojekyll ├── LICENSE ├── README.md ├── apex ├── __init__.py ├── _autocast_utils.py ├── contrib │ ├── __init__.py │ ├── bottleneck │ │ ├── __init__.py │ │ ├── bottleneck.py │ │ ├── halo_exchangers.py │ │ └── test.py │ ├── clip_grad │ │ ├── __init__.py │ │ └── clip_grad.py │ ├── conv_bias_relu │ │ ├── __init__.py │ │ └── conv_bias_relu.py │ ├── csrc │ │ ├── bottleneck │ │ │ └── bottleneck.cpp │ │ ├── conv_bias_relu │ │ │ └── conv_bias_relu.cpp │ │ ├── cudnn_gbn │ │ │ ├── cudnn_gbn.cpp │ │ │ ├── norm_sample.cpp │ │ │ └── norm_sample.h │ │ ├── fmha │ │ │ ├── fmha_api.cpp │ │ │ └── src │ │ │ │ ├── fmha.h │ │ │ │ ├── fmha │ │ │ │ ├── gemm.h │ │ │ │ ├── gmem_tile.h │ │ │ │ ├── kernel_traits.h │ │ │ │ ├── mask.h │ │ │ │ ├── smem_tile.h │ │ │ │ ├── softmax.h │ │ │ │ └── utils.h │ │ │ │ ├── fmha_dgrad_fp16_128_64_kernel.sm80.cu │ │ │ │ ├── fmha_dgrad_fp16_256_64_kernel.sm80.cu │ │ │ │ ├── fmha_dgrad_fp16_384_64_kernel.sm80.cu │ │ │ │ ├── fmha_dgrad_fp16_512_64_kernel.sm80.cu │ │ │ │ ├── fmha_dgrad_kernel_1xN_reload.h │ │ │ │ ├── fmha_dgrad_kernel_1xN_reload_nl.h │ │ │ │ ├── fmha_fill.cu │ │ │ │ ├── fmha_fprop_fp16_128_64_kernel.sm80.cu │ │ │ │ ├── fmha_fprop_fp16_256_64_kernel.sm80.cu │ │ │ │ ├── fmha_fprop_fp16_384_64_kernel.sm80.cu │ │ │ │ ├── fmha_fprop_fp16_512_64_kernel.sm80.cu │ │ │ │ ├── fmha_fprop_kernel_1xN.h │ │ │ │ ├── fmha_kernel.h │ │ │ │ ├── fmha_noloop_reduce.cu │ │ │ │ └── fmha_utils.h │ │ ├── focal_loss │ │ │ ├── focal_loss_cuda.cpp │ │ │ └── focal_loss_cuda_kernel.cu │ │ ├── gpu_direct_storage │ │ │ ├── gds.cpp │ │ │ ├── gds.h │ │ │ └── gds_pybind.cpp │ │ ├── group_norm │ │ │ ├── group_norm_nhwc.cpp │ │ │ ├── group_norm_nhwc.h │ │ │ ├── group_norm_nhwc_bwd_one_pass.h │ │ │ ├── group_norm_nhwc_bwd_one_pass_kernel.cuh │ │ │ ├── group_norm_nhwc_bwd_two_pass.cu │ │ │ ├── group_norm_nhwc_fwd_one_pass.h │ │ │ ├── group_norm_nhwc_fwd_one_pass_kernel.cuh │ │ │ ├── group_norm_nhwc_fwd_two_pass.cu │ │ │ ├── group_norm_nhwc_one_pass_10.cu │ │ │ ├── group_norm_nhwc_one_pass_112.cu │ │ │ ├── group_norm_nhwc_one_pass_12.cu │ │ │ ├── group_norm_nhwc_one_pass_120.cu │ │ │ ├── group_norm_nhwc_one_pass_128.cu │ │ │ ├── group_norm_nhwc_one_pass_14.cu │ │ │ ├── group_norm_nhwc_one_pass_16.cu │ │ │ ├── group_norm_nhwc_one_pass_160.cu │ │ │ ├── group_norm_nhwc_one_pass_20.cu │ │ │ ├── group_norm_nhwc_one_pass_24.cu │ │ │ ├── group_norm_nhwc_one_pass_26.cu │ │ │ ├── group_norm_nhwc_one_pass_28.cu │ │ │ ├── group_norm_nhwc_one_pass_30.cu │ │ │ ├── group_norm_nhwc_one_pass_32.cu │ │ │ ├── group_norm_nhwc_one_pass_4.cu │ │ │ ├── group_norm_nhwc_one_pass_40.cu │ │ │ ├── group_norm_nhwc_one_pass_42.cu │ │ │ ├── group_norm_nhwc_one_pass_48.cu │ │ │ ├── group_norm_nhwc_one_pass_56.cu │ │ │ ├── group_norm_nhwc_one_pass_60.cu │ │ │ ├── group_norm_nhwc_one_pass_64.cu │ │ │ ├── group_norm_nhwc_one_pass_70.cu │ │ │ ├── group_norm_nhwc_one_pass_8.cu │ │ │ ├── group_norm_nhwc_one_pass_80.cu │ │ │ ├── group_norm_nhwc_one_pass_84.cu │ │ │ ├── group_norm_nhwc_one_pass_96.cu │ │ │ ├── group_norm_nhwc_one_pass_98.cu │ │ │ ├── group_norm_nhwc_op.cpp │ │ │ ├── macros.h │ │ │ └── traits.h │ │ ├── group_norm_v2 │ │ │ ├── generate_gn_cuda_inst.py │ │ │ ├── gn.cpp │ │ │ ├── gn.hpp │ │ │ ├── gn_cuda.cu │ │ │ ├── gn_cuda_host_template.cuh │ │ │ ├── gn_cuda_inst_1024_1280.cu │ │ │ ├── gn_cuda_inst_1024_1920.cu │ │ │ ├── gn_cuda_inst_1024_320.cu │ │ │ ├── gn_cuda_inst_1024_640.cu │ │ │ ├── gn_cuda_inst_1024_960.cu │ │ │ ├── gn_cuda_inst_256_1280.cu │ │ │ ├── gn_cuda_inst_256_1920.cu │ │ │ ├── gn_cuda_inst_256_2560.cu │ │ │ ├── gn_cuda_inst_256_640.cu │ │ │ ├── gn_cuda_inst_4096_320.cu │ │ │ ├── gn_cuda_inst_4096_640.cu │ │ │ ├── gn_cuda_inst_4096_960.cu │ │ │ ├── gn_cuda_inst_64_1280.cu │ │ │ ├── gn_cuda_inst_64_2560.cu │ │ │ ├── gn_cuda_kernel.cuh │ │ │ ├── gn_dispatch_hw_c.hpp │ │ │ ├── gn_utils.cpp │ │ │ └── gn_utils.hpp │ │ ├── groupbn │ │ │ ├── batch_norm.cu │ │ │ ├── batch_norm.h │ │ │ ├── batch_norm_add_relu.cu │ │ │ ├── batch_norm_add_relu.h │ │ │ ├── cuda_utils.h │ │ │ ├── interface.cpp │ │ │ ├── ipc.cu │ │ │ └── nhwc_batch_norm_kernel.h │ │ ├── index_mul_2d │ │ │ ├── index_mul_2d_cuda.cpp │ │ │ └── index_mul_2d_cuda_kernel.cu │ │ ├── layer_norm │ │ │ ├── ln.h │ │ │ ├── ln_api.cpp │ │ │ ├── ln_bwd_kernels.cuh │ │ │ ├── ln_bwd_semi_cuda_kernel.cu │ │ │ ├── ln_fwd_cuda_kernel.cu │ │ │ ├── ln_fwd_kernels.cuh │ │ │ ├── ln_kernel_traits.h │ │ │ └── ln_utils.cuh │ │ ├── multihead_attn │ │ │ ├── additive_masked_softmax_dropout_cuda.cu │ │ │ ├── dropout.cuh │ │ │ ├── encdec_multihead_attn_cuda.cu │ │ │ ├── encdec_multihead_attn_norm_add_cuda.cu │ │ │ ├── layer_norm.cuh │ │ │ ├── masked_softmax_dropout_cuda.cu │ │ │ ├── multihead_attn_frontend.cpp │ │ │ ├── philox.cuh │ │ │ ├── self_multihead_attn_bias_additive_mask_cuda.cu │ │ │ ├── self_multihead_attn_bias_cuda.cu │ │ │ ├── self_multihead_attn_cuda.cu │ │ │ ├── self_multihead_attn_norm_add_cuda.cu │ │ │ ├── softmax.cuh │ │ │ └── strided_batched_gemm.cuh │ │ ├── nccl_allocator │ │ │ └── NCCLAllocator.cpp │ │ ├── nccl_p2p │ │ │ ├── nccl_p2p.cpp │ │ │ ├── nccl_p2p_cuda.cu │ │ │ ├── nccl_p2p_cuda.cuh │ │ │ ├── nccl_version.cpp │ │ │ └── nccl_version_check.cu │ │ ├── optimizers │ │ │ ├── fused_adam_cuda.cpp │ │ │ ├── fused_adam_cuda_kernel.cu │ │ │ ├── fused_lamb_cuda.cpp │ │ │ ├── fused_lamb_cuda_kernel.cu │ │ │ ├── multi_tensor_distopt_adam.cpp │ │ │ ├── multi_tensor_distopt_adam_kernel.cu │ │ │ ├── multi_tensor_distopt_lamb.cpp │ │ │ └── multi_tensor_distopt_lamb_kernel.cu │ │ ├── peer_memory │ │ │ ├── peer_memory.cpp │ │ │ ├── peer_memory_cuda.cu │ │ │ └── peer_memory_cuda.cuh │ │ ├── transducer │ │ │ ├── transducer_joint.cpp │ │ │ ├── transducer_joint_kernel.cu │ │ │ ├── transducer_loss.cpp │ │ │ └── transducer_loss_kernel.cu │ │ └── xentropy │ │ │ ├── interface.cpp │ │ │ └── xentropy_kernel.cu │ ├── cudnn_gbn │ │ ├── __init__.py │ │ └── batch_norm.py │ ├── examples │ │ ├── gpu_direct_storage │ │ │ ├── benchmark_load.py │ │ │ ├── benchmark_save.py │ │ │ ├── example_load.py │ │ │ └── example_save.py │ │ ├── multihead_attn │ │ │ ├── func_test_multihead_attn.py │ │ │ └── perf_test_multihead_attn.py │ │ └── nccl_allocator │ │ │ ├── allreduce.py │ │ │ ├── cache.py │ │ │ ├── change_cuda_allocator.py │ │ │ └── toy_ddp.py │ ├── fmha │ │ ├── __init__.py │ │ └── fmha.py │ ├── focal_loss │ │ ├── __init__.py │ │ └── focal_loss.py │ ├── gpu_direct_storage │ │ ├── README.md │ │ └── __init__.py │ ├── group_norm │ │ ├── __init__.py │ │ └── group_norm.py │ ├── groupbn │ │ ├── __init__.py │ │ └── batch_norm.py │ ├── index_mul_2d │ │ ├── __init__.py │ │ └── index_mul_2d.py │ ├── layer_norm │ │ ├── __init__.py │ │ └── layer_norm.py │ ├── multihead_attn │ │ ├── MHA_bwd.png │ │ ├── MHA_fwd.png │ │ ├── README.md │ │ ├── __init__.py │ │ ├── encdec_multihead_attn.py │ │ ├── encdec_multihead_attn_func.py │ │ ├── fast_encdec_multihead_attn_func.py │ │ ├── fast_encdec_multihead_attn_norm_add_func.py │ │ ├── fast_self_multihead_attn_func.py │ │ ├── fast_self_multihead_attn_norm_add_func.py │ │ ├── mask_softmax_dropout_func.py │ │ ├── self_multihead_attn.py │ │ └── self_multihead_attn_func.py │ ├── nccl_allocator │ │ ├── README.md │ │ ├── __init__.py │ │ └── nccl_allocator.py │ ├── openfold_triton │ │ ├── README.md │ │ ├── __init__.py │ │ ├── _layer_norm_backward_kernels.py │ │ ├── _layer_norm_config_ampere.py │ │ ├── _layer_norm_config_hopper.py │ │ ├── _layer_norm_forward_kernels.py │ │ ├── _mha_kernel.py │ │ ├── fused_adam_swa.py │ │ ├── layer_norm.py │ │ └── mha.py │ ├── optimizers │ │ ├── __init__.py │ │ ├── distributed_fused_adam.py │ │ ├── distributed_fused_lamb.py │ │ ├── fp16_optimizer.py │ │ ├── fused_adam.py │ │ ├── fused_lamb.py │ │ └── fused_sgd.py │ ├── peer_memory │ │ ├── __init__.py │ │ ├── peer_halo_exchanger_1d.py │ │ └── peer_memory.py │ ├── sparsity │ │ ├── COPYRIGHT │ │ ├── README.md │ │ ├── __init__.py │ │ ├── asp.py │ │ ├── permutation_lib.py │ │ ├── permutation_search_kernels │ │ │ ├── CUDA_kernels │ │ │ │ └── permutation_search_kernels.cu │ │ │ ├── __init__.py │ │ │ ├── call_permutation_search_kernels.py │ │ │ ├── channel_swap.py │ │ │ ├── exhaustive_search.py │ │ │ └── permutation_utilities.py │ │ ├── permutation_tests │ │ │ ├── README.md │ │ │ ├── ablation_studies.sh │ │ │ ├── permutation_test.py │ │ │ ├── runtime_table.sh │ │ │ └── unstructured_study.sh │ │ ├── sparse_masklib.py │ │ └── test │ │ │ ├── checkpointing_test_part1.py │ │ │ ├── checkpointing_test_part2.py │ │ │ ├── checkpointing_test_reference.py │ │ │ ├── test_permutation_application.py │ │ │ └── toy_problem.py │ ├── test │ │ ├── __init__.py │ │ ├── bottleneck │ │ │ ├── __init__.py │ │ │ └── test_bottleneck_module.py │ │ ├── clip_grad │ │ │ ├── __init__.py │ │ │ └── test_clip_grad.py │ │ ├── conv_bias_relu │ │ │ ├── __init__.py │ │ │ └── test_conv_bias_relu.py │ │ ├── cudnn_gbn │ │ │ ├── __init__.py │ │ │ └── test_cudnn_gbn_with_two_gpus.py │ │ ├── fmha │ │ │ ├── __init__.py │ │ │ └── test_fmha.py │ │ ├── focal_loss │ │ │ ├── __init__.py │ │ │ └── test_focal_loss.py │ │ ├── fused_dense │ │ │ └── test_fused_dense.py │ │ ├── group_norm │ │ │ ├── __init__.py │ │ │ └── test_group_norm.py │ │ ├── index_mul_2d │ │ │ ├── __init__.py │ │ │ └── test_index_mul_2d.py │ │ ├── layer_norm │ │ │ ├── __init__.py │ │ │ └── test_fast_layer_norm.py │ │ ├── multihead_attn │ │ │ ├── __init__.py │ │ │ ├── test_encdec_multihead_attn.py │ │ │ ├── test_encdec_multihead_attn_norm_add.py │ │ │ ├── test_fast_self_multihead_attn_bias.py │ │ │ ├── test_mha_fused_softmax.py │ │ │ ├── test_self_multihead_attn.py │ │ │ └── test_self_multihead_attn_norm_add.py │ │ ├── openfold_triton │ │ │ ├── test_fused_adam_swa.py │ │ │ ├── test_openfold_mha.py │ │ │ └── test_sync_triton_auto_tune_cache_across_gpus.py │ │ ├── optimizers │ │ │ ├── __init__.py │ │ │ ├── test_dist_adam.py │ │ │ └── test_distributed_fused_lamb.py │ │ ├── peer_memory │ │ │ ├── __init__.py │ │ │ └── test_peer_halo_exchange_module.py │ │ ├── transducer │ │ │ ├── __init__.py │ │ │ ├── test_transducer_joint.py │ │ │ └── test_transducer_loss.py │ │ └── xentropy │ │ │ ├── __init__.py │ │ │ └── test_label_smoothing.py │ ├── torchsched │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── config.py │ │ ├── inductor │ │ │ ├── __init__.py │ │ │ ├── _utils.py │ │ │ ├── event.py │ │ │ ├── graph.py │ │ │ ├── scheduler.py │ │ │ └── wrapper.py │ │ ├── ops │ │ │ ├── __init__.py │ │ │ └── layer_norm.py │ │ └── passes │ │ │ ├── __init__.py │ │ │ └── pre_grad_passes.py │ ├── transducer │ │ ├── __init__.py │ │ ├── _transducer_ref.py │ │ └── transducer.py │ └── xentropy │ │ ├── __init__.py │ │ └── softmax_xentropy.py ├── fused_dense │ ├── __init__.py │ └── fused_dense.py ├── mlp │ ├── __init__.py │ └── mlp.py ├── multi_tensor_apply │ ├── __init__.py │ └── multi_tensor_apply.py ├── normalization │ ├── __init__.py │ └── fused_layer_norm.py ├── optimizers │ ├── __init__.py │ ├── fused_adagrad.py │ ├── fused_adam.py │ ├── fused_lamb.py │ ├── fused_mixed_precision_lamb.py │ ├── fused_novograd.py │ └── fused_sgd.py └── transformer │ ├── README.md │ ├── __init__.py │ ├── _data │ ├── __init__.py │ └── _batchsampler.py │ ├── _ucc_util.py │ ├── amp │ ├── __init__.py │ └── grad_scaler.py │ ├── enums.py │ ├── functional │ ├── __init__.py │ ├── fused_rope.py │ └── fused_softmax.py │ ├── layers │ ├── __init__.py │ └── layer_norm.py │ ├── log_util.py │ ├── microbatches.py │ ├── parallel_state.py │ ├── pipeline_parallel │ ├── __init__.py │ ├── _timers.py │ ├── p2p_communication.py │ ├── schedules │ │ ├── __init__.py │ │ ├── common.py │ │ ├── fwd_bwd_no_pipelining.py │ │ ├── fwd_bwd_pipelining_with_interleaving.py │ │ └── fwd_bwd_pipelining_without_interleaving.py │ └── utils.py │ ├── tensor_parallel │ ├── __init__.py │ ├── cross_entropy.py │ ├── data.py │ ├── layers.py │ ├── mappings.py │ ├── memory.py │ ├── random.py │ └── utils.py │ ├── testing │ ├── __init__.py │ ├── arguments.py │ ├── commons.py │ ├── distributed_test_base.py │ ├── global_vars.py │ ├── standalone_bert.py │ ├── standalone_gpt.py │ └── standalone_transformer_lm.py │ └── utils.py ├── csrc ├── amp_C_frontend.cpp ├── compat.h ├── flatten_unflatten.cpp ├── fused_dense.cpp ├── fused_dense_cuda.cu ├── layer_norm_cuda.cpp ├── layer_norm_cuda_kernel.cu ├── megatron │ ├── fused_rotary_positional_embedding.cpp │ ├── fused_rotary_positional_embedding.h │ ├── fused_rotary_positional_embedding_cuda.cu │ ├── fused_weight_gradient_dense.cpp │ ├── fused_weight_gradient_dense_16bit_prec_cuda.cu │ ├── fused_weight_gradient_dense_cuda.cu │ ├── generic_scaled_masked_softmax.cpp │ ├── generic_scaled_masked_softmax.h │ ├── generic_scaled_masked_softmax_cuda.cu │ ├── scaled_masked_softmax.cpp │ ├── scaled_masked_softmax.h │ ├── scaled_masked_softmax_cuda.cu │ ├── scaled_softmax.cpp │ ├── scaled_softmax_cuda.cu │ ├── scaled_upper_triang_masked_softmax.cpp │ ├── scaled_upper_triang_masked_softmax.h │ └── scaled_upper_triang_masked_softmax_cuda.cu ├── mlp.cpp ├── mlp_cuda.cu ├── multi_tensor_adagrad.cu ├── multi_tensor_adam.cu ├── multi_tensor_apply.cuh ├── multi_tensor_axpby_kernel.cu ├── multi_tensor_l2norm_kernel.cu ├── multi_tensor_l2norm_kernel_mp.cu ├── multi_tensor_l2norm_scale_kernel.cu ├── multi_tensor_lamb.cu ├── multi_tensor_lamb_mp.cu ├── multi_tensor_lamb_stage_1.cu ├── multi_tensor_lamb_stage_2.cu ├── multi_tensor_novograd.cu ├── multi_tensor_scale_kernel.cu ├── multi_tensor_sgd_kernel.cu ├── static_switch.h ├── syncbn.cpp ├── type_shim.h ├── update_scale_hysteresis.cu └── welford.cu ├── docs ├── Makefile └── source │ ├── _static │ ├── css │ │ └── pytorch_theme.css │ └── img │ │ └── nv-pytorch2.png │ ├── _templates │ └── layout.html │ ├── conf.py │ ├── index.rst │ ├── layernorm.rst │ └── optimizers.rst ├── examples ├── README.md ├── dcgan │ ├── README.md │ └── main_amp.py ├── docker │ ├── Dockerfile │ └── README.md ├── imagenet │ ├── README.md │ └── main_amp.py └── simple │ └── distributed │ ├── README.md │ ├── distributed_data_parallel.py │ └── run.sh ├── pyproject.toml ├── requirements.txt ├── requirements_dev.txt ├── setup.py └── tests ├── L0 ├── run_fused_layer_norm │ └── test_fused_layer_norm.py ├── run_mlp │ └── test_mlp.py ├── run_optimizers │ ├── __init__.py │ ├── test_adam.py │ ├── test_fused_novograd.py │ ├── test_fused_optimizer.py │ └── test_lamb.py ├── run_test.py └── run_transformer │ ├── __init__.py │ ├── gpt_scaling_test.py │ ├── test_batch_sampler.py │ ├── test_bert_minimal.py │ ├── test_cross_entropy.py │ ├── test_data.py │ ├── test_dynamic_batchsize.py │ ├── test_fused_rope.py │ ├── test_fused_softmax.py │ ├── test_gpt_minimal.py │ ├── test_layers.py │ ├── test_mapping.py │ ├── test_microbatches.py │ ├── test_p2p_comm.py │ ├── test_parallel_state.py │ ├── test_pipeline_parallel_fwd_bwd.py │ ├── test_random.py │ └── test_transformer_utils.py ├── L1 ├── common │ ├── compare.py │ ├── main_amp.py │ └── run_test.sh ├── cross_product │ └── run.sh ├── cross_product_distributed │ └── run.sh └── transformer │ └── pipeline_parallel_fwd_bwd_ucc_async.py ├── distributed ├── DDP │ ├── ddp_race_condition_test.py │ └── run_race_test.sh ├── amp_master_params │ ├── amp_master_params.py │ ├── compare.py │ └── run.sh └── synced_batchnorm │ ├── python_single_gpu_unit_test.py │ ├── single_gpu_unit_test.py │ ├── test_batchnorm1d.py │ ├── test_groups.py │ ├── two_gpu_test_different_batch_size.py │ ├── two_gpu_unit_test.py │ └── unit_test.sh └── docker_extension_builds └── run.sh /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve apex 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the Bug** 11 | 12 | **Minimal Steps/Code to Reproduce the Bug** 13 | 18 | 19 | **Expected Behavior** 20 | 21 | 22 | **Environment** 23 | 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | apex.egg-info 2 | dist 3 | build 4 | docs/build 5 | *~ 6 | __pycache__ 7 | .vscode 8 | 9 | # Copied from https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Cython debug symbols 147 | cython_debug/ 148 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "apex/contrib/csrc/multihead_attn/cutlass"] 2 | path = apex/contrib/csrc/multihead_attn/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | branch = v1.2.0 5 | [submodule "apex/contrib/csrc/cudnn-frontend"] 6 | path = apex/contrib/csrc/cudnn-frontend 7 | url = https://github.com/NVIDIA/cudnn-frontend.git 8 | -------------------------------------------------------------------------------- /.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/.nojekyll -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /apex/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | 4 | # May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten 5 | import torch 6 | 7 | # For optimizers and normalization there is no Python fallback. 8 | # Absence of cuda backend is a hard error. 9 | # I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda 10 | # to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext 11 | # so they expect those backends to be available, but for some reason they actually aren't 12 | # available (for example because they built improperly in a way that isn't revealed until 13 | # load time) the error message is timely and visible. 14 | from . import optimizers 15 | from . import normalization 16 | from . import transformer 17 | 18 | 19 | __all__ = ["optimizers", "normalization", "transformer"] 20 | 21 | 22 | # Logging utilities for apex.transformer module 23 | class RankInfoFormatter(logging.Formatter): 24 | 25 | def format(self, record): 26 | from apex.transformer.parallel_state import get_rank_info 27 | record.rank_info = get_rank_info() 28 | return super().format(record) 29 | 30 | 31 | _library_root_logger = logging.getLogger(__name__) 32 | handler = logging.StreamHandler() 33 | handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S")) 34 | _library_root_logger.addHandler(handler) 35 | _library_root_logger.propagate = False 36 | 37 | 38 | def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool: 39 | cudnn_available = torch.backends.cudnn.is_available() 40 | cudnn_version = torch.backends.cudnn.version() if cudnn_available else None 41 | if not (cudnn_available and (cudnn_version >= required_cudnn_version)): 42 | warnings.warn( 43 | f"`{global_option}` depends on cuDNN {required_cudnn_version} or later, " 44 | f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}" 45 | ) 46 | return False 47 | return True 48 | 49 | 50 | class DeprecatedFeatureWarning(FutureWarning): 51 | pass 52 | 53 | 54 | def deprecated_warning(msg: str) -> None: 55 | if ( 56 | not torch.distributed.is_available 57 | or not torch.distributed.is_initialized() 58 | or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) 59 | ): 60 | warnings.warn(msg, DeprecatedFeatureWarning) 61 | -------------------------------------------------------------------------------- /apex/_autocast_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence 2 | 3 | import torch 4 | 5 | 6 | __all__ = ["_cast_if_autocast_enabled"] 7 | 8 | 9 | def _get_autocast_dtypes() -> Sequence[torch.dtype]: 10 | if torch.cuda.is_bf16_supported(): 11 | return [torch.half, torch.bfloat16] 12 | return [torch.half] 13 | 14 | 15 | def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype: 16 | if not torch.is_autocast_enabled(): 17 | return torch.float or dtype 18 | else: 19 | return torch.get_autocast_gpu_dtype() 20 | 21 | 22 | def _cast_if_autocast_enabled(*args): 23 | if not torch.is_autocast_enabled(): 24 | return args 25 | else: 26 | return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype()) 27 | -------------------------------------------------------------------------------- /apex/contrib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/__init__.py -------------------------------------------------------------------------------- /apex/contrib/bottleneck/__init__.py: -------------------------------------------------------------------------------- 1 | from .bottleneck import Bottleneck, SpatialBottleneck 2 | from .halo_exchangers import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer 3 | -------------------------------------------------------------------------------- /apex/contrib/clip_grad/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_grad import clip_grad_norm_ 2 | -------------------------------------------------------------------------------- /apex/contrib/conv_bias_relu/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU, ConvFrozenScaleBiasReLU 2 | 3 | -------------------------------------------------------------------------------- /apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | // CUDA forward declarations 7 | 8 | std::vector focal_loss_forward_cuda( 9 | const at::Tensor &cls_output, 10 | const at::Tensor &cls_targets_at_level, 11 | const at::Tensor &num_positives_sum, 12 | const int64_t num_real_classes, 13 | const float alpha, 14 | const float gamma, 15 | const float smoothing_factor); 16 | 17 | at::Tensor focal_loss_backward_cuda( 18 | const at::Tensor &grad_output, 19 | const at::Tensor &partial_grad, 20 | const at::Tensor &num_positives_sum); 21 | 22 | // C++ interface 23 | 24 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 25 | #define CHECK_CONTIGUOUS(x) \ 26 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 27 | #define CHECK_INPUT(x) \ 28 | CHECK_CUDA(x); \ 29 | CHECK_CONTIGUOUS(x) 30 | 31 | std::vector focal_loss_forward( 32 | const at::Tensor &cls_output, 33 | const at::Tensor &cls_targets_at_level, 34 | const at::Tensor &num_positives_sum, 35 | const int64_t num_real_classes, 36 | const float alpha, 37 | const float gamma, 38 | const float smoothing_factor 39 | ) { 40 | CHECK_INPUT(cls_output); 41 | CHECK_INPUT(cls_targets_at_level); 42 | CHECK_INPUT(num_positives_sum); 43 | 44 | return focal_loss_forward_cuda( 45 | cls_output, 46 | cls_targets_at_level, 47 | num_positives_sum, 48 | num_real_classes, 49 | alpha, 50 | gamma, 51 | smoothing_factor); 52 | } 53 | 54 | at::Tensor focal_loss_backward( 55 | const at::Tensor &grad_output, 56 | const at::Tensor &partial_grad, 57 | const at::Tensor &num_positives_sum 58 | ) { 59 | CHECK_INPUT(grad_output); 60 | CHECK_INPUT(partial_grad); 61 | 62 | return focal_loss_backward_cuda(grad_output, partial_grad, num_positives_sum); 63 | } 64 | 65 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 66 | m.def("forward", &focal_loss_forward, 67 | "Focal loss calculation forward (CUDA)", 68 | py::call_guard()); 69 | m.def("backward", &focal_loss_backward, 70 | "Focal loss calculation backward (CUDA)", 71 | py::call_guard()); 72 | } 73 | -------------------------------------------------------------------------------- /apex/contrib/csrc/gpu_direct_storage/gds.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace apex::contrib::gds { 10 | class File { 11 | public: 12 | File(); 13 | File(const std::string& filename, const std::string& mode); 14 | ~File(); 15 | 16 | void open(const std::string& filename, const std::string& mode); 17 | void close(); 18 | 19 | void load_data(const torch::Tensor& tensor); 20 | void save_data(const torch::Tensor& tensor); 21 | void load_data_no_gds(const torch::Tensor& tensor); 22 | void save_data_no_gds(const torch::Tensor& tensor); 23 | 24 | private: 25 | std::string filename; 26 | std::string mode; 27 | 28 | CUfileDescr_t cf_descr; 29 | CUfileHandle_t cf_handle; 30 | CUfileError_t status; 31 | 32 | int fd = -1; 33 | bool is_open = false; 34 | bool maybe_register = true; 35 | }; 36 | } 37 | -------------------------------------------------------------------------------- /apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | //python bindings 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 10 | py::class_< 11 | apex::contrib::gds::File, 12 | std::shared_ptr>( 13 | m, "_GDSFile") 14 | .def(py::init<>()) 15 | .def(py::init()) 16 | .def("open", &apex::contrib::gds::File::open) 17 | .def("close", &apex::contrib::gds::File::close) 18 | .def("load_data", &apex::contrib::gds::File::load_data) 19 | .def("save_data", &apex::contrib::gds::File::save_data) 20 | .def("load_data_no_gds", &apex::contrib::gds::File::load_data_no_gds) 21 | .def("save_data_no_gds", &apex::contrib::gds::File::save_data_no_gds); 22 | } 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_10.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 10, /* THREADS_PER_BLOCK */ 640) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_112.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 112, /* THREADS_PER_BLOCK */ 448) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_12.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2025, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 12, /* THREADS_PER_BLOCK */ 384) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_120.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 120, /* THREADS_PER_BLOCK */ 480) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_128.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 128, /* THREADS_PER_BLOCK */ 512) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_14.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 14, /* THREADS_PER_BLOCK */ 224) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_16.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 16, /* THREADS_PER_BLOCK */ 256) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_160.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 160, /* THREADS_PER_BLOCK */ 640) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_20.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 20, /* THREADS_PER_BLOCK */ 640) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_24.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 24, /* THREADS_PER_BLOCK */ 384) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_26.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 26, /* THREADS_PER_BLOCK */ 416) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_28.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 28, /* THREADS_PER_BLOCK */ 448) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_30.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 30, /* THREADS_PER_BLOCK */ 480) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_32.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 32, /* THREADS_PER_BLOCK */ 512) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_4.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | 23 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 4, /* THREADS_PER_BLOCK */ 128) 24 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_40.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 40, /* THREADS_PER_BLOCK */ 640) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_42.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 42, /* THREADS_PER_BLOCK */ 672) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_48.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 48, /* THREADS_PER_BLOCK */ 384) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_56.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 56, /* THREADS_PER_BLOCK */ 448) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_60.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 60, /* THREADS_PER_BLOCK */ 480) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_64.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 64, /* THREADS_PER_BLOCK */ 512) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_70.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 70, /* THREADS_PER_BLOCK */ 560) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_8.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 8, /* THREADS_PER_BLOCK */ 128) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_80.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 80, /* THREADS_PER_BLOCK */ 640) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_84.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 84, /* THREADS_PER_BLOCK */ 672) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_96.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 96, /* THREADS_PER_BLOCK */ 768) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_98.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are not permit- 5 | * ted. 6 | * 7 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 8 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 10 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 11 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 12 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 13 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 14 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | * 16 | **************************************************************************************************/ 17 | 18 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 19 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 20 | #include "macros.h" 21 | 22 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 98, /* THREADS_PER_BLOCK */ 392) 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/generate_gn_cuda_inst.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | 4 | hw_c_list = [ 5 | (8 * 8, 1280), 6 | (8 * 8, 2560), 7 | (16 * 16, 640), 8 | (16 * 16, 1280), 9 | (16 * 16, 1920), 10 | (16 * 16, 2560), 11 | (32 * 32, 320), 12 | (32 * 32, 640), 13 | (32 * 32, 960), 14 | (32 * 32, 1280), 15 | (32 * 32, 1920), 16 | (64 * 64, 320), 17 | (64 * 64, 640), 18 | (64 * 64, 960), 19 | ] 20 | 21 | 22 | def run(): 23 | src_path = pathlib.Path(__file__).parent.absolute() 24 | 25 | for f in src_path.glob("gn_cuda_inst_*.cu"): 26 | f.unlink() 27 | 28 | for hw, c in hw_c_list: 29 | print(f"GN_CUDA_INST_DEFINE({hw}, {c})") 30 | with open(src_path / f"gn_cuda_inst_{hw}_{c}.cu", "w") as f: 31 | f.write(f"#include \"gn_cuda_host_template.cuh\"\n") 32 | f.write(f"\n") 33 | f.write(f"\n") 34 | f.write(f"namespace group_norm_v2 {{\n") 35 | f.write(f"\n") 36 | f.write(f"GN_CUDA_INST_DEFINE({hw}, {c})\n") 37 | f.write(f"\n") 38 | f.write(f"}} // namespace group_norm_v2\n") 39 | 40 | with open(src_path / "gn_dispatch_hw_c.hpp", "w") as f: 41 | f.write(f"#pragma once\n") 42 | f.write(f"\n") 43 | f.write(f"#define DISPATCH_HW_C(hw, c, HW, C, ...) [&] {{ \\\n") 44 | for hw, c in hw_c_list: 45 | f.write(f" if (hw == {hw} && c == {c}) {{ constexpr int HW = {hw}, C = {c}; return __VA_ARGS__(); }} \\\n") 46 | f.write(f" throw std::invalid_argument(\"DISPATCH_HW_C \" + std::to_string(hw) + \" \" + std::to_string(c)); \\\n") 47 | f.write(f" }}()\n") 48 | 49 | 50 | if __name__ == "__main__": 51 | run() 52 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | 7 | namespace group_norm_v2 { 8 | 9 | struct Meta { 10 | int64_t red_buffer_size; 11 | int64_t barrier_size; 12 | int BLOCK_DIM_X; 13 | int C_PER_BLOCK; 14 | int ROWS_PER_BLOCK; 15 | int VEC_ELEMS; 16 | bool LOAD_TWICE; 17 | int BLOCKS_PER_SM; 18 | bool HARDWARE_CLUSTER; 19 | int wgrad_sync_method; 20 | }; 21 | 22 | template 23 | void gn_cuda(T *out, T *x, T *w, T *b, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *mean_var_out, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only); 24 | 25 | template 26 | void gn_bwd_cuda(T *grad_input, T *grad_weight, T *grad_bias, T *grad_output, T *x, T *w, T *b, float *mean_var, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only); 27 | 28 | } // namespace group_norm_v2 29 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda.cu: -------------------------------------------------------------------------------- 1 | #include "gn.hpp" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "gn_utils.hpp" 12 | #include "gn_dispatch_hw_c.hpp" 13 | 14 | 15 | #define DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, NUM_GROUPS, SILU, ...) [&] { \ 16 | if (num_groups == 16 && silu == true) { constexpr int NUM_GROUPS = 16; constexpr bool SILU = true; return __VA_ARGS__(); } \ 17 | if (num_groups == 32 && silu == false) { constexpr int NUM_GROUPS = 32; constexpr bool SILU = false; return __VA_ARGS__(); } \ 18 | throw std::invalid_argument("DISPATCH_NUM_GROUPS_AND_SILU " + std::to_string(num_groups) + " " + std::to_string(silu)); \ 19 | }() 20 | 21 | namespace group_norm_v2 { 22 | 23 | template 24 | void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(T)); 25 | 26 | template 27 | void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(T)); 28 | 29 | template 30 | void gn_cuda(GN_CUDA_HOST_PARAMS(T)) { 31 | DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] { 32 | DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU, [&] { 33 | return gn_cuda_single_shape(GN_CUDA_HOST_ARGS); 34 | }); 35 | }); 36 | } 37 | 38 | template 39 | void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(T)) { 40 | DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] { 41 | DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU, [&] { 42 | return gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_ARGS); 43 | }); 44 | }); 45 | } 46 | 47 | template void gn_cuda(GN_CUDA_HOST_PARAMS(half)); 48 | template void gn_cuda(GN_CUDA_HOST_PARAMS(__nv_bfloat16)); 49 | 50 | template void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(half)); 51 | template void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16)); 52 | 53 | } // namespace group_norm_v2 54 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1280.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(1024, 1280) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1920.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(1024, 1920) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_320.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(1024, 320) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_640.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(1024, 640) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_960.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(1024, 960) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1280.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(256, 1280) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1920.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(256, 1920) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_2560.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(256, 2560) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_640.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(256, 640) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_320.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(4096, 320) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_640.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(4096, 640) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_960.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(4096, 960) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_1280.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(64, 1280) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_2560.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | 4 | namespace group_norm_v2 { 5 | 6 | GN_CUDA_INST_DEFINE(64, 2560) 7 | 8 | } // namespace group_norm_v2 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_dispatch_hw_c.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define DISPATCH_HW_C(hw, c, HW, C, ...) [&] { \ 4 | if (hw == 64 && c == 1280) { constexpr int HW = 64, C = 1280; return __VA_ARGS__(); } \ 5 | if (hw == 64 && c == 2560) { constexpr int HW = 64, C = 2560; return __VA_ARGS__(); } \ 6 | if (hw == 256 && c == 640) { constexpr int HW = 256, C = 640; return __VA_ARGS__(); } \ 7 | if (hw == 256 && c == 1280) { constexpr int HW = 256, C = 1280; return __VA_ARGS__(); } \ 8 | if (hw == 256 && c == 1920) { constexpr int HW = 256, C = 1920; return __VA_ARGS__(); } \ 9 | if (hw == 256 && c == 2560) { constexpr int HW = 256, C = 2560; return __VA_ARGS__(); } \ 10 | if (hw == 1024 && c == 320) { constexpr int HW = 1024, C = 320; return __VA_ARGS__(); } \ 11 | if (hw == 1024 && c == 640) { constexpr int HW = 1024, C = 640; return __VA_ARGS__(); } \ 12 | if (hw == 1024 && c == 960) { constexpr int HW = 1024, C = 960; return __VA_ARGS__(); } \ 13 | if (hw == 1024 && c == 1280) { constexpr int HW = 1024, C = 1280; return __VA_ARGS__(); } \ 14 | if (hw == 1024 && c == 1920) { constexpr int HW = 1024, C = 1920; return __VA_ARGS__(); } \ 15 | if (hw == 4096 && c == 320) { constexpr int HW = 4096, C = 320; return __VA_ARGS__(); } \ 16 | if (hw == 4096 && c == 640) { constexpr int HW = 4096, C = 640; return __VA_ARGS__(); } \ 17 | if (hw == 4096 && c == 960) { constexpr int HW = 4096, C = 960; return __VA_ARGS__(); } \ 18 | throw std::invalid_argument("DISPATCH_HW_C " + std::to_string(hw) + " " + std::to_string(c)); \ 19 | }() 20 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_utils.cpp: -------------------------------------------------------------------------------- 1 | #include "gn_utils.hpp" 2 | 3 | #include 4 | #include 5 | 6 | 7 | namespace group_norm_v2 { 8 | 9 | cudaDeviceProp const &get_device_prop(int device_id) { 10 | static std::vector device_props; 11 | static std::once_flag flag; 12 | std::call_once(flag, [&] { 13 | int count; 14 | CUDA_CHECK(cudaGetDeviceCount(&count)); 15 | device_props.resize(count); 16 | for (int i = 0; i < count; i++) { 17 | CUDA_CHECK(cudaGetDeviceProperties(&device_props[i], i)); 18 | } 19 | }); 20 | return device_props.at(device_id); 21 | } 22 | 23 | } // namespace group_norm_v2 24 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_utils.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include "gn.hpp" 10 | 11 | 12 | // Definition of CUDA_CHECK macro 13 | #define CUDA_CHECK(call) \ 14 | do { \ 15 | cudaError_t err_ = call; \ 16 | if (err_ != cudaSuccess) { \ 17 | fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", \ 18 | __FILE__, __LINE__, err_, cudaGetErrorString(err_), #call); \ 19 | exit(EXIT_FAILURE); \ 20 | } \ 21 | } while (0) 22 | 23 | 24 | #define GN_CUDA_HOST_PARAMS(T) T *out, T *x, T *w, T *b, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *mean_var_out, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only 25 | 26 | #define GN_BWD_CUDA_HOST_PARAMS(T) T *grad_input, T *grad_weight, T *grad_bias, T *grad_output, T *x, T *w, T *b, float *mean_var, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only 27 | 28 | #define GN_CUDA_HOST_ARGS out, x, w, b, eps, silu, n, hw, num_groups, channels_per_group, mean_var_out, red_buffer, barrier, sm_margin, stream, device_id, meta_ptr, meta_only 29 | 30 | #define GN_BWD_CUDA_HOST_ARGS grad_input, grad_weight, grad_bias, grad_output, x, w, b, mean_var, eps, silu, n, hw, num_groups, channels_per_group, red_buffer, barrier, sm_margin, stream, device_id, meta_ptr, meta_only 31 | 32 | 33 | namespace group_norm_v2 { 34 | 35 | cudaDeviceProp const &get_device_prop(int device_id); 36 | 37 | #ifdef __CUDA_ARCH__ 38 | 39 | template 40 | __host__ __device__ inline int print_rank_0(char const *fmt, Ts &&...args) { 41 | if (threadIdx.x + threadIdx.y + threadIdx.z == 0 && blockIdx.x + blockIdx.y + blockIdx.z == 0) { 42 | return printf(fmt, std::forward(args)...); 43 | } 44 | return 0; 45 | } 46 | 47 | #endif 48 | 49 | } // namespace group_norm_v2 50 | -------------------------------------------------------------------------------- /apex/contrib/csrc/groupbn/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #include 2 | #ifndef CUDA_UTILS_H 3 | #define CUDA_UTILS_H 4 | 5 | namespace at { 6 | namespace cuda { 7 | 8 | namespace utils { 9 | 10 | static inline int MaxSharedMemoryPerMultiprocessor(int device_id) { 11 | return getDeviceProperties(device_id)->sharedMemPerMultiprocessor; 12 | } 13 | 14 | 15 | } 16 | } 17 | } 18 | 19 | 20 | #endif 21 | -------------------------------------------------------------------------------- /apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #define NCCL_CHECK(cmd) \ 9 | do { \ 10 | ncclResult_t result = cmd; \ 11 | if (result != ncclSuccess) { \ 12 | std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ 13 | std::to_string(__LINE__) + ", " + \ 14 | std::string(ncclGetErrorString(result)); \ 15 | TORCH_CHECK(false, err); \ 16 | } \ 17 | } while (0) 18 | 19 | void *nccl_alloc_plug(size_t size, int device, void *stream) { 20 | void *ptr; 21 | NCCL_CHECK(ncclMemAlloc(&ptr, size)); 22 | return ptr; 23 | } 24 | 25 | void nccl_free_plug(void *ptr, std::size_t size, int device, void *stream) { 26 | NCCL_CHECK(ncclMemFree(ptr)); 27 | } 28 | 29 | std::shared_ptr nccl_allocator; 30 | 31 | void maybe_init() { 32 | if (!nccl_allocator) { 33 | nccl_allocator = std::make_shared< 34 | torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator>( 35 | nccl_alloc_plug, nccl_free_plug); 36 | } 37 | } 38 | 39 | std::shared_ptr 40 | get_nccl_allocator() { 41 | maybe_init(); 42 | return nccl_allocator; 43 | } 44 | 45 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 46 | m.def("get_nccl_allocator", []() { return get_nccl_allocator(); }); 47 | }; 48 | -------------------------------------------------------------------------------- /apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "nccl_p2p_cuda.cuh" 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id", py::call_guard()); 21 | m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm", py::call_guard()); 22 | m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, "left_right_halo_exchange_inplace", py::call_guard()); 23 | m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange", py::call_guard()); 24 | m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay", py::call_guard()); 25 | } 26 | -------------------------------------------------------------------------------- /apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #ifndef _nccl_p2p_h_ 20 | #define _nccl_p2p_h_ 21 | 22 | namespace apex { namespace contrib { namespace nccl_p2p { 23 | at::Tensor get_unique_nccl_id(int n); 24 | int init_nccl_comm( 25 | at::Tensor unique_nccl_id, 26 | int my_rank, 27 | int num_ranks 28 | ); 29 | void left_right_halo_exchange_inplace( 30 | int handle, 31 | int left_rank, 32 | int right_rank, 33 | at::Tensor left_output_halo, 34 | at::Tensor right_output_halo, 35 | at::Tensor left_input_halo, 36 | at::Tensor right_input_halo); 37 | std::vector left_right_halo_exchange( 38 | int handle, 39 | int left_rank, 40 | int right_rank, 41 | at::Tensor left_output_halo, 42 | at::Tensor right_output_halo); 43 | void add_delay(int delay); 44 | }}} 45 | #endif 46 | -------------------------------------------------------------------------------- /apex/contrib/csrc/nccl_p2p/nccl_version.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | // This file is used to check the version of NCCL detected. 3 | #include 4 | 5 | #include 6 | 7 | std::tuple get_nccl_version(); 8 | 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 10 | m.def("get_nccl_version", &get_nccl_version); 11 | } 12 | -------------------------------------------------------------------------------- /apex/contrib/csrc/nccl_p2p/nccl_version_check.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | // This file is used to check the version of NCCL detected. 4 | #include 5 | #include 6 | 7 | 8 | std::tuple get_nccl_version() { 9 | return { int(NCCL_MAJOR), int(NCCL_MINOR) }; 10 | } 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void multi_tensor_lamb_cuda( 4 | int chunk_size, 5 | at::Tensor noop_flag, 6 | std::vector> tensor_lists, 7 | const float lr, 8 | const float beta1, 9 | const float beta2, 10 | const float epsilon, 11 | const int step, 12 | const int bias_correction, 13 | const float weight_decay, 14 | const int grad_averaging, 15 | const int mode, 16 | const float global_grad_norm, 17 | const float max_grad_norm); 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("lamb", &multi_tensor_lamb_cuda, "Computes and apply update for LAMB optimizer", py::call_guard()); 21 | } 22 | -------------------------------------------------------------------------------- /apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void multi_tensor_fused_adam_cuda( 4 | int chunk_size, at::Tensor noop_flag, 5 | std::vector> tensor_lists, at::Tensor grad_scale, 6 | float lr, float beta1, float beta2, float eps, int step, int mode, 7 | int bias_correction, float weight_decay); 8 | 9 | void multi_tensor_fused_adam_capturable_cuda( 10 | int chunk_size, at::Tensor noop_flag, 11 | std::vector> tensor_lists, at::Tensor grad_scale, 12 | at::Tensor lr, float beta1, float beta2, float eps, at::Tensor step, 13 | int mode, int bias_correction, float weight_decay); 14 | 15 | void multi_tensor_fused_adam_with_param_remainders_cuda( 16 | int chunk_size, at::Tensor noop_flag, 17 | std::vector> tensor_lists, at::Tensor grad_scale, 18 | float lr, float beta1, float beta2, float eps, int step, int mode, 19 | int bias_correction, float weight_decay); 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda, 23 | "CUDA kernels for multi-tensor Adam, " 24 | "with param copy", 25 | py::call_guard()); 26 | m.def("multi_tensor_fused_adam_capturable", 27 | &multi_tensor_fused_adam_capturable_cuda, 28 | "CUDA kernels for multi-tensor Adam, " 29 | "with param copy, capturable for CUDA graph", 30 | py::call_guard()); 31 | m.def("multi_tensor_fused_adam_with_param_remainders", 32 | &multi_tensor_fused_adam_with_param_remainders_cuda, 33 | "CUDA kernel for multi-tensor Adam, " 34 | "with stored param remainders and param copy", 35 | py::call_guard()); 36 | } 37 | -------------------------------------------------------------------------------- /apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void multi_tensor_lamb_compute_update_term_cuda( 4 | int chunk_size, 5 | at::Tensor noop_flag, 6 | std::vector> tensor_lists, 7 | at::Tensor per_tensor_beta1, 8 | at::Tensor per_tensor_beta2, 9 | at::Tensor per_tensor_beta3, 10 | at::Tensor per_tensor_bias_correction, 11 | at::Tensor step, 12 | at::Tensor per_tensor_epsilon, 13 | const int mode, 14 | at::Tensor per_tensor_decay, 15 | at::Tensor global_scale, 16 | at::Tensor global_grad_norm, 17 | const float max_grad_norm); 18 | 19 | void multi_tensor_lamb_update_weights_cuda( 20 | int chunk_size, 21 | at::Tensor noop_flag, 22 | std::vector> tensor_lists, 23 | at::Tensor per_tensor_param_norm, 24 | at::Tensor per_tensor_update_norm, 25 | at::Tensor update_norm_offset, 26 | at::Tensor learning_rate, 27 | at::Tensor per_tensor_decay, 28 | at::Tensor global_grad_norm, 29 | bool use_nvlamb); 30 | 31 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 32 | m.def("multi_tensor_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term_cuda, 33 | "Computes update term for LAMB optimizer", py::call_guard()); 34 | m.def("multi_tensor_lamb_update_weights", &multi_tensor_lamb_update_weights_cuda, 35 | "Applies update term for LAMB optimizer", py::call_guard()); 36 | } 37 | -------------------------------------------------------------------------------- /apex/contrib/csrc/peer_memory/peer_memory.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "peer_memory_cuda.cuh" 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw", py::call_guard()); 21 | m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw", py::call_guard()); 22 | m.def("zero", &apex::contrib::peer_memory::zero, "zero", py::call_guard()); 23 | m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address", py::call_guard()); 24 | m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers", py::call_guard()); 25 | m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half", py::call_guard()); 26 | m.def("blob_view_float", &apex::contrib::peer_memory::blob_view_float, "blob_view_float", py::call_guard()); 27 | m.def("blob_view_int", &apex::contrib::peer_memory::blob_view_int, "blob_view_int", py::call_guard()); 28 | m.def("push_pull_halos_1d", &apex::contrib::peer_memory::push_pull_halos_1d, "push_pull_halos_1d", py::call_guard()); 29 | } 30 | -------------------------------------------------------------------------------- /apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #ifndef _peer_memory_h_ 20 | #define _peer_memory_h_ 21 | 22 | namespace apex { namespace contrib { namespace peer_memory { 23 | int64_t allocate_raw(int64_t size); 24 | void free_raw(int64_t raw); 25 | void zero(int64_t raw, int64_t size); 26 | at::Tensor get_raw_ipc_address(int64_t raw); 27 | std::vector get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw); 28 | at::Tensor blob_view_half(int64_t raw, std::vector shape, bool channels_last); 29 | at::Tensor blob_view_float(int64_t raw, std::vector shape, bool channels_last); 30 | at::Tensor blob_view_int(int64_t raw, std::vector shape, bool channels_last); 31 | void push_pull_halos_1d( 32 | bool diagnostics, 33 | bool explicit_nhwc, 34 | int numSM, // number of SMs to use 35 | int peer_rank, // rank in spatial parallel group 36 | bool top_zero, // if top halo should be zeroed 37 | at::Tensor top_out_halo, // top output halo buffer (in local device memory, received from top neighbor) 38 | at::Tensor top_inp_transfer, // top input transfer buffer (in local peer memory) 39 | at::Tensor top_out_transfer, // top output transfer buffer (in top neighbor peer memory) 40 | at::Tensor top_inp_halo, // top input halo buffer (in local device memory, sent to top neighbor) 41 | bool btm_zero, // if btm halo should be zeroed 42 | at::Tensor btm_out_halo, // btm output halo buffer (in local device memory, received from btm neighbor) 43 | at::Tensor btm_inp_transfer, // btm input transfer buffer (in local peer memory) 44 | at::Tensor btm_out_transfer, // btm output transfer buffer (in btm neighbor peer memory) 45 | at::Tensor btm_inp_halo // btm input halo buffer (in local device memory, sent to btm neighbor) 46 | ); 47 | } } } 48 | #endif 49 | -------------------------------------------------------------------------------- /apex/contrib/csrc/transducer/transducer_joint.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 5 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 6 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 7 | 8 | std::vector transducer_joint_cuda_forward( 9 | torch::Tensor f, 10 | torch::Tensor g, 11 | torch::Tensor fLen, 12 | torch::Tensor gLen, 13 | torch::Tensor batchOffset, 14 | int64_t packedBatch, 15 | int opt, 16 | bool packOutput, 17 | bool relu, 18 | bool dropout, 19 | float dropoutProb, 20 | int tileSize); 21 | 22 | 23 | std::vector transducer_joint_cuda_backward( 24 | std::vector in, 25 | torch::Tensor fLen, 26 | torch::Tensor gLen, 27 | torch::Tensor batchOffset, 28 | int maxFLen, 29 | int maxGLen, 30 | bool packOutput, 31 | float scale); 32 | 33 | std::vector transducer_joint_forward( 34 | torch::Tensor f, 35 | torch::Tensor g, 36 | torch::Tensor fLen, 37 | torch::Tensor gLen, 38 | torch::Tensor batchOffset, 39 | int64_t packedBatch, 40 | int opt, 41 | bool packOutput, 42 | bool relu, 43 | bool dropout, 44 | float dropoutProb, 45 | int tileSize) { 46 | CHECK_INPUT(f); 47 | CHECK_INPUT(g); 48 | CHECK_INPUT(fLen); 49 | CHECK_INPUT(gLen); 50 | if (packOutput) 51 | CHECK_INPUT(batchOffset); 52 | return transducer_joint_cuda_forward( 53 | f, 54 | g, 55 | fLen, 56 | gLen, 57 | batchOffset, 58 | packedBatch, 59 | opt, 60 | packOutput, 61 | relu, 62 | dropout, 63 | dropoutProb, 64 | tileSize); 65 | } 66 | 67 | std::vector transducer_joint_backward( 68 | std::vector in, 69 | torch::Tensor fLen, 70 | torch::Tensor gLen, 71 | torch::Tensor batchOffset, 72 | int maxFLen, 73 | int maxGLen, 74 | bool packOutput, 75 | float scale) { 76 | for (auto t : in){ 77 | CHECK_INPUT(t); 78 | } 79 | CHECK_INPUT(fLen); 80 | CHECK_INPUT(gLen); 81 | if (packOutput) 82 | CHECK_INPUT(batchOffset); 83 | return transducer_joint_cuda_backward( 84 | in, 85 | fLen, 86 | gLen, 87 | batchOffset, 88 | maxFLen, 89 | maxGLen, 90 | packOutput, 91 | scale); 92 | } 93 | 94 | 95 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 96 | m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)", py::call_guard()); 97 | m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)", py::call_guard()); 98 | } 99 | -------------------------------------------------------------------------------- /apex/contrib/csrc/xentropy/interface.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | std::vector softmax_xentropy_cuda( 8 | const at::Tensor &input, 9 | const at::Tensor &labels, 10 | const float smoothing, 11 | const bool half_to_float); 12 | 13 | at::Tensor softmax_xentropy_backward_cuda( 14 | const at::Tensor &grad_loss, 15 | const at::Tensor &logits, 16 | const at::Tensor &max_log_sum_exp, 17 | const at::Tensor &labels, 18 | const float smoothing); 19 | 20 | // C++ interface 21 | 22 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 23 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 24 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 25 | 26 | std::vector softmax_xentropy_forward( 27 | const at::Tensor &input, 28 | const at::Tensor &labels, 29 | const float smoothing, 30 | const bool half_to_float) { 31 | CHECK_CUDA(input); 32 | CHECK_INPUT(labels); 33 | 34 | return softmax_xentropy_cuda(input, labels, smoothing, half_to_float); 35 | } 36 | 37 | at::Tensor softmax_xentropy_backward( 38 | const at::Tensor &grad_loss, 39 | const at::Tensor &logits, 40 | const at::Tensor &max_log_sum_exp, 41 | const at::Tensor &labels, 42 | const float smoothing) { 43 | CHECK_CUDA(grad_loss); 44 | CHECK_CUDA(logits); 45 | CHECK_INPUT(max_log_sum_exp); 46 | CHECK_INPUT(labels); 47 | 48 | return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing); 49 | } 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", py::call_guard()); 53 | m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", py::call_guard()); 54 | // ref: https://pybind11.readthedocs.io/en/stable/basics.html#exporting-variables 55 | py::object version = py::cast( 56 | #ifdef XENTROPY_VER 57 | XENTROPY_VER 58 | #else 59 | std::string{} 60 | #endif 61 | ); 62 | m.attr("__version__") = version; 63 | } 64 | -------------------------------------------------------------------------------- /apex/contrib/cudnn_gbn/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch_norm import GroupBatchNorm2d -------------------------------------------------------------------------------- /apex/contrib/examples/gpu_direct_storage/benchmark_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import timeit 3 | import torch 4 | import apex.contrib.gpu_direct_storage as gds 5 | 6 | def run_benchmark_torch_load(): 7 | sizes = [2 ** i for i in range(16, 28)] 8 | for size in sizes: 9 | torch.cuda.empty_cache() 10 | s = torch.cuda.Stream() 11 | x = torch.empty(size, device = "cuda") 12 | y = torch.linspace(0, 1, size, device = "cuda") 13 | torch.save(y, f"{size}.data") 14 | 15 | # warmup 16 | torch.cuda.synchronize() 17 | for _ in range(10): 18 | x = torch.load(f"{size}.data") 19 | 20 | torch.cuda.synchronize() 21 | start_time = timeit.default_timer() 22 | for _ in range(10): 23 | x = torch.load(f"{size}.data") 24 | torch.cuda.synchronize() 25 | end_time = timeit.default_timer() 26 | print(f"torch.load: size = {size}, {end_time - start_time}") 27 | assert(torch.allclose(x, y)) 28 | 29 | def run_benchmark(func): 30 | sizes = [2 ** i for i in range(16, 28)] 31 | for size in sizes: 32 | torch.cuda.empty_cache() 33 | s = torch.cuda.Stream() 34 | x = torch.empty(size, device = "cuda") 35 | y = torch.linspace(0, 1, size, device = "cuda") 36 | 37 | with gds.GDSFile(f"{size}.data", "w") as f: 38 | f.save_data(y) 39 | 40 | # warmup 41 | torch.cuda.synchronize() 42 | for _ in range(10): 43 | func(x, f"{size}.data") 44 | 45 | torch.cuda.synchronize() 46 | start_time = timeit.default_timer() 47 | for _ in range(10): 48 | func(x, f"{size}.data") 49 | torch.cuda.synchronize() 50 | end_time = timeit.default_timer() 51 | print(f"{func.__name__}: size = {size}, {end_time - start_time}") 52 | assert(torch.allclose(x, y)) 53 | 54 | def load_data_yes_gds(tensor, filename): 55 | with gds.GDSFile(filename, "r") as f: 56 | f.load_data(tensor) 57 | 58 | def load_data_no_gds(tensor, filename): 59 | with gds.GDSFile(filename, "rn") as f: 60 | f.load_data_no_gds(tensor) 61 | 62 | if __name__ == '__main__': 63 | run_benchmark_torch_load() 64 | run_benchmark(load_data_yes_gds) 65 | run_benchmark(load_data_no_gds) 66 | -------------------------------------------------------------------------------- /apex/contrib/examples/gpu_direct_storage/benchmark_save.py: -------------------------------------------------------------------------------- 1 | import os 2 | import timeit 3 | import torch 4 | import apex.contrib.gpu_direct_storage as gds 5 | 6 | def run_benchmark(func): 7 | sizes = [2 ** i for i in range(16, 28)] 8 | for size in sizes: 9 | torch.cuda.empty_cache() 10 | s = torch.cuda.Stream() 11 | x = torch.linspace(0, 1, size, device = "cuda") 12 | 13 | # warmup 14 | torch.cuda.synchronize() 15 | for _ in range(10): 16 | func(x, f"{size}.data") 17 | os.remove(f"{size}.data") 18 | 19 | torch.cuda.synchronize() 20 | start_time = timeit.default_timer() 21 | for _ in range(10): 22 | func(x, f"{size}.data") 23 | os.remove(f"{size}.data") 24 | torch.cuda.synchronize() 25 | end_time = timeit.default_timer() 26 | print(f"{func.__name__}: size = {size}, {end_time - start_time}") 27 | 28 | def save_data_yes_gds(tensor, filename): 29 | with gds.GDSFile(filename, "w") as f: 30 | f.save_data(tensor) 31 | 32 | def save_data_no_gds(tensor, filename): 33 | with gds.GDSFile(filename, "wn") as f: 34 | f.save_data_no_gds(tensor) 35 | 36 | if __name__ == '__main__': 37 | run_benchmark(torch.save) 38 | run_benchmark(save_data_yes_gds) 39 | run_benchmark(save_data_no_gds) 40 | -------------------------------------------------------------------------------- /apex/contrib/examples/gpu_direct_storage/example_load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import apex.contrib.gpu_direct_storage as gds 3 | 4 | for size in [128, 1024, 8192]: 5 | x = torch.empty(size, device = "cuda") 6 | with gds.GDSFile(f"{size}.data", "r") as f: 7 | f.load_data(x) 8 | xx = torch.linspace(0, 1, size, device = "cuda") 9 | assert(torch.allclose(x, xx)) 10 | -------------------------------------------------------------------------------- /apex/contrib/examples/gpu_direct_storage/example_save.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import apex.contrib.gpu_direct_storage as gds 3 | 4 | for size in [128, 1024, 8192]: 5 | x = torch.linspace(0, 1, size, device = "cuda") 6 | with gds.GDSFile(f"{size}.data", "w") as f: 7 | f.save_data(x) 8 | -------------------------------------------------------------------------------- /apex/contrib/examples/nccl_allocator/allreduce.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | import apex.contrib.nccl_allocator as nccl_allocator 5 | 6 | assert os.getenv("WORLD_SIZE") is not None, "Please use: torchrun --nproc-per-node=8 allreduce.py" 7 | 8 | rank = int(os.getenv("RANK")) 9 | local_rank = int(os.getenv("LOCAL_RANK")) 10 | world_size = int(os.getenv("WORLD_SIZE")) 11 | 12 | nccl_allocator.init() 13 | 14 | torch.cuda.set_device(local_rank) 15 | dist.init_process_group(backend="nccl") 16 | pool = nccl_allocator.create_nccl_mem_pool() 17 | with nccl_allocator.nccl_mem(pool): 18 | a = torch.ones(1024 * 1024 * 2, device="cuda") 19 | dist.all_reduce(a) 20 | 21 | torch.cuda.synchronize() 22 | 23 | -------------------------------------------------------------------------------- /apex/contrib/examples/nccl_allocator/cache.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import apex.contrib.nccl_allocator as nccl_allocator 4 | from pynvml.smi import nvidia_smi 5 | 6 | def set_device(dev): 7 | import ctypes 8 | handle = ctypes.CDLL("libcudart.so") 9 | result = handle.cudaSetDevice(ctypes.c_int(dev)) 10 | assert result == 0 11 | 12 | def print_used_mem(string, nvsmi, device_id = 0): 13 | print(f"{string}:", nvsmi.DeviceQuery('memory.used')['gpu'][device_id]) 14 | 15 | nccl_allocator.init() 16 | nrep = 6 17 | nccl_mem = [] 18 | 19 | set_device(0) 20 | nvsmi = nvidia_smi.getInstance() 21 | 22 | print_used_mem("", nvsmi) 23 | 24 | pool = nccl_allocator.create_nccl_mem_pool() 25 | with nccl_allocator.nccl_mem(pool): 26 | for i in range(nrep): 27 | out = torch.randn(1024 * 1024 * 100).cuda() # >= 400 MB 28 | nccl_mem.append(out) 29 | 30 | print_used_mem("after nccl alloc (+>=2400)", nvsmi) # + 2400+ MB 31 | 32 | cudart_mem = [] 33 | for i in range(nrep): 34 | out = torch.randn(1024 * 1024 * 50 ).cuda() # == 200 MB 35 | cudart_mem.append(out) 36 | 37 | print_used_mem("after cudart alloc (+1200)", nvsmi) 38 | 39 | del cudart_mem 40 | torch.cuda.empty_cache() 41 | torch.cuda.empty_cache() 42 | print_used_mem("release cudart mem (-1200)", nvsmi) # - 1200 MB 43 | 44 | del nccl_mem 45 | nccl_mem2 = [] 46 | with nccl_allocator.nccl_mem(pool): 47 | for i in range(nrep): 48 | out = torch.randn(1024 * 1024 * 100).cuda() # >= 400 MB 49 | nccl_mem2.append(out) 50 | print_used_mem("reuse nccl cache (same)", nvsmi) # + 0 MB 51 | del nccl_mem2 52 | torch.cuda.empty_cache() 53 | print_used_mem("release nccl_mem (-2400)", nvsmi) # - 2400 MB 54 | 55 | torch.cuda.empty_cache() 56 | -------------------------------------------------------------------------------- /apex/contrib/examples/nccl_allocator/change_cuda_allocator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import apex.contrib.nccl_allocator as nccl_allocator 3 | 4 | nccl_allocator.init() 5 | nrep = 6 6 | pool = nccl_allocator.create_nccl_mem_pool() 7 | with nccl_allocator.nccl_mem(pool): 8 | for i in range(nrep): 9 | out = torch.randn(1024).cuda() 10 | 11 | for i in range(nrep): 12 | out = torch.randn(1024).cuda() 13 | 14 | torch.cuda.empty_cache() 15 | torch.cuda.empty_cache() 16 | -------------------------------------------------------------------------------- /apex/contrib/examples/nccl_allocator/toy_ddp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.distributed as dist 6 | from torch.nn.parallel import DistributedDataParallel as DDP 7 | 8 | import apex.contrib.nccl_allocator as nccl_allocator 9 | 10 | assert os.getenv("WORLD_SIZE") is not None, "Please use: torchrun --nproc-per-node=8 toy_ddp.py" 11 | 12 | class ToyModel(nn.Module): 13 | def __init__(self): 14 | super(ToyModel, self).__init__() 15 | self.net1 = nn.Linear(10, 10) 16 | self.relu = nn.ReLU() 17 | self.net2 = nn.Linear(10, 5) 18 | 19 | def forward(self, x): 20 | return self.net2(self.relu(self.net1(x))) 21 | 22 | 23 | rank = int(os.getenv("RANK")) 24 | local_rank = int(os.getenv("LOCAL_RANK")) 25 | world_size = int(os.getenv("WORLD_SIZE")) 26 | 27 | nccl_allocator.init() 28 | 29 | torch.cuda.set_device(local_rank) 30 | dist.init_process_group(backend="nccl") 31 | 32 | device = torch.device("cuda", local_rank) 33 | model = ToyModel().to(device) 34 | ddp_model = DDP(model, device_ids=[rank]) 35 | loss_fn = nn.MSELoss() 36 | optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) 37 | 38 | data_ptrs = [] 39 | pool = nccl_allocator.create_nccl_mem_pool() 40 | with nccl_allocator.nccl_mem(pool): 41 | for param in ddp_model.parameters(): 42 | param.grad = torch.empty_like(param) 43 | data_ptrs.append(param.grad.data_ptr()) 44 | 45 | for _ in range(10): 46 | optimizer.zero_grad(set_to_none=False) 47 | outputs = ddp_model(torch.randn(20, 10)) 48 | labels = torch.randn(20, 5).to(rank) 49 | loss_fn(outputs, labels).backward() 50 | optimizer.step() 51 | 52 | for data_ptr, param in zip(data_ptrs, ddp_model.parameters()): 53 | assert(data_ptr == param.grad.data_ptr()) 54 | dist.destroy_process_group() 55 | -------------------------------------------------------------------------------- /apex/contrib/fmha/__init__.py: -------------------------------------------------------------------------------- 1 | from .fmha import FMHAFun 2 | -------------------------------------------------------------------------------- /apex/contrib/focal_loss/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import torch 3 | import focal_loss_cuda 4 | from .focal_loss import focal_loss 5 | del torch 6 | del focal_loss_cuda 7 | del focal_loss 8 | except ImportError as err: 9 | print("apex was installed without --focal_loss flag, apex.contrib.focal_loss is not available") 10 | -------------------------------------------------------------------------------- /apex/contrib/focal_loss/focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import focal_loss_cuda 4 | 5 | 6 | class FocalLoss(torch.autograd.Function): 7 | @staticmethod 8 | def forward( 9 | ctx, 10 | cls_output, 11 | cls_targets_at_level, 12 | num_positives_sum, 13 | num_real_classes, 14 | alpha, 15 | gamma, 16 | label_smoothing=0.0, 17 | ): 18 | loss, partial_grad = focal_loss_cuda.forward( 19 | cls_output, 20 | cls_targets_at_level, 21 | num_positives_sum, 22 | num_real_classes, 23 | alpha, 24 | gamma, 25 | label_smoothing, 26 | ) 27 | 28 | ctx.save_for_backward(partial_grad, num_positives_sum) 29 | return loss 30 | 31 | @staticmethod 32 | def backward(ctx, grad_loss): 33 | partial_grad, num_positives_sum = ctx.saved_tensors 34 | 35 | # The backward kernel is actually in-place to save memory space, 36 | # partial_grad and grad_input are the same tensor. 37 | grad_input = focal_loss_cuda.backward(grad_loss, partial_grad, num_positives_sum) 38 | 39 | return grad_input, None, None, None, None, None, None 40 | 41 | 42 | def focal_loss( 43 | cls_output: torch.Tensor, 44 | cls_targets_at_level: torch.Tensor, 45 | num_positive_sum: torch.Tensor, 46 | num_real_classes: int, 47 | alpha: float, 48 | gamma: float, 49 | label_smoothing: float = 0.0, 50 | ) -> torch.Tensor: 51 | """Fused focal loss function.""" 52 | return FocalLoss.apply( 53 | cls_output, 54 | cls_targets_at_level, 55 | num_positive_sum, 56 | num_real_classes, 57 | alpha, 58 | gamma, 59 | label_smoothing, 60 | ) 61 | -------------------------------------------------------------------------------- /apex/contrib/gpu_direct_storage/README.md: -------------------------------------------------------------------------------- 1 | # APEX GPUDirect Storage 2 | 3 | This module aims to add a PyTorch extension for [GPUDirect Storage](https://developer.nvidia.com/blog/gpudirect-storage/) (GDS) support through utilizing the [cuFile](https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html) library. 4 | 5 | # Build command 6 | ``` 7 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--gpu_direct_storage" ./ 8 | ``` 9 | 10 | Alternatively: 11 | ``` 12 | python setup.py install --gpu_direct_storage 13 | ``` 14 | 15 | Check installation: 16 | ``` 17 | python -c "import torch; import apex.contrib.gpu_direct_storage" 18 | ``` 19 | -------------------------------------------------------------------------------- /apex/contrib/gpu_direct_storage/__init__.py: -------------------------------------------------------------------------------- 1 | from _apex_gpu_direct_storage import _GDSFile 2 | from contextlib import contextmanager 3 | 4 | @contextmanager 5 | def GDSFile(filename, mode): 6 | assert type(filename) == str 7 | assert type(mode) == str 8 | try: 9 | from apex import deprecated_warning 10 | 11 | deprecated_warning( 12 | "`gpu_direct_storage.GDSFile` is deprecated and will be removed in September 2025. " 13 | "We encourage you to use `torch.cuda.gds` module of PyTorch as a replacement. " 14 | "Its documentation is available at https://docs.pytorch.org/docs/stable/cuda.html#gpudirect-storage-prototype" 15 | ) 16 | file_handle = _GDSFile(filename, mode) 17 | yield file_handle 18 | finally: 19 | file_handle.close() 20 | del file_handle 21 | -------------------------------------------------------------------------------- /apex/contrib/group_norm/__init__.py: -------------------------------------------------------------------------------- 1 | from .group_norm import * 2 | -------------------------------------------------------------------------------- /apex/contrib/groupbn/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import torch 3 | import bnp 4 | from .batch_norm import BatchNorm2d_NHWC 5 | del torch 6 | del bnp 7 | del batch_norm 8 | except ImportError as err: 9 | print("apex was installed without --bnp flag, contrib.groupbn is not available") 10 | -------------------------------------------------------------------------------- /apex/contrib/index_mul_2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .index_mul_2d import index_mul_2d 2 | -------------------------------------------------------------------------------- /apex/contrib/layer_norm/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer_norm import FastLayerNorm 2 | -------------------------------------------------------------------------------- /apex/contrib/layer_norm/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import init 3 | 4 | from apex._autocast_utils import _cast_if_autocast_enabled 5 | import fast_layer_norm 6 | 7 | 8 | class FastLayerNormFN(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, x, gamma, beta, epsilon, memory_efficient=False): 11 | ctx.x_shape = x.shape 12 | ctx.memory_efficient = memory_efficient 13 | 14 | x = x.contiguous() 15 | gamma = gamma.contiguous() 16 | beta = beta.contiguous() 17 | hidden_size = gamma.numel() 18 | xmat = x.view((-1, hidden_size)) 19 | ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon) 20 | if ctx.memory_efficient: 21 | ctx.save_for_backward(ymat, gamma, None, rsigma, beta) 22 | else: 23 | ctx.save_for_backward(xmat, gamma, mu, rsigma, None) 24 | return ymat.view(x.shape) 25 | 26 | @staticmethod 27 | def backward(ctx, dy): 28 | # assert dy.is_contiguous() 29 | dy = dy.contiguous() # this happens! 30 | x_or_y_mat, gamma, mu, rsigma, beta = ctx.saved_tensors 31 | dymat = dy.view(x_or_y_mat.shape) 32 | dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, x_or_y_mat, mu, rsigma, gamma, beta, ctx.memory_efficient) 33 | dx = dxmat.view(ctx.x_shape) 34 | return dx, dgamma, dbeta, None, None 35 | 36 | 37 | def _fast_layer_norm(x, weight, bias, epsilon, memory_efficient): 38 | args = _cast_if_autocast_enabled(x, weight, bias, epsilon, memory_efficient) 39 | with torch.amp.autocast('cuda', enabled=False): 40 | return FastLayerNormFN.apply(*args) 41 | 42 | 43 | class FastLayerNorm(torch.nn.Module): 44 | def __init__(self, hidden_size, eps=1e-5, memory_efficient=False): 45 | super().__init__() 46 | self.epsilon = eps 47 | self.memory_efficient = memory_efficient 48 | self.weight = torch.nn.Parameter(torch.empty(hidden_size)) 49 | self.bias = torch.nn.Parameter(torch.empty(hidden_size)) 50 | self.reset_parameters() 51 | 52 | def reset_parameters(self): 53 | init.ones_(self.weight) 54 | init.zeros_(self.bias) 55 | 56 | def forward(self, x): 57 | return _fast_layer_norm(x, self.weight, self.bias, self.epsilon, self.memory_efficient) 58 | -------------------------------------------------------------------------------- /apex/contrib/multihead_attn/MHA_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/multihead_attn/MHA_bwd.png -------------------------------------------------------------------------------- /apex/contrib/multihead_attn/MHA_fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/multihead_attn/MHA_fwd.png -------------------------------------------------------------------------------- /apex/contrib/multihead_attn/README.md: -------------------------------------------------------------------------------- 1 | # Fast Multihead Attention 2 | 3 | This implementation has two main features : 4 | * A C++ implementation to avoid the CPU overheads of Pytorch found with smaller batch sizes. 5 | * The removal of all copies and transposes found in standard implementations of Multihead Attention. 6 | 7 | | | Python Version | C++ Version | 8 | | :----------------------------------------- | :------------: | :---------: | 9 | | Layer Norm and Residual Add Variant | X | X | 10 | | Includes Linear Biases | X | | 11 | | Reduces CPU Overheads | | X | 12 | | Fuses masking with Softmax | | X | 13 | | Removes Transposes and Copies | X | X | 14 | | Includes Self and Encoder/Decoder Variants | X | X | 15 | 16 | ## How to Instantiate 17 | 18 | `SelfMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)` 19 | `EncdecMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)` 20 | 21 | `impl` has two options: 22 | * `fast` uses C++ Version 23 | * `default` uses Python Version 24 | 25 | ## Instructions to build on Linux 26 | 27 | ``` 28 | $ git clone https://github.com/NVIDIA/apex 29 | $ cd apex 30 | $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" ./ 31 | ``` 32 | ## Try Performance Tests Yourself! 33 | Perf test script is found here! 34 | ``` 35 | cd contrib/examples/multihead_attn 36 | ``` 37 | #### Fast Multihead Attention 38 | ``` 39 | python perf_test_multihead_attn.py --ref 40 | ``` 41 | #### Fast Multihead Attention with C++ Implementation 42 | ``` 43 | python perf_test_multihead_attn.py 44 | ``` 45 | #### Compare with `torch.nn.MultiheadAttn` 46 | ``` 47 | python perf_test_multihead_attn.py --native 48 | ``` 49 | #### Test your own range! 50 | ``` 51 | python perf_test_multihead_attn.py --seq-length 64 --num-seqs-start 10 --num-seqs-stop 120 --num-seqs-inc 5 52 | ``` 53 | 54 | ## Performance Comparisons 55 | 56 | * Performance was measured with 64 token sequence lengths on an NVIDIA TitanV card. 57 | * Time is measured across multiple layers to simulate an in model scenario. 58 | 59 | ![Multihead Attention Forward](MHA_fwd.png) 60 | ![Multihead Attention Backward](MHA_bwd.png) 61 | -------------------------------------------------------------------------------- /apex/contrib/multihead_attn/__init__.py: -------------------------------------------------------------------------------- 1 | from .self_multihead_attn import SelfMultiheadAttn 2 | from .encdec_multihead_attn import EncdecMultiheadAttn 3 | from .mask_softmax_dropout_func import fast_mask_softmax_dropout_func 4 | -------------------------------------------------------------------------------- /apex/contrib/multihead_attn/mask_softmax_dropout_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import fast_multihead_attn 4 | 5 | 6 | class MaskSoftmaxDropout(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob): 9 | heads_t = torch.tensor([heads]) 10 | dropout_prob_t = torch.tensor([dropout_prob]) 11 | null_tensor = torch.tensor([]) 12 | use_mask = pad_mask is not None 13 | use_mask_t = torch.tensor([use_mask]) 14 | mask_additive_t = torch.tensor([mask_additive]) 15 | 16 | if mask_additive: 17 | dropout_results, dropout_mask, softmax_results = fast_multihead_attn.additive_mask_softmax_dropout_forward( 18 | use_mask, is_training, heads, inputs, pad_mask if use_mask else null_tensor, dropout_prob 19 | ) 20 | # fast_additive_mask_softmax_dropout.forward( \ 21 | else: 22 | dropout_results, dropout_mask, softmax_results = fast_multihead_attn.mask_softmax_dropout_forward( 23 | use_mask, is_training, heads, inputs, pad_mask if use_mask else null_tensor, dropout_prob 24 | ) 25 | # fast_mask_softmax_dropout.forward( \ 26 | 27 | ctx.save_for_backward( 28 | use_mask_t, 29 | heads_t, 30 | softmax_results, 31 | dropout_mask, 32 | pad_mask if use_mask else null_tensor, 33 | mask_additive_t, 34 | dropout_prob_t, 35 | ) 36 | 37 | return dropout_results.detach() 38 | 39 | @staticmethod 40 | def backward(ctx, output_grads): 41 | ( 42 | use_mask_t, 43 | heads_t, 44 | softmax_results, 45 | dropout_mask, 46 | pad_mask, 47 | mask_additive_t, 48 | dropout_prob_t, 49 | ) = ctx.saved_tensors 50 | 51 | if mask_additive_t[0]: 52 | input_grads = fast_multihead_attn.additive_mask_softmax_dropout_backward( 53 | use_mask_t[0], heads_t[0], output_grads, softmax_results, dropout_mask, dropout_prob_t[0] 54 | ) 55 | # fast_additive_mask_softmax_dropout.backward( \ 56 | else: 57 | input_grads = fast_multihead_attn.mask_softmax_dropout_backward( 58 | use_mask_t[0], heads_t[0], output_grads, softmax_results, dropout_mask, pad_mask, dropout_prob_t[0] 59 | ) 60 | # fast_mask_softmax_dropout.backward( \ 61 | return None, None, input_grads, None, None, None 62 | 63 | 64 | fast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply 65 | -------------------------------------------------------------------------------- /apex/contrib/nccl_allocator/README.md: -------------------------------------------------------------------------------- 1 | ## General information 2 | 3 | `nccl_allocator` is a module that enables `ncclMemAlloc`[^1] to be used within PyTorch for faster NCCL NVLS collective communications. 4 | It is mainly based on `CUDAPluggableAllocator`. 5 | The context manager `nccl_allocator.nccl_mem(enabled=True)` is used as a switch between `cudaMalloc` and `ncclMemAlloc` (if `enabled=True` it will use `cudaMalloc`). 6 | 7 | [^1]: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html 8 | 9 | ### Example usage: 10 | 11 | Here is a minimalistic example: 12 | 13 | ``` 14 | import os 15 | import torch 16 | import torch.distributed as dist 17 | import apex.contrib.nccl_allocator as nccl_allocator 18 | 19 | rank = int(os.getenv("RANK")) 20 | local_rank = int(os.getenv("LOCAL_RANK")) 21 | world_size = int(os.getenv("WORLD_SIZE")) 22 | 23 | nccl_allocator.init() 24 | 25 | torch.cuda.set_device(local_rank) 26 | dist.init_process_group(backend="nccl") 27 | 28 | with nccl_allocator.nccl_mem(): 29 | a = torch.ones(1024 * 1024 * 2, device="cuda") 30 | dist.all_reduce(a) 31 | 32 | torch.cuda.synchronize() 33 | ``` 34 | 35 | Please visit `apex/contrib/examples/nccl_allocator` for more examples. 36 | 37 | 38 | ### IMPORTANT 39 | 40 | There are several strict requirements: 41 | - PyTorch must include PR [#112850](https://github.com/pytorch/pytorch/pull/112850) 42 | - NCCL v2.19.4 and newer 43 | - NCCL NVLS requires CUDA Driver 530 and newer (tested on 535) 44 | 45 | -------------------------------------------------------------------------------- /apex/contrib/nccl_allocator/__init__.py: -------------------------------------------------------------------------------- 1 | from .nccl_allocator import * 2 | -------------------------------------------------------------------------------- /apex/contrib/nccl_allocator/nccl_allocator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import _apex_nccl_allocator 4 | 5 | from contextlib import nullcontext 6 | 7 | 8 | __all__ = ["init", "nccl_mem", "create_nccl_mem_pool"] 9 | 10 | 11 | def create_nccl_mem_pool(): 12 | _allocator = _apex_nccl_allocator.get_nccl_allocator() 13 | _pool = torch.cuda.MemPool(_allocator) 14 | return _pool 15 | 16 | 17 | def init() -> None: 18 | os.environ["NCCL_NVLS_ENABLE"] = "1" 19 | os.environ["TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"] = "0" 20 | 21 | 22 | class nccl_mem: 23 | def __init__(self, pool, enabled = True, device = None, group = None): 24 | self.device = None 25 | self.group = None 26 | self.mem_context = None 27 | self.pool = pool 28 | 29 | if enabled: 30 | if device is None: 31 | self.device = torch.device("cuda", torch.cuda.current_device()) 32 | elif isinstance(device, int): 33 | self.device = torch.device("cuda", device) 34 | elif isinstance(device, str): 35 | assert "cuda" in device, "only cuda devices are supported" 36 | self.device = torch.device(device) 37 | 38 | if group is None: 39 | self.group = torch.distributed.distributed_c10d._get_default_group() 40 | else: 41 | self.group = group 42 | 43 | self.mem_context = torch.cuda.use_mem_pool(self.pool) 44 | else: 45 | self.mem_context = nullcontext() 46 | 47 | def __enter__(self): 48 | self.mem_context.__enter__() 49 | if self.group is not None: 50 | backend = self.group._get_backend(self.device) 51 | try: 52 | backend.deregister_mem_pool(self.pool) 53 | except RuntimeError: 54 | pass 55 | 56 | def __exit__(self, *args): 57 | if self.group is not None: 58 | backend = self.group._get_backend(self.device) 59 | try: 60 | backend.register_mem_pool(self.pool) 61 | except RuntimeError: 62 | pass 63 | self.mem_context.__exit__(*args) 64 | -------------------------------------------------------------------------------- /apex/contrib/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .fp16_optimizer import FP16_Optimizer 2 | from .fused_adam import FusedAdam 3 | from .fused_lamb import FusedLAMB 4 | -------------------------------------------------------------------------------- /apex/contrib/peer_memory/__init__.py: -------------------------------------------------------------------------------- 1 | from .peer_memory import PeerMemoryPool 2 | from .peer_halo_exchanger_1d import PeerHaloExchanger1d 3 | 4 | -------------------------------------------------------------------------------- /apex/contrib/sparsity/COPYRIGHT: -------------------------------------------------------------------------------- 1 | Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. 2 | -------------------------------------------------------------------------------- /apex/contrib/sparsity/__init__.py: -------------------------------------------------------------------------------- 1 | from .sparse_masklib import create_mask 2 | from .asp import ASP 3 | -------------------------------------------------------------------------------- /apex/contrib/sparsity/permutation_search_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from .call_permutation_search_kernels import accelerated_search_for_good_permutation 2 | from .permutation_utilities import sum_after_2_to_4 -------------------------------------------------------------------------------- /apex/contrib/sparsity/permutation_tests/runtime_table.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OUTDIR="results/runtime_logs" 4 | mkdir -p $OUTDIR 5 | 6 | R1000=random,1000 7 | CS=channel_swap,0 8 | CS_100=channel_swap,100 9 | OSG2=optimize_stripe_groups,8,0 10 | OSG2_100=optimize_stripe_groups,8,100 11 | OSG2_1000=optimize_stripe_groups,8,1000 12 | OSG3=optimize_stripe_groups,12,0 13 | OSG3_100=optimize_stripe_groups,12,100 14 | OSG3_1000=optimize_stripe_groups,12,1000 15 | 16 | for cols in "32" "64" "128" "256"; do 17 | echo "$cols x $cols" 18 | python3 permutation_test.py --channels $cols --filters $cols --pretty_print=False $R1000 $CS $CS_100 $OSG2 $OSG2_100 $OSG2_1000 $OSG3 $OSG3_100 $OSG3_1000 | tee "${OUTDIR}/runtime_${cols}x${cols}.log" 19 | let "rows = $cols * 2" 20 | echo "$cols x $rows" 21 | python3 permutation_test.py --channels $cols --filters $rows --pretty_print=False $R1000 $CS $CS_100 $OSG2 $OSG2_100 $OSG2_1000 $OSG3 $OSG3_100 $OSG3_1000 | tee "${OUTDIR}/runtime_${cols}x${rows}.log" 22 | done 23 | 24 | # 2048x2048 is too large for OSG3 25 | echo "2048 x 2048" 26 | python3 permutation_test.py --channels 2048 --filters 2048 --pretty_print=False $R1000 $CS $CS_100 $OSG2 $OSG2_100 $OSG2_1000 | tee "${OUTDIR}/runtime_2048x2048.log" 27 | 28 | 29 | ############### collect results into a .csv file 30 | echo "Gathering results ..." 31 | 32 | # efficacy and runtime from one strategy and size 33 | get_results() { 34 | local strategy=$1 35 | local cols=$2 36 | local rows=$3 37 | local OUTFILE=$4 38 | 39 | grep "$strategy," "$OUTDIR/runtime_${cols}x${rows}.log" | awk -F "," '{printf "%s,%s,",$3,$4}' >> $OUTFILE 40 | } 41 | 42 | # prepare output file headers 43 | OUTFILE="results/runtimes.csv" 44 | printf "Columns," > $OUTFILE 45 | for cols in "32" "64" "128" "256"; do 46 | printf "$cols,$cols,$cols,$cols," >> $OUTFILE 47 | done 48 | printf "2048,2048\n" >> $OUTFILE 49 | 50 | printf "Rows," >> $OUTFILE 51 | for cols in "32" "64" "128" "256"; do 52 | let "rows = $cols * 2" 53 | printf "$cols,$cols,$rows,$rows," >> $OUTFILE 54 | done 55 | printf "2048,2048\n" >> $OUTFILE 56 | 57 | printf "Metric," >> $OUTFILE 58 | for cols in "32" "64" "128" "256"; do 59 | printf "Efficacy,Runtime,Efficay,Runtime," >> $OUTFILE 60 | done 61 | printf "Efficacy,Runtime\n" >> $OUTFILE 62 | 63 | # gather data in a reasonable order 64 | for strategy in "$R1000" "$CS" "$CS_100" "$OSG2" "$OSG2_100" "$OSG2_1000" "$OSG3" "$OSG3_100" "$OSG3_1000"; do 65 | strategy=$(echo "$strategy" | sed 's/,/_/g') # replace commas with underscores, as they'll appear in the results logs 66 | printf "$strategy," >> $OUTFILE 67 | for cols in "32" "64" "128" "256"; do 68 | get_results "$strategy" "$cols" "$cols" "$OUTFILE" 69 | let "rows = $cols * 2" 70 | get_results "$strategy" "$cols" "$rows" "$OUTFILE" 71 | done 72 | 73 | get_results "$strategy" "2048" "2048" "$OUTFILE" 74 | 75 | printf "\n" >> $OUTFILE 76 | done 77 | 78 | echo "Done! $OUTFILE" 79 | -------------------------------------------------------------------------------- /apex/contrib/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/bottleneck/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/bottleneck/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/clip_grad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/clip_grad/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/conv_bias_relu/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/conv_bias_relu/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/cudnn_gbn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/cudnn_gbn/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/fmha/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/fmha/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/focal_loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/focal_loss/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/focal_loss/test_focal_loss.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | reference_available = True 7 | try: 8 | from torchvision.ops.focal_loss import sigmoid_focal_loss 9 | except ImportError: 10 | reference_available = False 11 | 12 | SKIP_TEST = None 13 | try: 14 | from apex.contrib.focal_loss import focal_loss 15 | except ImportError as e: 16 | SKIP_TEST = e 17 | 18 | 19 | @unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") 20 | @unittest.skipIf(not reference_available, "Reference implementation `torchvision.ops.focal_loss.sigmoid_focal_loss` is not available.") 21 | class FocalLossTest(unittest.TestCase): 22 | 23 | N_SAMPLES = 12 24 | N_CLASSES = 8 25 | ALPHA = 0.24 26 | GAMMA = 2.0 27 | REDUCTION = "sum" 28 | 29 | def test_focal_loss(self) -> None: 30 | if not reference_available: 31 | self.skipTest("This test needs `torchvision` for `torchvision.ops.focal_loss.sigmoid_focal_loss`.") 32 | else: 33 | x = torch.randn(FocalLossTest.N_SAMPLES, FocalLossTest.N_CLASSES).cuda() 34 | with torch.no_grad(): 35 | x_expected = x.clone() 36 | x_actual = x.clone() 37 | x_expected.requires_grad_() 38 | x_actual.requires_grad_() 39 | 40 | classes = torch.randint(0, FocalLossTest.N_CLASSES, (FocalLossTest.N_SAMPLES,)).cuda() 41 | with torch.no_grad(): 42 | y = F.one_hot(classes, FocalLossTest.N_CLASSES).float() 43 | 44 | expected = sigmoid_focal_loss( 45 | x_expected, 46 | y, 47 | alpha=FocalLossTest.ALPHA, 48 | gamma=FocalLossTest.GAMMA, 49 | reduction=FocalLossTest.REDUCTION, 50 | ) 51 | 52 | actual = sum([focal_loss.FocalLoss.apply( 53 | x_actual[i:i+1], 54 | classes[i:i+1].long(), 55 | torch.ones([], device="cuda"), 56 | FocalLossTest.N_CLASSES, 57 | FocalLossTest.ALPHA, 58 | FocalLossTest.GAMMA, 59 | 0.0, 60 | ) for i in range(FocalLossTest.N_SAMPLES)]) 61 | 62 | # forward parity 63 | torch.testing.assert_close(expected, actual) 64 | 65 | expected.backward() 66 | actual.backward() 67 | 68 | # grad parity 69 | torch.testing.assert_close(x_expected.grad, x_actual.grad) 70 | 71 | 72 | if __name__ == "__main__": 73 | torch.manual_seed(42) 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /apex/contrib/test/fused_dense/test_fused_dense.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | 4 | import torch 5 | from torch.testing._internal import common_utils 6 | from torch.testing._internal.common_device_type import instantiate_device_type_tests 7 | 8 | SKIP_TEST = None 9 | try: 10 | from apex import fused_dense 11 | except ImportError as e: 12 | SKIP_TEST = e 13 | 14 | 15 | @unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") 16 | class FusedDenseTest(common_utils.TestCase): 17 | 18 | def _test_fused_dense(self, dtype, seed=0): 19 | 20 | os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] = "0" 21 | torch.manual_seed(seed) 22 | 23 | seq_length = 512 24 | sequences = 3 25 | hidden_dim = 1024 26 | 27 | ref_inputs = torch.randn(sequences*seq_length, hidden_dim, 28 | dtype=dtype, device=torch.device("cuda")).requires_grad_(True) 29 | 30 | tst_inputs = ref_inputs.clone().detach().requires_grad_(True) 31 | dense = fused_dense.FusedDense(1024, 3072) 32 | dense.to(dtype=dtype) 33 | dense.cuda() 34 | 35 | y_tst = dense(tst_inputs) 36 | y_ref = torch.matmul(ref_inputs, dense.weight.t())+dense.bias 37 | dy = torch.randn_like(y_tst).to(dtype=dtype) 38 | y_tst.backward(dy) 39 | dw_ref = torch.matmul(dy.t(), ref_inputs) 40 | dx_ref = torch.matmul(dy, dense.weight.clone()) 41 | db_ref = dy.sum(0, False) 42 | 43 | torch.testing.assert_close( 44 | ref_inputs, tst_inputs, atol=1e-5, rtol=1e-5) 45 | torch.testing.assert_close( 46 | y_ref, y_tst, atol=1e-3, rtol=1e-3, equal_nan=True) 47 | torch.testing.assert_close( 48 | dw_ref, dense.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True) 49 | torch.testing.assert_close( 50 | dx_ref, tst_inputs.grad, atol=1e-3, rtol=1e-3, equal_nan=True) 51 | torch.testing.assert_close( 52 | db_ref, dense.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True) 53 | 54 | @common_utils.parametrize("dtype", [torch.half, torch.float, torch.bfloat16]) 55 | def test_fused_dense(self, dtype): 56 | self._test_fused_dense(dtype) 57 | 58 | 59 | instantiate_device_type_tests(FusedDenseTest, globals(), only_for=("cuda",)) 60 | 61 | if __name__ == "__main__": 62 | common_utils.run_tests() 63 | -------------------------------------------------------------------------------- /apex/contrib/test/group_norm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/group_norm/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/index_mul_2d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/index_mul_2d/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/layer_norm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/layer_norm/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/multihead_attn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/multihead_attn/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/multihead_attn/test_mha_fused_softmax.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | SKIP_TEST = None 7 | try: 8 | from apex.contrib.multihead_attn import fast_mask_softmax_dropout_func 9 | except ImportError as e: 10 | SKIP_TEST = e 11 | 12 | 13 | @unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") 14 | class FusedSoftmaxTest(unittest.TestCase): 15 | def setUp(self, seed=1234): 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | self.seq_length = 80 20 | self.sequences = 10 21 | self.hidden_dim = 1024 22 | self.heads = 16 23 | self.dropout_prob = 0.0 24 | 25 | self.mask = (torch.randn(self.sequences, self.seq_length) > 0).cuda() 26 | self.mask = self.mask.half() * -10000 27 | self.ref_inputs = torch.randn( 28 | self.heads * self.sequences, 29 | self.seq_length, 30 | self.seq_length, 31 | dtype=torch.float16, 32 | device=torch.device("cuda"), 33 | ).requires_grad_(True) 34 | 35 | self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True) 36 | 37 | def test_fused_softmax(self): 38 | grads = torch.randn_like(self.tst_inputs) 39 | y_ref = self.ref_inputs.view(self.sequences, self.heads, self.seq_length, self.seq_length) 40 | y_ref = y_ref + self.mask.unsqueeze(1).unsqueeze(2) 41 | y_ref = y_ref.view(self.sequences * self.heads, self.seq_length, self.seq_length) 42 | y_ref = F.softmax(y_ref, dim=-1) 43 | y_ref = torch._fused_dropout(y_ref, 1.0) 44 | 45 | y_tst = fast_mask_softmax_dropout_func(True, self.heads, self.tst_inputs, self.mask, True, 0.0) 46 | y_ref[0].backward(grads) 47 | y_tst.backward(grads) 48 | 49 | torch.testing.assert_close(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5) 50 | torch.testing.assert_close(y_ref[0], y_tst, atol=1e-3, rtol=1e-3) 51 | torch.testing.assert_close(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3) 52 | 53 | 54 | if __name__ == "__main__": 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | SKIP_TEST = None 6 | try: 7 | from apex.contrib.multihead_attn import SelfMultiheadAttn 8 | except ImportError as e: 9 | SKIP_TEST = e 10 | 11 | 12 | @unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") 13 | class SelfMultiheadAttnNormAddTest(unittest.TestCase): 14 | def setUp(self, seed=1234): 15 | torch.manual_seed(seed) 16 | 17 | self.seq_length = 80 18 | self.sequences = 10 19 | self.hidden_dim = 1024 20 | self.heads = 16 21 | self.dropout_prob = 0.0 22 | 23 | self.ref_layer = SelfMultiheadAttn( 24 | self.hidden_dim, self.heads, dropout=self.dropout_prob, bias=False, include_norm_add=True, impl="default" 25 | ) 26 | self.ref_layer.cuda().half() 27 | self.ref_layer.reset_parameters() 28 | self.ref_inputs = torch.randn( 29 | self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda") 30 | ).requires_grad_(True) 31 | 32 | # Reset seed so parameters are identical 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | 36 | self.tst_layer = SelfMultiheadAttn( 37 | self.hidden_dim, self.heads, dropout=self.dropout_prob, bias=False, include_norm_add=True, impl="fast" 38 | ) 39 | self.tst_layer.cuda().half() 40 | self.tst_layer.reset_parameters() 41 | 42 | self.tst_inputs = torch.randn( 43 | self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda") 44 | ).requires_grad_(True) 45 | 46 | def test_self_multihead_attn_norm_add(self): 47 | grads = torch.randn_like(self.tst_inputs) 48 | 49 | for _ in range(0, 5): 50 | ref_outputs, _ = self.ref_layer.forward( 51 | self.ref_inputs, 52 | self.ref_inputs, 53 | self.ref_inputs, 54 | key_padding_mask=None, 55 | need_weights=False, 56 | attn_mask=None, 57 | is_training=True, 58 | ) 59 | 60 | tst_outputs, _ = self.tst_layer.forward( 61 | self.tst_inputs, 62 | self.tst_inputs, 63 | self.tst_inputs, 64 | key_padding_mask=None, 65 | need_weights=False, 66 | attn_mask=None, 67 | is_training=True, 68 | ) 69 | 70 | self.ref_inputs.backward(grads) 71 | self.tst_inputs.backward(grads) 72 | 73 | torch.testing.assert_close(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5) 74 | torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3) 75 | torch.testing.assert_close(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3) 76 | 77 | 78 | if __name__ == "__main__": 79 | unittest.main() 80 | -------------------------------------------------------------------------------- /apex/contrib/test/optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/optimizers/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/peer_memory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/peer_memory/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/transducer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/transducer/__init__.py -------------------------------------------------------------------------------- /apex/contrib/test/xentropy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/contrib/test/xentropy/__init__.py -------------------------------------------------------------------------------- /apex/contrib/torchsched/__init__.py: -------------------------------------------------------------------------------- 1 | """Graph scheduler package.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | import torch 8 | import torch._inductor 9 | from torch._dynamo import list_backends 10 | from torch._dynamo import register_backend 11 | from torch._inductor.compile_fx import compile_fx_inner 12 | 13 | from .backend import get_backend 14 | 15 | if TYPE_CHECKING: 16 | from typing import Any 17 | from typing import Callable 18 | 19 | from torch._ops import OpOverload 20 | 21 | __all__ = ["get_backend", "set_default_backend"] 22 | 23 | # Register custom operators 24 | torch.ops.import_module("apex.contrib.torchsched.ops") 25 | 26 | 27 | # Register torch-sched backend 28 | # Same API as torch._inductor.compile_fx 29 | @register_backend 30 | def torchsched( 31 | model_: torch.fx.GraphModule, 32 | example_inputs_: list[torch.Tensor], 33 | inner_compile: Callable[..., Any] = compile_fx_inner, 34 | config_patches: dict[str, Any] | None = None, 35 | decompositions: dict[OpOverload, Callable[..., Any]] | None = None, 36 | ) -> Callable: 37 | backend = get_backend(backend="torchsched", scheme="dwb") 38 | return backend(model_, example_inputs_, inner_compile, config_patches, decompositions) 39 | 40 | 41 | _SUPPORTED_BACKENDS = list_backends() 42 | _DEFAULT_BACKEND = "inductor" 43 | __torch_compile__ = torch.compile 44 | 45 | 46 | def set_default_backend(backend: str) -> None: 47 | """ 48 | Set the default backend for torch.compile. 49 | 50 | Parameters: 51 | backend (str): The backend to use as the default for torch.compile. 52 | """ 53 | global _SUPPORTED_BACKENDS, _DEFAULT_BACKEND 54 | assert backend in _SUPPORTED_BACKENDS, f"Unknown backend {backend}" 55 | _DEFAULT_BACKEND = backend 56 | 57 | 58 | def torchsched_compile( 59 | *args: object, 60 | backend: str | Callable | None = None, 61 | **kwargs: object, 62 | ) -> object: 63 | """ 64 | Wrap around the original torch.compile to support default backend. 65 | 66 | Parameters: 67 | *args (object): Positional arguments for torch.compile. 68 | backend (Union[str, Callable, None]): The backend to use. 69 | If None, the default backend is used. 70 | **kwargs (object): Additional keyword arguments for torch.compile. 71 | 72 | Returns: 73 | object: Compiler or compiled model. 74 | """ 75 | if backend is None: 76 | backend = _DEFAULT_BACKEND 77 | return __torch_compile__(*args, backend=backend, **kwargs) 78 | 79 | 80 | # Monkey patch torch.compile to set default backend 81 | torch.compile = torchsched_compile 82 | -------------------------------------------------------------------------------- /apex/contrib/torchsched/config.py: -------------------------------------------------------------------------------- 1 | """Configurations for graph scheduler.""" 2 | 3 | import os 4 | import sys 5 | 6 | # Debug info and dump grpahs 7 | debug = os.getenv("TORCH_SCHED_DEBUG", "0") == "1" 8 | 9 | # Toggle pre_grad_pass for various pattern matches 10 | enable_pre_grad_pass = False 11 | 12 | # Pre grad pass patterns 13 | pre_grad_pass_options: list[str] = ["cudnn_layer_norm"] 14 | 15 | # Number of CUDA streams used for multi-stream scheduling. 16 | # The first stream will be critical path stream, operators on non-critical path will be 17 | # scheduled to other streams in a round-robin way. 18 | num_streams = int(os.getenv("TORCH_SCHED_NUM_STREAMS", "8")) 19 | 20 | from torch.utils._config_module import install_config_module # noqa: E402 21 | 22 | # adds patch, save_config, etc 23 | install_config_module(sys.modules[__name__]) 24 | -------------------------------------------------------------------------------- /apex/contrib/torchsched/inductor/__init__.py: -------------------------------------------------------------------------------- 1 | """Scheduling abstractions on PyTorch Inductor level.""" 2 | 3 | from apex.contrib.torchsched.inductor.graph import patch_graph_lowering 4 | 5 | __all__ = ["patch_graph_lowering"] 6 | -------------------------------------------------------------------------------- /apex/contrib/torchsched/inductor/graph.py: -------------------------------------------------------------------------------- 1 | """Scheduling abstractions on PyTorch Inductor GraphLowering level.""" 2 | 3 | from __future__ import annotations 4 | 5 | import functools 6 | from typing import TYPE_CHECKING 7 | 8 | from torch._inductor.graph import GraphLowering 9 | from torch._inductor.scheduler import Scheduler 10 | from torch._inductor.virtualized import V 11 | 12 | if TYPE_CHECKING: 13 | from torch.fx.node import Node 14 | 15 | from apex.contrib.torchsched.inductor.scheduler import MultiCudaStreamScheduler 16 | 17 | 18 | @functools.wraps(GraphLowering.codegen) 19 | def _codegen(graph: GraphLowering) -> tuple[str, list[tuple[int, Node]]]: 20 | graph.init_wrapper_code() 21 | 22 | if graph.device_type == "cuda": 23 | graph.scheduler = MultiCudaStreamScheduler(graph.operations) 24 | else: 25 | graph.scheduler = Scheduler(graph.operations) 26 | V.debug.draw_orig_fx_graph(graph.orig_gm, graph.scheduler.nodes) 27 | 28 | graph.wrapper_code.push_codegened_graph(graph) 29 | graph.scheduler.codegen() 30 | result = graph.wrapper_code.generate(graph.is_inference) 31 | graph.wrapper_code.pop_codegened_graph() 32 | 33 | return result 34 | 35 | 36 | _origin_codegen = GraphLowering.codegen 37 | 38 | 39 | def patch_graph_lowering(patch: bool = True) -> None: 40 | """Patch PyTorch Inductor lowerings with multi-stream scheduling. 41 | 42 | This function patches the `torch.compile` stack on the GraphLowering level, 43 | i.e., the compute graph has been captured by Dynamo and it has undergone 44 | post-auto-gradient passes, including pattern-matching optimizations and 45 | preliminary operator fusions. At that point, most nodes in the graph are 46 | either fused Triton templates, or function calls to external libraries. The 47 | multi-stream scheduler then finds the longest critical path in this graph, 48 | and schedule other nodes to side streams to exploit the inherent parallelism 49 | of the given compute graph. 50 | 51 | Args: 52 | patch: Whether to patch Inductor `GraphLowering` with multi-stream 53 | scheduler. Set to `False` to restore the default `torch.compile` 54 | behavior. (default: `True`) 55 | """ 56 | if patch: 57 | GraphLowering.codegen = _codegen 58 | else: 59 | GraphLowering.codegen = _origin_codegen 60 | -------------------------------------------------------------------------------- /apex/contrib/torchsched/ops/__init__.py: -------------------------------------------------------------------------------- 1 | """Custom PyTorch operators.""" 2 | 3 | import torch 4 | 5 | __all__: list[str] = [] 6 | 7 | # Register custom operators 8 | torch.ops.import_module("apex.contrib.torchsched.ops.layer_norm") 9 | -------------------------------------------------------------------------------- /apex/contrib/torchsched/passes/__init__.py: -------------------------------------------------------------------------------- 1 | """Customized compiler passes.""" 2 | 3 | from __future__ import annotations 4 | 5 | from apex.contrib.torchsched.passes.pre_grad_passes import pre_grad_custom_pass 6 | 7 | __all__ = ["pre_grad_custom_pass"] 8 | -------------------------------------------------------------------------------- /apex/contrib/transducer/__init__.py: -------------------------------------------------------------------------------- 1 | from .transducer import TransducerJoint 2 | from .transducer import TransducerLoss 3 | from . import _transducer_ref 4 | -------------------------------------------------------------------------------- /apex/contrib/xentropy/__init__.py: -------------------------------------------------------------------------------- 1 | from .softmax_xentropy import SoftmaxCrossEntropyLoss 2 | 3 | 4 | __all__ = [ 5 | "SoftmaxCrossEntropyLoss", 6 | ] 7 | -------------------------------------------------------------------------------- /apex/contrib/xentropy/softmax_xentropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import xentropy_cuda 4 | 5 | 6 | class SoftmaxCrossEntropyLoss(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False): 9 | losses, max_log_sum_exp = xentropy_cuda.forward( 10 | logits, labels, smoothing, half_to_float) 11 | losses.masked_fill_(labels==padding_idx, 0) 12 | 13 | ctx.save_for_backward(logits, max_log_sum_exp, labels, 14 | torch.FloatTensor([smoothing]), 15 | torch.LongTensor([padding_idx])) 16 | 17 | return losses 18 | 19 | @staticmethod 20 | def backward(ctx, grad_loss): 21 | logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors 22 | 23 | if not grad_loss.is_contiguous(): 24 | grad_loss = grad_loss.contiguous() 25 | grad_loss.masked_fill_(labels==padding_idx.item(), 0) 26 | grad_logits = xentropy_cuda.backward( 27 | grad_loss.contiguous(), logits, max_log_sum_exp, 28 | labels, smoothing.item()) 29 | 30 | return grad_logits, None, None, None, None 31 | -------------------------------------------------------------------------------- /apex/fused_dense/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_dense import * 2 | -------------------------------------------------------------------------------- /apex/mlp/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import * 2 | -------------------------------------------------------------------------------- /apex/mlp/mlp.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from apex._autocast_utils import _cast_if_autocast_enabled 8 | import mlp_cuda 9 | 10 | 11 | class MlpFunction(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, bias, activation, *args): 14 | output = mlp_cuda.forward(bias, activation, args) 15 | ctx.save_for_backward(*args) 16 | ctx.outputs = output 17 | ctx.bias = bias 18 | ctx.activation = activation 19 | return output[0] 20 | 21 | @staticmethod 22 | def backward(ctx, grad_o): 23 | grads = mlp_cuda.backward(ctx.bias, ctx.activation, grad_o, ctx.outputs, ctx.saved_tensors) 24 | del ctx.outputs 25 | return (None, None, *grads) 26 | 27 | 28 | def mlp_function(bias, activation, *args): 29 | autocast_args = _cast_if_autocast_enabled(bias, activation, *args) 30 | return MlpFunction.apply(*autocast_args) 31 | 32 | 33 | class MLP(torch.nn.Module): 34 | """Launch MLP in C++ 35 | 36 | Args: 37 | mlp_sizes (list of int): MLP sizes. Example: [1024,1024,1024] will create 2 MLP layers with shape 1024x1024 38 | bias (bool): Default True: 39 | relu (bool): Default True 40 | """ 41 | def __init__(self, mlp_sizes, bias=True, activation='relu'): 42 | super().__init__() 43 | self.num_layers = len(mlp_sizes) - 1 44 | self.mlp_sizes = copy(mlp_sizes) 45 | self.bias = 1 if bias else 0 46 | 47 | if activation == 'none': 48 | self.activation = 0 49 | elif activation == 'relu': 50 | self.activation = 1 51 | elif activation == 'sigmoid': 52 | self.activation = 2 53 | else: 54 | raise TypeError("activation must be relu or none.") 55 | 56 | self.weights = [] 57 | self.biases = [] 58 | for i in range(self.num_layers): 59 | w = torch.nn.Parameter(torch.empty(mlp_sizes[i+1], mlp_sizes[i])) 60 | self.weights.append(w) 61 | name = 'weight_{}'.format(i) 62 | setattr(self, name, w) 63 | if self.bias: 64 | b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1])) 65 | self.biases.append(b) 66 | name = 'bias_{}'.format(i) 67 | setattr(self, name, b) 68 | 69 | self.reset_parameters() 70 | 71 | def reset_parameters(self): 72 | for weight in self.weights: 73 | dimsum = weight.size(0) + weight.size(1) 74 | std = math.sqrt(2. / float(dimsum)) 75 | nn.init.normal_(weight, 0., std) 76 | if self.bias: 77 | for bias in self.biases: 78 | std = math.sqrt(1. / float(bias.size(0))) 79 | nn.init.normal_(bias, 0., std) 80 | 81 | def forward(self, input): 82 | return mlp_function(self.bias, self.activation, input, *self.weights, *self.biases) 83 | 84 | def extra_repr(self): 85 | s = F"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, activation={self.activation}" 86 | return s 87 | -------------------------------------------------------------------------------- /apex/multi_tensor_apply/__init__.py: -------------------------------------------------------------------------------- 1 | from .multi_tensor_apply import MultiTensorApply 2 | 3 | multi_tensor_applier = MultiTensorApply(2048*32) 4 | 5 | -------------------------------------------------------------------------------- /apex/multi_tensor_apply/multi_tensor_apply.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MultiTensorApply(object): 4 | available = False 5 | warned = False 6 | 7 | def __init__(self, chunk_size): 8 | try: 9 | import amp_C 10 | MultiTensorApply.available = True 11 | self.chunk_size = chunk_size 12 | except ImportError as err: 13 | MultiTensorApply.available = False 14 | MultiTensorApply.import_err = err 15 | 16 | def check_avail(self): 17 | if MultiTensorApply.available == False: 18 | raise RuntimeError( 19 | "Attempted to call MultiTensorApply method, but MultiTensorApply " 20 | "is not available, possibly because Apex was installed without " 21 | "--cpp_ext --cuda_ext. Original import error message:", 22 | MultiTensorApply.import_err) 23 | 24 | def __call__(self, op, noop_flag_buffer, tensor_lists, *args): 25 | self.check_avail() 26 | 27 | return op(self.chunk_size, 28 | noop_flag_buffer, 29 | tensor_lists, 30 | *args) 31 | -------------------------------------------------------------------------------- /apex/normalization/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm, FusedRMSNorm, MixedFusedRMSNorm 2 | -------------------------------------------------------------------------------- /apex/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_sgd import FusedSGD 2 | from .fused_adam import FusedAdam 3 | from .fused_novograd import FusedNovoGrad 4 | from .fused_lamb import FusedLAMB 5 | from .fused_adagrad import FusedAdagrad 6 | from .fused_mixed_precision_lamb import FusedMixedPrecisionLamb 7 | -------------------------------------------------------------------------------- /apex/transformer/README.md: -------------------------------------------------------------------------------- 1 | # apex.transformer 2 | 3 | `apex.transformer` is a module which enables efficient large Transformer models at scale. 4 | 5 | `apex.transformer.tensor_parallel` and `apex.transformer.pipeline_parallel` are both based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s module. 6 | The former is based on `megatron.mpu` and the latter is on `megatron.schedules` and `megatron.p2p_communication`. 7 | 8 | ## Tensor Model Parallel (TP) 9 | 10 | APEX's tensor model parallel utilities provides some `torch.nn.Module`'s, custom fused kernels, and PRNG state handling. 11 | See Appendix B.2 of [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) for the details of 12 | PRNG state handling. 13 | 14 | ## Pipeline Model Parallel (PP) 15 | APEX's pipeline model parallel functions require models to have `.set_input_tensor` because 16 | the input tensor for `.forward` method can be `None`. 17 | 18 | The following is a really casual sketch of training script with apex pp. 19 | 20 | ```python 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | 25 | from apex.transformer import parallel_state 26 | from apex.transformer.pipeline_parallel import get_forward_backward_func 27 | 28 | 29 | class Model(nn.Module): 30 | 31 | ... 32 | 33 | def __init__(self, *args, **kwargs): 34 | super().__init__() 35 | pre_process = kwargs.pop("pre_process") 36 | post_process = kwargs.pop("post_process") 37 | 38 | def set_input_tensor(self, tensor): 39 | self.input_tensor = tensor 40 | 41 | def forward(self, x, ...): 42 | if parallel_state.is_pipeline_first_stage(): 43 | input = x 44 | else: 45 | input = self.input_tensor 46 | ... 47 | 48 | 49 | def model_provider_func(*args, **kwargs): 50 | return Model(*args, **kwargs) 51 | 52 | 53 | def loss_func(pred, label): 54 | loss = ... 55 | averaged_loss = average_losses_across_data_parallel_group([loss]) 56 | return loss, {'nice_loss': averaged_loss} 57 | 58 | 59 | def forward_step_func(batch, model): 60 | input, label = process_batch(batch) 61 | out = model(input) 62 | return out, partial(loss_func, label) 63 | 64 | 65 | forward_backward_func = get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size) 66 | 67 | 68 | parallel_state.initialize_model_parallel( 69 | tensor_model_parallel_size, 70 | pipeline_model_parallel_size, 71 | virtual_pipeline_model_parallel_size, 72 | ) 73 | # The following line basically is equivalent to `build_model(Model, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)` 74 | model = build_model(model_provider_func, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs) 75 | optimizer = ... 76 | data_loader = ... 77 | for epoch in range(num_epochs): 78 | for batch in data_loader: 79 | forward_backward_func(forward_step_func, batch, model, forward_only=False, tensor_shape) 80 | optimizer.step() 81 | ``` 82 | -------------------------------------------------------------------------------- /apex/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from apex.transformer import amp 2 | from apex.transformer import functional 3 | from apex.transformer import parallel_state 4 | from apex.transformer import pipeline_parallel 5 | from apex.transformer import tensor_parallel 6 | from apex.transformer import utils 7 | from apex.transformer.enums import LayerType 8 | from apex.transformer.enums import AttnType 9 | from apex.transformer.enums import AttnMaskType 10 | 11 | 12 | __all__ = [ 13 | "amp", 14 | "functional", 15 | "parallel_state", 16 | "pipeline_parallel", 17 | "tensor_parallel", 18 | "utils", 19 | # enums.py 20 | "LayerType", 21 | "AttnType", 22 | "AttnMaskType", 23 | ] 24 | -------------------------------------------------------------------------------- /apex/transformer/_data/__init__.py: -------------------------------------------------------------------------------- 1 | from apex.transformer._data._batchsampler import MegatronPretrainingRandomSampler 2 | from apex.transformer._data._batchsampler import MegatronPretrainingSampler 3 | 4 | 5 | __all__ = [ 6 | "MegatronPretrainingRandomSampler", 7 | "MegatronPretrainingSampler", 8 | ] 9 | -------------------------------------------------------------------------------- /apex/transformer/_ucc_util.py: -------------------------------------------------------------------------------- 1 | from torch import distributed as dist 2 | 3 | HAS_UCC = hasattr(dist, "is_ucc_available") and dist.is_ucc_available() 4 | if not HAS_UCC: 5 | try: 6 | import torch_ucc 7 | HAS_UCC = True 8 | except ImportError: 9 | HAS_UCC = False 10 | -------------------------------------------------------------------------------- /apex/transformer/amp/__init__.py: -------------------------------------------------------------------------------- 1 | from apex.transformer.amp.grad_scaler import GradScaler 2 | 3 | 4 | __all__ = [ 5 | "GradScaler", 6 | ] 7 | -------------------------------------------------------------------------------- /apex/transformer/enums.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import enum 16 | 17 | 18 | class LayerType(enum.Enum): 19 | encoder = 1 20 | decoder = 2 21 | 22 | 23 | class AttnType(enum.Enum): 24 | self_attn = 1 25 | cross_attn = 2 26 | 27 | 28 | class AttnMaskType(enum.Enum): 29 | padding = 1 30 | causal = 2 31 | 32 | 33 | class ModelType(enum.Enum): 34 | encoder_or_decoder = 1 35 | encoder_and_decoder = 2 36 | -------------------------------------------------------------------------------- /apex/transformer/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from apex.transformer.functional.fused_rope import ( 2 | fused_apply_rotary_pos_emb, 3 | fused_apply_rotary_pos_emb_cached, 4 | fused_apply_rotary_pos_emb_thd, 5 | fused_apply_rotary_pos_emb_2d, 6 | ) 7 | from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax 8 | 9 | __all__ = [ 10 | "FusedScaleMaskSoftmax", 11 | "fused_apply_rotary_pos_emb", 12 | "fused_apply_rotary_pos_emb_cached", 13 | "fused_apply_rotary_pos_emb_thd", 14 | "fused_apply_rotary_pos_emb_2d", 15 | ] 16 | -------------------------------------------------------------------------------- /apex/transformer/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | from apex.transformer.layers.layer_norm import FastLayerNorm 3 | from apex.transformer.layers.layer_norm import FusedLayerNorm 4 | from apex.transformer.layers.layer_norm import MixedFusedLayerNorm 5 | 6 | 7 | __all__ = [ 8 | "FastLayerNorm", 9 | "FusedLayerNorm", 10 | "MixedFusedLayerNorm", 11 | ] 12 | -------------------------------------------------------------------------------- /apex/transformer/log_util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | 5 | def get_transformer_logger(name: str) -> logging.Logger: 6 | name_wo_ext = os.path.splitext(name)[0] 7 | return logging.getLogger(name_wo_ext) 8 | 9 | 10 | def set_logging_level(verbosity) -> None: 11 | """Change logging severity. 12 | 13 | Args: 14 | verbosity 15 | """ 16 | from apex import _library_root_logger 17 | 18 | _library_root_logger.setLevel(verbosity) 19 | -------------------------------------------------------------------------------- /apex/transformer/pipeline_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func 2 | from apex.transformer.pipeline_parallel.schedules.common import build_model 3 | 4 | 5 | __all__ = [ 6 | "get_forward_backward_func", 7 | "build_model", 8 | ] 9 | -------------------------------------------------------------------------------- /apex/transformer/pipeline_parallel/_timers.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | 5 | 6 | class _Timer: 7 | """Timer.""" 8 | 9 | def __init__(self, name): 10 | self.name_ = name 11 | self.elapsed_ = 0.0 12 | self.started_ = False 13 | self.start_time = time.time() 14 | 15 | def start(self): 16 | """Start the timer.""" 17 | assert not self.started_, "timer has already been started" 18 | torch.cuda.synchronize() 19 | self.start_time = time.time() 20 | self.started_ = True 21 | 22 | def stop(self): 23 | """Stop the timer.""" 24 | assert self.started_, "timer is not started" 25 | torch.cuda.synchronize() 26 | self.elapsed_ += time.time() - self.start_time 27 | self.started_ = False 28 | 29 | def reset(self): 30 | """Reset timer.""" 31 | self.elapsed_ = 0.0 32 | self.started_ = False 33 | 34 | def elapsed(self, reset=True): 35 | """Calculate the elapsed time.""" 36 | started_ = self.started_ 37 | # If the timing in progress, end it first. 38 | if self.started_: 39 | self.stop() 40 | # Get the elapsed time. 41 | elapsed_ = self.elapsed_ 42 | # Reset the elapsed time 43 | if reset: 44 | self.reset() 45 | # If timing was in progress, set it back. 46 | if started_: 47 | self.start() 48 | return elapsed_ 49 | 50 | 51 | class _Timers: 52 | """Group of timers.""" 53 | 54 | def __init__(self): 55 | self.timers = {} 56 | 57 | def __call__(self, name): 58 | if name not in self.timers: 59 | self.timers[name] = _Timer(name) 60 | return self.timers[name] 61 | 62 | def write(self, names, writer, iteration, normalizer=1.0, reset=False): 63 | """Write timers to a tensorboard writer""" 64 | # currently when using add_scalars, 65 | # torch.utils.add_scalars makes each timer its own run, which 66 | # polutes the runs list, so we just add each as a scalar 67 | assert normalizer > 0.0 68 | for name in names: 69 | value = self.timers[name].elapsed(reset=reset) / normalizer 70 | writer.add_scalar(name + "-time", value, iteration) 71 | 72 | def log(self, names, normalizer=1.0, reset=True): 73 | """Log a group of timers.""" 74 | assert normalizer > 0.0 75 | string = "time (ms)" 76 | for name in names: 77 | elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer 78 | string += " | {}: {:.2f}".format(name, elapsed_time) 79 | if torch.distributed.is_initialized(): 80 | if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): 81 | print(string, flush=True) 82 | else: 83 | print(string, flush=True) 84 | -------------------------------------------------------------------------------- /apex/transformer/pipeline_parallel/schedules/__init__.py: -------------------------------------------------------------------------------- 1 | from apex.transformer import parallel_state 2 | from apex.transformer.pipeline_parallel.utils import get_num_microbatches 3 | from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import ( 4 | forward_backward_no_pipelining, 5 | ) 6 | from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( 7 | _forward_backward_pipelining_with_interleaving, 8 | ) 9 | from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( 10 | forward_backward_pipelining_without_interleaving, 11 | ) 12 | 13 | __all__ = [ 14 | "get_forward_backward_func", 15 | ] 16 | 17 | 18 | class ExperimentalWarning(Warning): 19 | pass 20 | 21 | 22 | def get_forward_backward_func( 23 | virtual_pipeline_model_parallel_size, pipeline_model_parallel_size, 24 | ): 25 | if parallel_state.get_pipeline_model_parallel_world_size() > 1: 26 | if virtual_pipeline_model_parallel_size is not None: 27 | if get_num_microbatches() % pipeline_model_parallel_size != 0: 28 | msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule" 29 | raise RuntimeError(msg) 30 | forward_backward_func = _forward_backward_pipelining_with_interleaving 31 | else: 32 | forward_backward_func = forward_backward_pipelining_without_interleaving 33 | else: 34 | forward_backward_func = forward_backward_no_pipelining 35 | return forward_backward_func 36 | -------------------------------------------------------------------------------- /apex/transformer/tensor_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Model parallel utility interface.""" 16 | 17 | from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy 18 | 19 | from apex.transformer.tensor_parallel.data import broadcast_data 20 | 21 | from apex.transformer.tensor_parallel.layers import ( 22 | ColumnParallelLinear, 23 | RowParallelLinear, 24 | VocabParallelEmbedding, 25 | set_tensor_model_parallel_attributes, 26 | set_defaults_if_not_set_tensor_model_parallel_attributes, 27 | copy_tensor_model_parallel_attributes, 28 | ) 29 | 30 | from apex.transformer.tensor_parallel.mappings import ( 31 | copy_to_tensor_model_parallel_region, 32 | gather_from_tensor_model_parallel_region, 33 | reduce_from_tensor_model_parallel_region, 34 | scatter_to_tensor_model_parallel_region, 35 | scatter_to_sequence_parallel_region, 36 | ) 37 | 38 | from .random import ( 39 | checkpoint, 40 | get_cuda_rng_tracker, 41 | init_checkpointed_activations_memory_buffer, 42 | model_parallel_cuda_manual_seed, 43 | reset_checkpointed_activations_memory_buffer, 44 | ) 45 | 46 | from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim 47 | 48 | 49 | __all__ = [ 50 | # cross_entropy.py 51 | "vocab_parallel_cross_entropy", 52 | # data.py 53 | "broadcast_data", 54 | # layers.py 55 | "ColumnParallelLinear", 56 | "RowParallelLinear", 57 | "VocabParallelEmbedding", 58 | "set_tensor_model_parallel_attributes", 59 | "set_defaults_if_not_set_tensor_model_parallel_attributes", 60 | "copy_tensor_model_parallel_attributes", 61 | # mappings.py 62 | "copy_to_tensor_model_parallel_region", 63 | "gather_from_tensor_model_parallel_region", 64 | "reduce_from_tensor_model_parallel_region", 65 | "scatter_to_tensor_model_parallel_region", 66 | "scatter_to_sequence_parallel_region", 67 | # random.py 68 | "checkpoint", 69 | "get_cuda_rng_tracker", 70 | "init_checkpointed_activations_memory_buffer", 71 | "model_parallel_cuda_manual_seed", 72 | "reset_checkpointed_activations_memory_buffer", 73 | # utils.py 74 | "split_tensor_along_last_dim", 75 | ] 76 | -------------------------------------------------------------------------------- /apex/transformer/tensor_parallel/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from typing import List, Sequence 16 | 17 | import torch 18 | 19 | from apex.transformer.utils import divide 20 | 21 | 22 | def split_tensor_along_last_dim( 23 | tensor: torch.Tensor, 24 | num_partitions: int, 25 | contiguous_split_chunks: bool = False, 26 | ) -> List[torch.Tensor]: 27 | """Split a tensor along its last dimension. 28 | Arguments: 29 | tensor: input tensor. 30 | num_partitions: number of partitions to split the tensor 31 | contiguous_split_chunks: If True, make each chunk contiguous 32 | in memory. 33 | """ 34 | # Get the size and dimension. 35 | last_dim = tensor.dim() - 1 36 | last_dim_size = divide(tensor.size()[last_dim], num_partitions) 37 | # Split. 38 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 39 | # Note: torch.split does not create contiguous tensors by default. 40 | if contiguous_split_chunks: 41 | return tuple(chunk.contiguous() for chunk in tensor_list) 42 | 43 | return tensor_list 44 | 45 | 46 | class VocabUtility: 47 | """Split the vocabulary into `world_size` chunks and return the 48 | first and last index of the vocabulary belonging to the `rank` 49 | partition: Note that indices in [fist, last)""" 50 | 51 | @staticmethod 52 | def vocab_range_from_per_partition_vocab_size( 53 | per_partition_vocab_size: int, rank, world_size: int 54 | ) -> Sequence[int]: 55 | index_f = rank * per_partition_vocab_size 56 | index_l = index_f + per_partition_vocab_size 57 | return index_f, index_l 58 | 59 | @staticmethod 60 | def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]: 61 | per_partition_vocab_size = divide(global_vocab_size, world_size) 62 | return VocabUtility.vocab_range_from_per_partition_vocab_size( 63 | per_partition_vocab_size, rank, world_size 64 | ) 65 | -------------------------------------------------------------------------------- /apex/transformer/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/apex/transformer/testing/__init__.py -------------------------------------------------------------------------------- /apex/transformer/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions used by both `pipeline_parallel` and `tensor_parallel`""" 2 | import torch 3 | 4 | from apex.transformer import parallel_state 5 | 6 | # `all_gather_into_tensor` is new placeholders for `_all_gather_base`. 7 | # It requires the most recent version of PyTorch. 8 | # The following 4 lines are for backward comparability with 9 | # older PyTorch. 10 | if "all_gather_into_tensor" not in dir(torch.distributed): 11 | torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base 12 | 13 | def ensure_divisibility(numerator, denominator): 14 | """Ensure that numerator is divisible by the denominator.""" 15 | assert numerator % denominator == 0, "{} is not divisible by {}".format( 16 | numerator, denominator 17 | ) 18 | 19 | 20 | def divide(numerator, denominator): 21 | """Ensure that numerator is divisible by the denominator and return 22 | the division value.""" 23 | ensure_divisibility(numerator, denominator) 24 | return numerator // denominator 25 | 26 | 27 | def split_tensor_into_1d_equal_chunks(tensor): 28 | """Break a tensor into equal 1D chunks.""" 29 | data = tensor.view(-1) 30 | partition_size = ( 31 | torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size() 32 | ) 33 | start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() 34 | end_index = start_index + partition_size 35 | return data[start_index:end_index] 36 | 37 | 38 | def gather_split_1d_tensor(tensor): 39 | """Opposite of above function, gather values from model parallel ranks.""" 40 | world_size = parallel_state.get_tensor_model_parallel_world_size() 41 | numel = torch.numel(tensor) 42 | numel_gathered = world_size * numel 43 | gathered = torch.empty( 44 | numel_gathered, 45 | dtype=tensor.dtype, 46 | device=torch.cuda.current_device(), 47 | requires_grad=False, 48 | ) 49 | torch.distributed.all_gather_into_tensor( 50 | gathered, 51 | tensor, 52 | group=parallel_state.get_tensor_model_parallel_group() 53 | ) 54 | return gathered 55 | -------------------------------------------------------------------------------- /csrc/compat.h: -------------------------------------------------------------------------------- 1 | #ifndef TORCH_CHECK 2 | #define TORCH_CHECK AT_CHECK 3 | #endif 4 | 5 | #ifdef VERSION_GE_1_3 6 | #define DATA_PTR data_ptr 7 | #else 8 | #define DATA_PTR data 9 | #endif 10 | -------------------------------------------------------------------------------- /csrc/flatten_unflatten.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | // https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h 4 | 5 | at::Tensor flatten(std::vector tensors) 6 | { 7 | return torch::utils::flatten_dense_tensors(tensors); 8 | } 9 | 10 | std::vector unflatten(at::Tensor flat, std::vector tensors) 11 | { 12 | return torch::utils::unflatten_dense_tensors(flat, tensors); 13 | } 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("flatten", &flatten, "Flatten dense tensors", py::call_guard()); 17 | m.def("unflatten", &unflatten, "Unflatten dense tensors", py::call_guard()); 18 | } 19 | -------------------------------------------------------------------------------- /csrc/megatron/fused_weight_gradient_dense.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | void wgrad_gemm_accum_fp32_cuda_stub( 7 | at::Tensor &input_2d, 8 | at::Tensor &d_output_2d, 9 | at::Tensor &d_weight 10 | ); 11 | 12 | void wgrad_gemm_accum_fp16_cuda_stub( 13 | at::Tensor &input_2d, 14 | at::Tensor &d_output_2d, 15 | at::Tensor &d_weight 16 | ); 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32_cuda_stub, "wgrad gemm accum in fp32", py::call_guard()); 20 | m.def("wgrad_gemm_accum_fp16", &wgrad_gemm_accum_fp16_cuda_stub, "wgrad gemm accum in fp16", py::call_guard()); 21 | } 22 | -------------------------------------------------------------------------------- /csrc/megatron/scaled_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | namespace multihead_attn { 22 | namespace fused_softmax { 23 | namespace scaled_softmax { 24 | 25 | torch::Tensor fwd_cuda( 26 | torch::Tensor const& input, 27 | float scale_factor); 28 | 29 | torch::Tensor bwd_cuda( 30 | torch::Tensor const& output_grads, 31 | torch::Tensor const& softmax_results, 32 | float scale_factor); 33 | 34 | torch::Tensor fwd( 35 | torch::Tensor const& input, 36 | float scale_factor) { 37 | TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); 38 | TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || 39 | (input.scalar_type() == at::ScalarType::BFloat16), 40 | "Only fp16 and bf16 are supported"); 41 | 42 | return fwd_cuda(input, scale_factor); 43 | } 44 | 45 | torch::Tensor bwd( 46 | torch::Tensor const& output_grads, 47 | torch::Tensor const& softmax_results, 48 | float scale_factor) { 49 | 50 | TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); 51 | TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); 52 | 53 | TORCH_CHECK((output_grads.scalar_type() == at::ScalarType::Half) || 54 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 55 | "Only fp16 and bf16 are supported"); 56 | TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) || 57 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 58 | "Only fp16 and bf16 are supported"); 59 | 60 | return bwd_cuda(output_grads, softmax_results, scale_factor); 61 | } 62 | 63 | } // end namespace scaled_softmax 64 | } // end namespace fused_softmax 65 | } // end namespace multihead_attn 66 | 67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 68 | m.def("forward", 69 | &multihead_attn::fused_softmax::scaled_softmax::fwd, 70 | "Self Multihead Attention scaled, softmax -- Forward.", py::call_guard()); 71 | m.def("backward", 72 | &multihead_attn::fused_softmax::scaled_softmax::bwd, 73 | "Self Multihead Attention scaled, softmax -- Backward.", py::call_guard()); 74 | } 75 | 76 | -------------------------------------------------------------------------------- /csrc/megatron/scaled_upper_triang_masked_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | namespace multihead_attn { 22 | namespace fused_softmax { 23 | namespace scaled_upper_triang_masked_softmax { 24 | 25 | torch::Tensor fwd_cuda( 26 | torch::Tensor const& input, 27 | float scale_factor); 28 | 29 | torch::Tensor bwd_cuda( 30 | torch::Tensor const& output_grads, 31 | torch::Tensor const& softmax_results, 32 | float scale_factor); 33 | 34 | torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { 35 | TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); 36 | TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || 37 | (input.scalar_type() == at::ScalarType::BFloat16), 38 | "Only fp16 and bf16 are supported"); 39 | 40 | return fwd_cuda(input, scale_factor); 41 | } 42 | 43 | torch::Tensor bwd( 44 | torch::Tensor const& output_grads, 45 | torch::Tensor const& softmax_results, 46 | float scale_factor) { 47 | 48 | TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); 49 | TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); 50 | 51 | TORCH_CHECK((output_grads.scalar_type() == at::ScalarType::Half) || 52 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 53 | "Only fp16 and bf16 are supported"); 54 | TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) || 55 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 56 | "Only fp16 and bf16 are supported"); 57 | 58 | return bwd_cuda(output_grads, softmax_results, scale_factor); 59 | } 60 | 61 | } // end namespace scaled_upper_triang_masked_softmax 62 | } // end namespace fused_softmax 63 | } // end namespace multihead_attn 64 | 65 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 66 | m.def("forward", 67 | &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, 68 | "Self Multihead Attention scaled, time masked softmax -- Forward.", py::call_guard()); 69 | m.def("backward", 70 | &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, 71 | "Self Multihead Attention scaled, time masked softmax -- Backward.", py::call_guard()); 72 | } 73 | -------------------------------------------------------------------------------- /csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | // From 2 | // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr static bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr static bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /csrc/update_scale_hysteresis.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | __global__ void update_scale_hysteresis_cuda_kernel(float* current_scale, 6 | int* growth_tracker, 7 | int* hysteresis_tracker, 8 | const float* found_inf, 9 | double growth_factor, 10 | double backoff_factor, 11 | int growth_interval, 12 | int hysteresis) 13 | { 14 | if (*found_inf > 0) { 15 | *hysteresis_tracker -= 1; 16 | 17 | // Only reset the growth tracker when hysteresis is larger than zero 18 | if (*hysteresis_tracker > 0) { 19 | *growth_tracker = 0; 20 | return; 21 | } 22 | } 23 | 24 | if (*found_inf) { 25 | *current_scale = (*current_scale)*backoff_factor; 26 | *growth_tracker = 0; 27 | } else { 28 | // Entering this branch means we just carried out a successful step, 29 | // so growth_tracker is incremented before comparing to growth_interval. 30 | auto successful = (*growth_tracker) + 1; 31 | if (successful == growth_interval) { 32 | auto new_scale = static_cast((*current_scale)*growth_factor); 33 | // Do not grow the scale past fp32 bounds to inf. 34 | if (isfinite(new_scale)) { 35 | *current_scale = new_scale; 36 | } 37 | *growth_tracker = 0; 38 | } else { 39 | *growth_tracker = successful; 40 | } 41 | } 42 | 43 | // Reset the hysteresis tracker if no infs are found 44 | if (*found_inf <= 0) { 45 | *hysteresis_tracker = hysteresis; 46 | } 47 | } 48 | 49 | at::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale, 50 | at::Tensor growth_tracker, 51 | at::Tensor hysteresis_tracker, 52 | at::Tensor found_inf, 53 | const double growth_factor, 54 | const double backoff_factor, 55 | const int64_t growth_interval, 56 | const int hysteresis) 57 | { 58 | update_scale_hysteresis_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( 59 | current_scale.mutable_data_ptr(), 60 | growth_tracker.mutable_data_ptr(), 61 | hysteresis_tracker.mutable_data_ptr(), 62 | found_inf.const_data_ptr(), 63 | growth_factor, 64 | backoff_factor, 65 | growth_interval, 66 | hysteresis); 67 | 68 | AT_CUDA_CHECK(cudaGetLastError()); 69 | 70 | return current_scale; 71 | } 72 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = NVIDIAAPEX 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | gh-pages: 16 | git checkout gh-pages 17 | rm -rf build 18 | rm -rf source 19 | git checkout master -- . 20 | make html 21 | rm -rf ../_modules ../_sources ../_static 22 | mv -fv build/html/* ../ 23 | rm -rf build 24 | git add -A 25 | git commit -m "Generated gh-pages for `git log master -1 --pretty=short --abbrev-commit`" && git push origin gh-pages ; git checkout master 26 | 27 | .PHONY: help Makefile 28 | 29 | # Catch-all target: route all unknown targets to Sphinx using the new 30 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 31 | %: Makefile 32 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 33 | -------------------------------------------------------------------------------- /docs/source/_static/css/pytorch_theme.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 3 | } 4 | 5 | /* Default header fonts are ugly */ 6 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { 7 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 8 | } 9 | 10 | /* Use white for docs background */ 11 | .wy-side-nav-search { 12 | background-color: #fff; 13 | } 14 | 15 | .wy-nav-content-wrap, .wy-menu li.current > a { 16 | background-color: #fff; 17 | } 18 | 19 | @media screen and (min-width: 1400px) { 20 | .wy-nav-content-wrap { 21 | background-color: rgba(0, 0, 0, 0.0470588); 22 | } 23 | 24 | .wy-nav-content { 25 | background-color: #fff; 26 | } 27 | } 28 | 29 | /* Fixes for mobile */ 30 | .wy-nav-top { 31 | background-color: #fff; 32 | background-image: url('../img/apex.jpg'); 33 | background-repeat: no-repeat; 34 | background-position: center; 35 | padding: 0; 36 | margin: 0.4045em 0.809em; 37 | color: #333; 38 | } 39 | 40 | .wy-nav-top > a { 41 | display: none; 42 | } 43 | 44 | @media screen and (max-width: 768px) { 45 | .wy-side-nav-search>a img.logo { 46 | height: 60px; 47 | } 48 | } 49 | 50 | /* This is needed to ensure that logo above search scales properly */ 51 | .wy-side-nav-search a { 52 | display: block; 53 | } 54 | 55 | /* This ensures that multiple constructors will remain in separate lines. */ 56 | .rst-content dl:not(.docutils) dt { 57 | display: table; 58 | } 59 | 60 | /* Use our red for literals (it's very similar to the original color) */ 61 | .rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal { 62 | color: #F05732; 63 | } 64 | 65 | .rst-content tt.xref, a .rst-content tt, .rst-content tt.xref, 66 | .rst-content code.xref, a .rst-content tt, a .rst-content code { 67 | color: #404040; 68 | } 69 | 70 | /* Change link colors (except for the menu) */ 71 | 72 | a { 73 | color: #F05732; 74 | } 75 | 76 | a:hover { 77 | color: #F05732; 78 | } 79 | 80 | 81 | a:visited { 82 | color: #D44D2C; 83 | } 84 | 85 | .wy-menu a { 86 | color: #b3b3b3; 87 | } 88 | 89 | .wy-menu a:hover { 90 | color: #b3b3b3; 91 | } 92 | 93 | /* Default footer text is quite big */ 94 | footer { 95 | font-size: 80%; 96 | } 97 | 98 | footer .rst-footer-buttons { 99 | font-size: 125%; /* revert footer settings - 1/80% = 125% */ 100 | } 101 | 102 | footer p { 103 | font-size: 100%; 104 | } 105 | 106 | /* For hidden headers that appear in TOC tree */ 107 | /* see http://stackoverflow.com/a/32363545/3343043 */ 108 | .rst-content .hidden-section { 109 | display: none; 110 | } 111 | 112 | nav .hidden-section { 113 | display: inherit; 114 | } 115 | 116 | .wy-side-nav-search>div.version { 117 | color: #000; 118 | } 119 | -------------------------------------------------------------------------------- /docs/source/_static/img/nv-pytorch2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/docs/source/_static/img/nv-pytorch2.png -------------------------------------------------------------------------------- /docs/source/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | {% block sidebartitle %} {{ super() }} 3 | 4 | 32 | {% endblock %} 33 | 34 | {% block footer %} {{ super() }} 35 | 36 | 51 | {% endblock %} 52 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. PyTorch documentation master file, created by 2 | sphinx-quickstart on Fri Dec 23 13:31:47 2016. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/nvidia/apex 7 | 8 | Apex (A PyTorch Extension) 9 | =================================== 10 | 11 | This site contains the API documentation for Apex (https://github.com/nvidia/apex), 12 | a Pytorch extension with NVIDIA-maintained utilities to streamline mixed precision and distributed training. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible. 13 | 14 | Installation instructions can be found here: https://github.com/NVIDIA/apex#quick-start. 15 | 16 | Some other useful material, including GTC 2019 and Pytorch DevCon 2019 Slides, can be found here: https://github.com/mcarilli/mixed_precision_references. 17 | 18 | .. toctree:: 19 | :maxdepth: 1 20 | :caption: Fused Optimizers 21 | 22 | optimizers 23 | 24 | .. toctree:: 25 | :maxdepth: 1 26 | :caption: Fused Layer Norm 27 | 28 | layernorm 29 | 30 | .. .. toctree:: 31 | :maxdepth: 1 32 | :caption: Deprecated mixed precision API 33 | fp16_util 34 | 35 | .. RNN 36 | 37 | Indices and tables 38 | ================== 39 | 40 | * :ref:`genindex` 41 | * :ref:`modindex` 42 | -------------------------------------------------------------------------------- /docs/source/layernorm.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | apex.normalization.fused_layer_norm 5 | =================================== 6 | 7 | .. automodule:: apex.normalization 8 | .. currentmodule:: apex.normalization 9 | 10 | .. FusedAdam 11 | ---------- 12 | 13 | .. autoclass:: FusedLayerNorm 14 | :members: 15 | 16 | .. autoclass:: FusedRMSNorm 17 | :members: 18 | -------------------------------------------------------------------------------- /docs/source/optimizers.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | apex.optimizers 5 | =================================== 6 | 7 | .. automodule:: apex.optimizers 8 | .. currentmodule:: apex.optimizers 9 | 10 | .. FusedAdam 11 | ---------- 12 | 13 | .. autoclass:: FusedAdam 14 | :members: 15 | 16 | .. autoclass:: FusedLAMB 17 | :members: 18 | 19 | .. autoclass:: FusedNovoGrad 20 | :members: 21 | 22 | .. autoclass:: FusedSGD 23 | :members: 24 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | This directory contains examples illustrating Apex mixed precision and distributed tools. 2 | 3 | **Note for users of the pre-unification API**: 4 | `deprecated_api` contains examples illustrating the old (pre-unified) APIs. These APIs will be removed soon, and users are strongly encouraged to switch. The separate mixed precision tools called `Amp` and `FP16_Optimizer` in the old API are exposed via different flags/optimization levels in the new API. 5 | -------------------------------------------------------------------------------- /examples/dcgan/README.md: -------------------------------------------------------------------------------- 1 | # Mixed Precision DCGAN Training in PyTorch 2 | 3 | `main_amp.py` is based on [https://github.com/pytorch/examples/tree/master/dcgan](https://github.com/pytorch/examples/tree/master/dcgan). 4 | It implements Automatic Mixed Precision (Amp) training of the DCGAN example for different datasets. Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed precision "optimization levels" or `opt_level`s. For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html). 5 | 6 | We introduce these changes to the PyTorch DCGAN example as described in the [Multiple models/optimizers/losses](https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses) section of the documentation:: 7 | ``` 8 | # Added after models and optimizers construction 9 | [netD, netG], [optimizerD, optimizerG] = amp.initialize( 10 | [netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3) 11 | ... 12 | # loss.backward() changed to: 13 | with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled: 14 | errD_real_scaled.backward() 15 | ... 16 | with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled: 17 | errD_fake_scaled.backward() 18 | ... 19 | with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled: 20 | errG_scaled.backward() 21 | ``` 22 | 23 | Note that we use different `loss_scalers` for each computed loss. 24 | Using a separate loss scaler per loss is [optional, not required](https://nvidia.github.io/apex/advanced.html#optionally-have-amp-use-a-different-loss-scaler-per-loss). 25 | 26 | To improve the numerical stability, we swapped `nn.Sigmoid() + nn.BCELoss()` to `nn.BCEWithLogitsLoss()`. 27 | 28 | With the new Amp API **you never need to explicitly convert your model, or the input data, to half().** 29 | 30 | "Pure FP32" training: 31 | ``` 32 | $ python main_amp.py --opt_level O0 33 | ``` 34 | Recommended mixed precision training: 35 | ``` 36 | $ python main_amp.py --opt_level O1 37 | ``` 38 | 39 | Have a look at the original [DCGAN example](https://github.com/pytorch/examples/tree/master/dcgan) for more information about the used arguments. 40 | 41 | To enable mixed precision training, we introduce the `--opt_level` argument. 42 | -------------------------------------------------------------------------------- /examples/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Base image must at least have pytorch and CUDA installed. 2 | ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:23.03-py3 3 | FROM $BASE_IMAGE 4 | ARG BASE_IMAGE 5 | RUN echo "Installing Apex on top of ${BASE_IMAGE}" 6 | # make sure we don't overwrite some existing directory called "apex" 7 | WORKDIR /tmp/unique_for_apex 8 | # uninstall Apex if present, twice to make absolutely sure :) 9 | RUN pip uninstall -y apex || : 10 | RUN pip uninstall -y apex || : 11 | # SHA is something the user can touch to force recreation of this Docker layer, 12 | # and therefore force cloning of the latest version of Apex 13 | RUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git 14 | WORKDIR /tmp/unique_for_apex/apex 15 | RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 16 | WORKDIR /workspace 17 | -------------------------------------------------------------------------------- /examples/docker/README.md: -------------------------------------------------------------------------------- 1 | ## Option 1: Create a new container with Apex 2 | 3 | **Dockerfile** installs the latest Apex on top of an existing image. Run 4 | ``` 5 | docker build -t new_image_with_apex . 6 | ``` 7 | By default, **Dockerfile** uses NVIDIA's Pytorch container as the base image, 8 | which requires an NVIDIA GPU Cloud (NGC) account. If you don't have an NGC account, you can sign up for free by following the instructions [here](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html#generating-api-key). 9 | 10 | Alternatively, you can supply your own base image via the `BASE_IMAGE` build-arg. 11 | `BASE_IMAGE` must have Pytorch and Cuda installed. For example, any 12 | `-devel` image for Pytorch 1.0 and later from the 13 | [official Pytorch Dockerhub](https://hub.docker.com/r/pytorch/pytorch) may be used: 14 | ``` 15 | docker build --build-arg BASE_IMAGE=1.3-cuda10.1-cudnn7-devel -t new_image_with_apex . 16 | ``` 17 | 18 | If you want to rebuild your image, and force the latest Apex to be cloned and installed, make any small change to the `SHA` variable in **Dockerfile**. 19 | 20 | **Warning:** 21 | Currently, the non-`-devel` images on Pytorch Dockerhub do not contain the Cuda compiler `nvcc`. Therefore, 22 | images whose name does not contain `-devel` are not eligible candidates for `BASE_IMAGE`. 23 | 24 | ### Running your Apex container 25 | 26 | Like any Cuda-enabled Pytorch container, a container with Apex should be run via [nvidia-docker](https://github.com/NVIDIA/nvidia-docker), for example: 27 | ``` 28 | docker run --runtime=nvidia -it --rm --ipc=host new_image_with_apex 29 | ``` 30 | 31 | ## Option 2: Install Apex in a running container 32 | 33 | Instead of building a new container, it is also a viable option to `git clone https://github.com/NVIDIA/apex.git` on bare metal, mount the Apex repo into your container at launch by running, for example, 34 | ``` 35 | docker run --runtime=nvidia -it --rm --ipc=host -v /bare/metal/apex:/apex/in/container 36 | ``` 37 | then go to /apex/in/container within the running container and 38 | ``` 39 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 40 | ``` 41 | -------------------------------------------------------------------------------- /examples/simple/distributed/README.md: -------------------------------------------------------------------------------- 1 | **distributed_data_parallel.py** and **run.sh** show an example using Amp with 2 | [apex.parallel.DistributedDataParallel](https://nvidia.github.io/apex/parallel.html) or 3 | [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#distributeddataparallel) 4 | and the Pytorch multiprocess launcher script, 5 | [torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility). 6 | The use of `Amp` with DistributedDataParallel does not need to change from ordinary 7 | single-process use. The only gotcha is that wrapping your model with `DistributedDataParallel` must 8 | come after the call to `amp.initialize`. Test via 9 | ```bash 10 | bash run.sh 11 | ``` 12 | 13 | **This is intended purely as an instructional example, not a performance showcase.** 14 | -------------------------------------------------------------------------------- /examples/simple/distributed/distributed_data_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | from apex import amp 5 | # FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead) 6 | from apex.parallel import DistributedDataParallel 7 | 8 | parser = argparse.ArgumentParser() 9 | # FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied 10 | # automatically by torch.distributed.launch. 11 | parser.add_argument("--local_rank", default=0, type=int) 12 | args = parser.parse_args() 13 | 14 | # FOR DISTRIBUTED: If we are running under torch.distributed.launch, 15 | # the 'WORLD_SIZE' environment variable will also be set automatically. 16 | args.distributed = False 17 | if 'WORLD_SIZE' in os.environ: 18 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 19 | 20 | if args.distributed: 21 | # FOR DISTRIBUTED: Set the device according to local_rank. 22 | torch.cuda.set_device(args.local_rank) 23 | 24 | # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide 25 | # environment variables, and requires that you use init_method=`env://`. 26 | torch.distributed.init_process_group(backend='nccl', 27 | init_method='env://') 28 | 29 | torch.backends.cudnn.benchmark = True 30 | 31 | N, D_in, D_out = 64, 1024, 16 32 | 33 | # Each process receives its own batch of "fake input data" and "fake target data." 34 | # The "training loop" in each process just uses this fake batch over and over. 35 | # https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic 36 | # example of distributed data sampling for both training and validation. 37 | x = torch.randn(N, D_in, device='cuda') 38 | y = torch.randn(N, D_out, device='cuda') 39 | 40 | model = torch.nn.Linear(D_in, D_out).cuda() 41 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) 42 | 43 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1") 44 | 45 | if args.distributed: 46 | # FOR DISTRIBUTED: After amp.initialize, wrap the model with 47 | # apex.parallel.DistributedDataParallel. 48 | model = DistributedDataParallel(model) 49 | # torch.nn.parallel.DistributedDataParallel is also fine, with some added args: 50 | # model = torch.nn.parallel.DistributedDataParallel(model, 51 | # device_ids=[args.local_rank], 52 | # output_device=args.local_rank) 53 | 54 | loss_fn = torch.nn.MSELoss() 55 | 56 | for t in range(500): 57 | optimizer.zero_grad() 58 | y_pred = model(x) 59 | loss = loss_fn(y_pred, y) 60 | with amp.scale_loss(loss, optimizer) as scaled_loss: 61 | scaled_loss.backward() 62 | optimizer.step() 63 | 64 | if args.local_rank == 0: 65 | print("final loss = ", loss) 66 | -------------------------------------------------------------------------------- /examples/simple/distributed/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m torch.distributed.launch --nproc_per_node=2 distributed_data_parallel.py 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "wheel", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cxxfilt>=0.2.0 2 | tqdm>=4.28.1 3 | numpy>=1.15.3 4 | PyYAML>=5.1 5 | pytest>=3.5.1 6 | packaging>=14.0 7 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | flake8>=3.7.9 3 | Sphinx>=3.0.3 -------------------------------------------------------------------------------- /tests/L0/run_optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/tests/L0/run_optimizers/__init__.py -------------------------------------------------------------------------------- /tests/L0/run_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/e74a67bba3ee679f778670e17edc21639008ae0a/tests/L0/run_transformer/__init__.py -------------------------------------------------------------------------------- /tests/L0/run_transformer/test_data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.testing 4 | from torch.testing._internal import common_utils 5 | 6 | logging.getLogger("torch").setLevel(logging.WARNING) 7 | 8 | from apex.transformer import parallel_state 9 | from apex.transformer.tensor_parallel import data as data_utils 10 | from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase 11 | from apex.transformer.testing.distributed_test_base import UccDistributedTestBase 12 | 13 | logging.getLogger("torch").setLevel(logging.WARNING) 14 | 15 | 16 | class BroadcastDataTestBase: 17 | def test_broadcast_data(self): 18 | tensor_model_parallel_world_size: int = self.world_size // ( 19 | 1 + self.world_size > 1 20 | ) 21 | parallel_state.initialize_model_parallel( 22 | tensor_model_parallel_size_=tensor_model_parallel_world_size 23 | ) 24 | 25 | target_key_size = { 26 | "key1": [7, 11], 27 | "key2": [8, 2, 1], 28 | "key3": [13], 29 | "key4": [5, 1, 2], 30 | "key5": [5, 12], 31 | } 32 | keys = [k for k in target_key_size] 33 | 34 | data = {} 35 | data_t = {} 36 | with torch.no_grad(): 37 | for key in target_key_size: 38 | data[key] = torch.randint(0, 1000, size=target_key_size[key]) 39 | data_t[key] = data[key].clone() 40 | # "key_x" is supposed to be ignored. 41 | data["key_x"] = torch.rand(5) 42 | data_t["key_x"] = data["key_x"].clone() 43 | if parallel_state.get_tensor_model_parallel_rank() != 0: 44 | data = None 45 | 46 | data_utils._check_data_types(keys, data_t, torch.int64) 47 | key_size, _, _ = data_utils._build_key_size_numel_dictionaries(keys, data) 48 | 49 | for key in keys: 50 | self.assertEqual(target_key_size[key], key_size[key]) 51 | 52 | broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64) 53 | for key in keys: 54 | self.assertEqual(broadcasted_data[key], data_t[key].cuda()) 55 | 56 | parallel_state.destroy_model_parallel() 57 | 58 | 59 | class NcclBroadcastDataTest(BroadcastDataTestBase, NcclDistributedTestBase): pass 60 | class UccBroadcastDataTest(BroadcastDataTestBase, UccDistributedTestBase): pass 61 | 62 | 63 | if __name__ == "__main__": 64 | common_utils.run_tests() 65 | -------------------------------------------------------------------------------- /tests/L0/run_transformer/test_transformer_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch.testing._internal import common_utils 5 | 6 | logging.getLogger("torch").setLevel(logging.WARNING) 7 | 8 | from apex.transformer import parallel_state 9 | from apex.transformer.tensor_parallel import utils 10 | from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase 11 | 12 | logging.getLogger("apex").setLevel(logging.WARNING) 13 | 14 | 15 | class TransformerUtilsTest(NcclDistributedTestBase): 16 | def test_split_tensor_along_last_dim(self): 17 | for tensor_model_paralell_world_size in range(1, self.world_size + 1): 18 | if self.world_size % tensor_model_paralell_world_size > 0: 19 | continue 20 | parallel_state.initialize_model_parallel( 21 | tensor_model_parallel_size_=tensor_model_paralell_world_size 22 | ) 23 | 24 | device = "cpu" 25 | input_tensor = torch.randn((100, 100, 100), device=device) 26 | splits = utils.split_tensor_along_last_dim(input_tensor, 10) 27 | last_dim_shapes = torch.tensor( 28 | [int(split.size()[-1]) for split in splits] 29 | ) 30 | 31 | self.assertTrue( 32 | torch.equal(last_dim_shapes, torch.full((10,), 10),), 33 | msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}", 34 | ) 35 | 36 | parallel_state.destroy_model_parallel() 37 | 38 | 39 | if __name__ == "__main__": 40 | common_utils.run_tests() 41 | -------------------------------------------------------------------------------- /tests/L1/common/compare.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | parser = argparse.ArgumentParser(description='Compare') 5 | parser.add_argument('--opt-level', type=str) 6 | parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) 7 | parser.add_argument('--loss-scale', type=str, default=None) 8 | parser.add_argument('--fused-adam', action='store_true') 9 | parser.add_argument('--use_baseline', action='store_true') 10 | args = parser.parse_args() 11 | 12 | base_file = str(args.opt_level) + "_" +\ 13 | str(args.loss_scale) + "_" +\ 14 | str(args.keep_batchnorm_fp32) + "_" +\ 15 | str(args.fused_adam) 16 | 17 | file_e = "True_" + base_file 18 | file_p = "False_" + base_file 19 | if args.use_baseline: 20 | file_b = "baselines/True_" + base_file 21 | 22 | dict_e = torch.load(file_e) 23 | dict_p = torch.load(file_p) 24 | if args.use_baseline: 25 | dict_b = torch.load(file_b) 26 | 27 | torch.set_printoptions(precision=10) 28 | 29 | print(file_e) 30 | print(file_p) 31 | if args.use_baseline: 32 | print(file_b) 33 | 34 | # ugly duplication here... 35 | if not args.use_baseline: 36 | for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])): 37 | assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p) 38 | 39 | loss_e = dict_e["Loss"][n] 40 | loss_p = dict_p["Loss"][n] 41 | assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p) 42 | print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format( 43 | i_e, 44 | loss_e, 45 | loss_p, 46 | dict_e["Speed"][n], 47 | dict_p["Speed"][n])) 48 | else: 49 | for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])): 50 | assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p) 51 | 52 | loss_e = dict_e["Loss"][n] 53 | loss_p = dict_p["Loss"][n] 54 | loss_b = dict_b["Loss"][n] 55 | assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p) 56 | assert loss_e == loss_b, "Iteration {}, loss_e = {}, loss_b = {}".format(i_e, loss_e, loss_b) 57 | print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format( 58 | i_e, 59 | loss_b, 60 | loss_e, 61 | loss_p, 62 | dict_b["Speed"][n], 63 | dict_e["Speed"][n], 64 | dict_p["Speed"][n])) 65 | -------------------------------------------------------------------------------- /tests/L1/cross_product/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/" 4 | # DATADIR="/opt/home/apex/examples/imagenet/" 5 | cp ../common/* . 6 | bash run_test.sh single_gpu $1 7 | -------------------------------------------------------------------------------- /tests/L1/cross_product_distributed/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cp ../common/* . 4 | bash run_test.sh distributed $1 5 | -------------------------------------------------------------------------------- /tests/distributed/DDP/ddp_race_condition_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.nn import Parameter 4 | from torch.nn import Module 5 | from apex.parallel import DistributedDataParallel as DDP 6 | import argparse 7 | import os 8 | 9 | 10 | parser = argparse.ArgumentParser(description='allreduce hook example') 11 | parser.add_argument("--local_rank", default=0, type=int) 12 | args = parser.parse_args() 13 | 14 | args.distributed = False 15 | if 'WORLD_SIZE' in os.environ: 16 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 17 | 18 | if args.distributed: 19 | args.gpu = args.local_rank % torch.cuda.device_count() 20 | torch.cuda.set_device(args.gpu) 21 | torch.distributed.init_process_group(backend='nccl', 22 | init_method='env://') 23 | args.world_size = torch.distributed.get_world_size() 24 | 25 | torch.set_printoptions(precision=10) 26 | torch.manual_seed(args.local_rank) 27 | 28 | class Model(Module): 29 | def __init__(self): 30 | super(Model, self).__init__() 31 | self.a = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(1.0)) 32 | self.b = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(2.0)) 33 | def forward(self, input): 34 | return (input*self.a)*self.b 35 | 36 | model = Model() 37 | # model = DDP(model, message_size=1, gradient_predivide_factor=8.0) 38 | # model = DDP(model, delay_allreduce=True) 39 | # model = DDP(model, message_size=1, allreduce_trigger_params=[model.b]) 40 | model = DDP(model, message_size=1, allreduce_trigger_params=[model.b], num_allreduce_streams=3) 41 | 42 | x = torch.cuda.FloatTensor(4096*4096) 43 | 44 | passed = True 45 | torch.cuda.cudart().cudaProfilerStart() 46 | for i in range(10): 47 | x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity 48 | model.zero_grad() 49 | out = model(x) 50 | loss = out.sum() 51 | # torch.cuda.nvtx.range_push("backward") 52 | loss.backward() 53 | # torch.cuda.nvtx.range_pop() 54 | 55 | # torch.cuda.nvtx.range_push("synchronize() + info") 56 | # torch.cuda.synchronize() 57 | print("i = {}".format(i)) 58 | def info(name, param, val): 59 | expected = val*4096*4096*(2.*i+1)/2. 60 | actual = param.grad.data.sum().item() 61 | print(name+": grad.data_ptr() = {}, expected sum {}, got {}".format( 62 | param.grad.data_ptr(), expected, actual)) 63 | return (expected == actual) 64 | if not info("model.a", model.module.a, 2.): passed = False 65 | if not info("model.b", model.module.b, 1.): passed = False 66 | # torch.cuda.nvtx.range_pop() 67 | torch.cuda.cudart().cudaProfilerStop() 68 | 69 | print("passed = ", passed) 70 | -------------------------------------------------------------------------------- /tests/distributed/DDP/run_race_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 ddp_race_condition_test.py 4 | -------------------------------------------------------------------------------- /tests/distributed/amp_master_params/amp_master_params.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | from apex import amp 5 | # FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead) 6 | from apex.parallel import DistributedDataParallel 7 | 8 | parser = argparse.ArgumentParser() 9 | # FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied 10 | # automatically by torch.distributed.launch. 11 | parser.add_argument("--local_rank", default=0, type=int) 12 | args = parser.parse_args() 13 | 14 | # FOR DISTRIBUTED: If we are running under torch.distributed.launch, 15 | # the 'WORLD_SIZE' environment variable will also be set automatically. 16 | args.distributed = False 17 | if 'WORLD_SIZE' in os.environ: 18 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 19 | 20 | if args.distributed: 21 | # FOR DISTRIBUTED: Set the device according to local_rank. 22 | torch.cuda.set_device(args.local_rank) 23 | 24 | # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide 25 | # environment variables, and requires that you use init_method=`env://`. 26 | torch.distributed.init_process_group(backend='nccl', 27 | init_method='env://') 28 | 29 | torch.manual_seed(torch.distributed.get_rank()) 30 | 31 | torch.backends.cudnn.benchmark = True 32 | 33 | N, D_in, D_out = 64, 1024, 16 34 | 35 | # Each process receives its own batch of "fake input data" and "fake target data." 36 | # The "training loop" in each process just uses this fake batch over and over. 37 | # https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic 38 | # example of distributed data sampling for both training and validation. 39 | x = torch.randn(N, D_in, device='cuda') 40 | y = torch.randn(N, D_out, device='cuda') 41 | 42 | model = torch.nn.Linear(D_in, D_out).cuda() 43 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) 44 | 45 | model, optimizer = amp.initialize(model, optimizer, opt_level="O2") 46 | 47 | if args.distributed: 48 | # FOR DISTRIBUTED: After amp.initialize, wrap the model with 49 | # apex.parallel.DistributedDataParallel. 50 | model = DistributedDataParallel(model) 51 | # torch.nn.parallel.DistributedDataParallel is also fine, with some added args: 52 | # model = torch.nn.parallel.DistributedDataParallel(model, 53 | # device_ids=[args.local_rank], 54 | # output_device=args.local_rank) 55 | 56 | loss_fn = torch.nn.MSELoss() 57 | 58 | for t in range(500): 59 | optimizer.zero_grad() 60 | y_pred = model(x) 61 | loss = loss_fn(y_pred, y) 62 | with amp.scale_loss(loss, optimizer) as scaled_loss: 63 | scaled_loss.backward() 64 | optimizer.step() 65 | 66 | if args.local_rank == 0: 67 | print("final loss = ", loss) 68 | 69 | torch.save(list(model.parameters()), "rank{}model.pth".format(torch.distributed.get_rank())) 70 | torch.save(list(amp.master_params(optimizer)), "rank{}master.pth".format(torch.distributed.get_rank())) 71 | -------------------------------------------------------------------------------- /tests/distributed/amp_master_params/compare.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | model_params_rank0 = torch.load("rank0model.pth", 4 | map_location = lambda storage, loc: storage.cuda(0)) 5 | model_params_rank1 = torch.load("rank1model.pth", 6 | map_location = lambda storage, loc: storage.cuda(0)) 7 | master_params_rank0 = torch.load("rank0master.pth", 8 | map_location = lambda storage, loc: storage.cuda(0)) 9 | master_params_rank1 = torch.load("rank1master.pth", 10 | map_location = lambda storage, loc: storage.cuda(0)) 11 | 12 | for model_rank0, model_rank1, master_rank0, master_rank1 in zip( 13 | model_params_rank0, 14 | model_params_rank1, 15 | master_params_rank0, 16 | master_params_rank1): 17 | assert torch.allclose(model_rank0, model_rank1), "Model param mismatch" 18 | assert torch.allclose(master_rank0, master_rank1), "Master param mismatch" 19 | # Some debugging/investigation assistance code: 20 | # maxval, maxind = torch.max(((torch.abs(model_rank0).float())/torch.abs(master_rank0)).view(-1), 0) 21 | # offending_val_half = model_rank0.view(-1)[maxind.item()] 22 | # offending_val_float = master_rank0.view(-1)[maxind.item()] 23 | # print(maxval.item(), maxind.item(), offending_val_half.item(), offending_val_float.item(), 24 | # offending_val_float.half().item()) 25 | # rtol needs to be > 2^-11 because of denormals... 26 | assert torch.allclose(model_rank0, master_rank0.half(), rtol=.005), "Model-master mismatch" 27 | 28 | print("OK: Model and master params match across ranks.") 29 | -------------------------------------------------------------------------------- /tests/distributed/amp_master_params/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py 3 | 4 | python compare.py 5 | -------------------------------------------------------------------------------- /tests/distributed/synced_batchnorm/test_batchnorm1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import apex 3 | 4 | model = apex.parallel.SyncBatchNorm(4).cuda() 5 | model.weight.data.uniform_() 6 | model.bias.data.uniform_() 7 | data = torch.rand((8,4)).cuda() 8 | 9 | model_ref = torch.nn.BatchNorm1d(4).cuda() 10 | model_ref.load_state_dict(model.state_dict()) 11 | data_ref = data.clone() 12 | 13 | output = model(data) 14 | output_ref = model_ref(data_ref) 15 | 16 | assert(output.allclose(output_ref)) 17 | assert(model.running_mean.allclose(model_ref.running_mean)) 18 | assert(model.running_var.allclose(model_ref.running_var)) 19 | -------------------------------------------------------------------------------- /tests/distributed/synced_batchnorm/unit_test.sh: -------------------------------------------------------------------------------- 1 | python python_single_gpu_unit_test.py 2 | python single_gpu_unit_test.py 3 | python test_batchnorm1d.py 4 | python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py 5 | python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16 6 | python -m torch.distributed.launch --nproc_per_node=2 two_gpu_test_different_batch_size.py --apex 7 | #beware, you need a system with at least 4 gpus to test group_size