├── python ├── flux_triton │ ├── __init__.py │ ├── kernels │ │ └── __init__.py │ ├── tools │ │ ├── __init__.py │ │ ├── compile │ │ │ ├── compile.h │ │ │ ├── __init__.py │ │ │ └── compile.c │ │ └── runtime │ │ │ └── triton_aot_runtime.h │ └── extra │ │ ├── cuda_extra.bc │ │ └── __init__.py └── flux │ ├── .gitignore │ ├── triton │ └── __init__.py │ ├── testing │ └── __init__.py │ └── __init__.py ├── src ├── .gitignore ├── gemm_rs │ ├── test │ │ └── CMakeLists.txt │ ├── cutlass_impls │ │ └── .clang-format │ ├── ring_reduce.hpp │ ├── bsr_reduce.hpp │ ├── CMakeLists.txt │ ├── ths_op │ │ ├── helper_ops.h │ │ └── gemm_reduce_scatter.h │ ├── padding_util.hpp │ ├── reduce_scatter_topos.hpp │ └── tuning_config │ │ └── config_gemm_rs_sm80_A100_tp4_nnodes1.cu ├── cuda │ ├── version.ld │ ├── gemm_op_registry.cu.in │ ├── CMakeLists.txt │ ├── random_initialize.cu │ ├── cuda_common.cu │ ├── moe_utils.cu │ └── cuda_common.cc ├── moe_ag_scatter │ ├── test │ │ └── CMakeLists.txt │ ├── cutlass_impls │ │ └── .clang-format │ ├── CMakeLists.txt │ ├── triton_util.h │ ├── tuning_config │ │ └── config_ag_scatter_sm90_H800.cu │ ├── dispatch_policy.hpp │ └── ths_op │ │ └── gemm_grouped_v3_ag_scatter.h ├── moe_gather_rs │ ├── cutlass_impls │ │ └── .clang-format │ ├── CMakeLists.txt │ ├── ths_op │ │ ├── topk_reduce_gather_rs.h │ │ ├── moe_utils.h │ │ ├── topk_reduce_gather_rs.cc │ │ └── topk_scatter_reduce.cc │ ├── topk_gather_rs.hpp │ └── workspace_helper.h ├── quantization │ ├── CMakeLists.txt │ ├── quantization.hpp │ └── ths_op │ │ └── quantization.h ├── comm_none │ ├── test │ │ └── CMakeLists.txt │ ├── CMakeLists.txt │ ├── cutlass_blockscale_gemm_impl.h │ ├── ths_op │ │ ├── gemm_grouped_v3.h │ │ ├── gemm_grouped_v2.h │ │ └── gemm_only.h │ └── tuning_config │ │ └── config_gemm_only_sm80_A100.cu ├── inplace_cast │ ├── CMakeLists.txt │ ├── ths_op │ │ ├── helper_ops.h │ │ ├── inplace_cast.h │ │ └── helper_ops.cc │ └── inplace_cast.hpp ├── ths_op │ ├── CMakeLists.txt │ └── helper_ops.cc ├── generator │ └── CMakeLists.txt ├── gemm_a2a_transpose │ ├── ths_op │ │ └── pre_attn_a2a_types.h │ ├── CMakeLists.txt │ ├── pre_attn_a2a_transpose_impls.hpp │ └── pre_attn_qkv_pack_a2a_impls.hpp ├── coll │ ├── CMakeLists.txt │ ├── all2all_impl.hpp │ ├── all2all_single_2d_impl.hpp │ ├── local_copy_and_reset.hpp │ ├── ths_op │ │ ├── all2all_single_2d.h │ │ ├── isendrecv.h │ │ ├── all2all_op.h │ │ ├── all_gather_op.h │ │ ├── all_gather_types.h │ │ └── reduce_scatter_op.h │ └── all_gather_impls.hpp ├── ag_gemm │ ├── CMakeLists.txt │ ├── all_gather_swizzle.hpp │ └── ths_op │ │ ├── all_gather_gemm_op_internode.h │ │ └── all_gather_gemm_op.h ├── a2a_transpose_gemm │ ├── CMakeLists.txt │ ├── ths_op │ │ └── all_to_all_types.h │ ├── post_attn_a2a_transpose_impls.hpp │ └── tuning_config │ │ └── config_a2a_transpose_gemm_kernel_sm90_H800_tp8_nnodes1.cu └── pybind │ ├── inplace_cast.cc │ ├── quantization.cc │ └── gemm_grouped_v3_ag_scatter.cc ├── test ├── CMakeLists.txt ├── samples │ ├── config_ag_gemm_sm80_tp8_nnodes1.prototxt │ └── config_gemm_rs_sm80_tp8_nnodes1.prototxt ├── tools │ └── aot │ │ └── run.sh ├── python │ ├── util │ │ ├── cuda_kernels │ │ │ ├── copy_kernel.cu │ │ │ └── reduce_kernel.cu │ │ ├── test_uniform_initialize.py │ │ ├── test_flux_ring_barrier.py │ │ ├── test_bitwise_check.py │ │ └── test_bsr_reduce.py │ ├── quantization │ │ └── test_quantization.py │ └── inplace_cast │ │ └── test_inplace_cast.py └── unit │ └── test_cuda_common.cu ├── docs ├── assets │ ├── flux_moe.png │ ├── torch_moe.png │ ├── dense_layer0.png │ ├── dense_layer1.png │ ├── e2e_latency.png │ └── toy_example.png ├── FAQ.md └── mlsys_comet_ae.md ├── pre-commit ├── .gitmodules ├── MANIFEST.in ├── .github └── ISSUE_TEMPLATE │ ├── question.md │ ├── bug_report.md │ └── feature_request.md ├── proto └── CMakeLists.txt ├── Dockerfile ├── pyproject.toml ├── setup.cfg ├── .gitignore ├── examples └── run_moe.sh ├── include └── flux │ ├── cuda │ ├── gemm_impls │ │ └── cutlass_impls │ │ │ └── .clang-format │ ├── helper_kernels.h │ ├── moe_utils.h │ ├── reduce_utils.cuh │ ├── cuda_stub.h │ └── nvml_stub.h │ ├── op_registry_proto_utils.h │ ├── args │ ├── gemm_a2a_transpose.h │ ├── a2a_transpose_gemm.h │ ├── ag_gemm.h │ ├── moe_ag_scatter.h │ └── gemm_rs.h │ ├── ths_op │ ├── topo_utils.h │ └── util.h │ └── gemm_operator_base.h ├── .clang-format ├── cmake ├── modules │ └── FindNUMA.cmake └── FluxConfig.cmake └── launch.sh /python/flux_triton/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/flux_triton/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/flux_triton/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | triton_aot_generated/ 2 | -------------------------------------------------------------------------------- /python/flux/.gitignore: -------------------------------------------------------------------------------- 1 | lib/ 2 | include/ 3 | share/ 4 | -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | add_subdirectory(unit) 4 | -------------------------------------------------------------------------------- /docs/assets/flux_moe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/flux/HEAD/docs/assets/flux_moe.png -------------------------------------------------------------------------------- /docs/assets/torch_moe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/flux/HEAD/docs/assets/torch_moe.png -------------------------------------------------------------------------------- /docs/assets/dense_layer0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/flux/HEAD/docs/assets/dense_layer0.png -------------------------------------------------------------------------------- /docs/assets/dense_layer1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/flux/HEAD/docs/assets/dense_layer1.png -------------------------------------------------------------------------------- /docs/assets/e2e_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/flux/HEAD/docs/assets/e2e_latency.png -------------------------------------------------------------------------------- /docs/assets/toy_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/flux/HEAD/docs/assets/toy_example.png -------------------------------------------------------------------------------- /python/flux_triton/extra/cuda_extra.bc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/flux/HEAD/python/flux_triton/extra/cuda_extra.bc -------------------------------------------------------------------------------- /src/gemm_rs/test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | add_executable(test_gemm_rs test_gemm_rs.cc) 4 | target_link_libraries(test_gemm_rs PUBLIC flux_cuda) 5 | -------------------------------------------------------------------------------- /pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$(./code-format.sh --show-only)" != "" ]; then 4 | echo "code format check failed, please run the following command before commit: ./code-format.sh" 5 | exit 1 6 | fi 7 | 8 | -------------------------------------------------------------------------------- /src/cuda/version.ld: -------------------------------------------------------------------------------- 1 | { 2 | global: 3 | extern "C++" { 4 | bytedance::flux::*; 5 | nvshmemi_init_thread*; 6 | nvshmemi_finalize*; 7 | }; 8 | 9 | local: 10 | *; 11 | }; 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/nccl"] 2 | path = 3rdparty/nccl 3 | url = https://github.com/NVIDIA/nccl 4 | [submodule "3rdparty/cutlass"] 5 | path = 3rdparty/cutlass 6 | url = https://github.com/NVIDIA/cutlass 7 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | global-exclude *.so* 2 | recursive-include python/flux/include * 3 | recursive-include python/flux/share * 4 | include python/flux_triton/extra/cuda_extra.bc 5 | include python/flux_triton/extra/cuda_extra.ll 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Ask a general question about Flux 4 | title: "[QUESTION]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Your question** 11 | Ask a clear and concise question about Flux. 12 | -------------------------------------------------------------------------------- /proto/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | INCLUDE_DIRECTORIES(${PROTOBUF_INCLUDE_DIR}) 3 | PROTOBUF_GENERATE_CPP(PROTO_SRC PROTO_HEADER flux.proto) 4 | ADD_LIBRARY(flux_proto ${PROTO_HEADER} ${PROTO_SRC}) 5 | set_property(TARGET flux_proto PROPERTY POSITION_INDEPENDENT_CODE ON) 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:24.07-py3 2 | 3 | WORKDIR /workspace/flux 4 | 5 | COPY . . 6 | 7 | RUN pip install ninja packaging 8 | RUN git submodule update --init --recursive 9 | RUN OMP_NUM_THREADS=128 ./build.sh --arch "80;89;90" --nvshmem 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # pyproject.toml 2 | 3 | [tool.black] 4 | line-length = 100 5 | include = '\.pyi?$' 6 | exclude = ''' 7 | /( 8 | \.git 9 | | \.hg 10 | | \.mypy_cache 11 | | \.tox 12 | | \.venv 13 | | _build 14 | | buck-out 15 | | build 16 | | dist 17 | )/ 18 | ''' 19 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .git,__pycache__,docs/source/conf.py,old,build,dist,*.pyi 3 | max-line-length = 100 4 | ignore = E203,W503 5 | 6 | [autopep8] 7 | exclude = .git,__pycache__,docs/source/conf.py,old,build,dist 8 | max-line-length = 100 9 | 10 | [metadata] 11 | name = flux 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # PyCache files 2 | build/ 3 | .cache/ 4 | tmp/ 5 | report*.sqlite 6 | report*.nsys-rep 7 | 8 | # run files 9 | log/ 10 | prof/ 11 | workspace/ 12 | 13 | # general things to ignore 14 | dist/ 15 | *.egg-info/ 16 | .eggs/ 17 | *.egg 18 | *.py[cod] 19 | __pycache__/ 20 | *.so 21 | *.so.* 22 | *~ 23 | python/flux/version.py 24 | 25 | # due to using tox and pytest and clangd 26 | .tox 27 | -------------------------------------------------------------------------------- /python/flux_triton/tools/compile/compile.h: -------------------------------------------------------------------------------- 1 | #ifndef TT_KERNEL_INCLUDES 2 | #define TT_KERNEL_INCLUDES 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #endif 10 | 11 | void unload_{kernel_name}(void); 12 | void load_{kernel_name}(void); 13 | // tt-linker: {kernel_name}:{full_signature}:{algo_info} 14 | CUresult{_placeholder} {kernel_name}(CUstream stream, {signature}); 15 | -------------------------------------------------------------------------------- /src/moe_ag_scatter/test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | find_package(MPI) 4 | if (MPI_FOUND) 5 | add_executable(test_moe_ag_scatter test_moe_ag_scatter.cc) 6 | target_link_libraries(test_moe_ag_scatter PUBLIC flux_cuda MPI::MPI_CXX) 7 | else() 8 | message(STATUS "MPI not found, mpi_target will not be built") 9 | endif() 10 | 11 | add_executable(test_sort_utils test_sort_utils.cc) 12 | target_link_libraries(test_sort_utils PUBLIC flux_cuda) 13 | target_compile_options(test_sort_utils PRIVATE -g) 14 | -------------------------------------------------------------------------------- /src/gemm_rs/cutlass_impls/.clang-format: -------------------------------------------------------------------------------- 1 | # cutlass does not respect .clang-format: even in a single doc format is not consistent. 2 | # but still we format document with best effort to make it like cutlass 3 | ColumnLimit: 120 4 | SortIncludes: false 5 | BreakBeforeBraces: Custom 6 | BraceWrapping: 7 | BeforeElse: true 8 | AfterEnum: true 9 | AfterStruct: true 10 | AfterFunction: false 11 | 12 | PointerAlignment: Middle 13 | ReferenceAlignment: Right 14 | 15 | BinPackArguments: false 16 | BinPackParameters: false 17 | ExperimentalAutoDetectBinPacking: true 18 | AllowAllParametersOfDeclarationOnNextLine: true 19 | -------------------------------------------------------------------------------- /src/moe_ag_scatter/cutlass_impls/.clang-format: -------------------------------------------------------------------------------- 1 | # cutlass does not respect .clang-format: even in a single doc format is not consistent. 2 | # but still we format document with best effort to make it like cutlass 3 | ColumnLimit: 120 4 | SortIncludes: false 5 | BreakBeforeBraces: Custom 6 | BraceWrapping: 7 | BeforeElse: true 8 | AfterEnum: true 9 | AfterStruct: true 10 | AfterFunction: false 11 | 12 | PointerAlignment: Middle 13 | ReferenceAlignment: Right 14 | 15 | BinPackArguments: false 16 | BinPackParameters: false 17 | ExperimentalAutoDetectBinPacking: true 18 | AllowAllParametersOfDeclarationOnNextLine: true 19 | -------------------------------------------------------------------------------- /src/moe_gather_rs/cutlass_impls/.clang-format: -------------------------------------------------------------------------------- 1 | # cutlass does not respect .clang-format: even in a single doc format is not consistent. 2 | # but still we format document with best effort to make it like cutlass 3 | ColumnLimit: 120 4 | SortIncludes: false 5 | BreakBeforeBraces: Custom 6 | BraceWrapping: 7 | BeforeElse: true 8 | AfterEnum: true 9 | AfterStruct: true 10 | AfterFunction: false 11 | 12 | PointerAlignment: Middle 13 | ReferenceAlignment: Right 14 | 15 | BinPackArguments: false 16 | BinPackParameters: false 17 | ExperimentalAutoDetectBinPacking: true 18 | AllowAllParametersOfDeclarationOnNextLine: true 19 | -------------------------------------------------------------------------------- /examples/run_moe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This run script provides a minimal example of how to build MoE layers with Flux, 4 | # compared with the native pytorch implementation. 5 | 6 | # Suppress NCCL debugging info 7 | export NCCL_DEBUG=WARN 8 | 9 | # The MoE layer0: 10 | ../launch.sh moe_layer0.py 11 | 12 | # The MoE layer1: 13 | ../launch.sh moe_layer1.py 14 | 15 | # A minimal MoE layer - Compare torch and Flux: 16 | # ../launch.sh moe.py 17 | 18 | # A minimal MoE layer with only flux: 19 | # ../launch.sh moe_flux_only.py 20 | 21 | # For a complete and more detailed implementation of the MoE layer, please refer to docs/moe_usage.md -------------------------------------------------------------------------------- /include/flux/cuda/gemm_impls/cutlass_impls/.clang-format: -------------------------------------------------------------------------------- 1 | # cutlass does not respect .clang-format: even in a single doc format is not consistent. 2 | # but still we format document with best effort to make it like cutlass 3 | ColumnLimit: 120 4 | SortIncludes: false 5 | BreakBeforeBraces: Custom 6 | BraceWrapping: 7 | BeforeElse: true 8 | AfterEnum: true 9 | AfterStruct: true 10 | AfterFunction: false 11 | 12 | PointerAlignment: Middle 13 | ReferenceAlignment: Right 14 | 15 | BinPackArguments: false 16 | BinPackParameters: false 17 | ExperimentalAutoDetectBinPacking: true 18 | AllowAllParametersOfDeclarationOnNextLine: true 19 | -------------------------------------------------------------------------------- /test/samples/config_ag_gemm_sm80_tp8_nnodes1.prototxt: -------------------------------------------------------------------------------- 1 | tune_configs { 2 | meta { 3 | dtype { 4 | A: FP16 5 | B: FP16 6 | C: Void 7 | D: FP16 8 | Acc: FP32 9 | } 10 | arch: Sm80 11 | comm_op: AGKernel 12 | gemm_layout: RRR 13 | impl: GemmV2 14 | } 15 | rt_conf { 16 | m: 512 17 | n: 6144 18 | k: 12288 19 | all_gather_rt_conf { 20 | world_size: 8 21 | nnodes: 1 22 | ring_mode: 1 23 | } 24 | } 25 | best_hparams { 26 | gemm_v2_hparams { 27 | warp_shape: [64, 64, 32] 28 | instruction_shape: [16, 8, 16] 29 | streamk_mode: DP 30 | } 31 | tile_shape: [256, 128, 32] 32 | gemm_kind: GemmStreamK 33 | mainloop_stage: 3 34 | raster_order: AlongM 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve Flux 4 | title: "[BUG]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior. The easier it is to reproduce the faster it will get maintainer attention. 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Stack trace/logs** 20 | If applicable, add the stack trace or logs from the time of the error. 21 | 22 | **Environment** 23 | 24 | **Proposed fix** 25 | If you have a proposal for how to fix the issue state it here or link to a PR. 26 | 27 | **Additional context** 28 | Add any other context about the problem here. 29 | -------------------------------------------------------------------------------- /src/quantization/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | set(OP_REGS "") 4 | 5 | file(GLOB KERNEL quantization.cu) 6 | set(CU_FILES 7 | ${OP_REGS} 8 | ${KERNEL} 9 | ) 10 | 11 | set(LIB_NAME "flux_cuda_quantization") 12 | flux_add_op_cu_obj_lib(${LIB_NAME} "${CU_FILES}") 13 | target_compile_options(${LIB_NAME} PRIVATE $<$:-rdc=true>) 14 | 15 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS}) 16 | list(APPEND FLUX_CUDA_OP_TARGETS ${LIB_NAME}) 17 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS} PARENT_SCOPE) 18 | 19 | if (BUILD_THS) 20 | file(GLOB THS_FILES ths_op/*.cc) 21 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES}) 22 | list(APPEND FLUX_THS_OP_FILES ${THS_FILES}) 23 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES} PARENT_SCOPE) 24 | flux_add_ths_op_target("quantization") 25 | endif() -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for Flux 4 | title: "[ENHANCEMENT]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Proposed implementation** 20 | If you have a proposed implementation for the feature state it here or link to a PR. 21 | 22 | **Additional context** 23 | Add any other context or screenshots about the feature request here. 24 | -------------------------------------------------------------------------------- /src/comm_none/test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | add_executable(test_gemm_only test_gemm_only.cc) 4 | target_link_libraries(test_gemm_only PUBLIC flux_cuda) 5 | 6 | if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.3") 7 | add_executable(test_grouped_gemm_v3_comm_none test_grouped_gemm_comm_none.cc) 8 | target_link_libraries(test_grouped_gemm_v3_comm_none PUBLIC flux_cuda) 9 | 10 | add_executable(test_blockscale_gemm_v3_comm_none test_blockscale_gemm_comm_none.cu) 11 | target_link_libraries(test_blockscale_gemm_v3_comm_none PUBLIC flux_cuda) 12 | endif() 13 | 14 | if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.4") 15 | add_executable(test_sm89_fp8_gemm_v2_comm_none test_sm89_fp8_gemm_comm_none.cu) 16 | target_link_libraries(test_sm89_fp8_gemm_v2_comm_none PUBLIC flux_cuda) 17 | endif() 18 | -------------------------------------------------------------------------------- /src/inplace_cast/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | set(OP_REGS "") 4 | 5 | file(GLOB KERNEL inplace_cast.cu) 6 | set(CU_FILES 7 | ${OP_REGS} 8 | ${KERNEL} 9 | ) 10 | 11 | set(LIB_NAME "flux_cuda_inplace_cast") 12 | flux_add_op_cu_obj_lib(${LIB_NAME} "${CU_FILES}") 13 | target_compile_options(${LIB_NAME} PRIVATE $<$:-rdc=true>) 14 | 15 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS}) 16 | list(APPEND FLUX_CUDA_OP_TARGETS ${LIB_NAME}) 17 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS} PARENT_SCOPE) 18 | 19 | if (BUILD_THS) 20 | file(GLOB THS_FILES ths_op/*.cc) 21 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES}) 22 | list(APPEND FLUX_THS_OP_FILES ${THS_FILES}) 23 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES} PARENT_SCOPE) 24 | flux_add_ths_op_target("inplace_cast") 25 | endif() 26 | -------------------------------------------------------------------------------- /test/tools/aot/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" 3 | PROJECT_ROOT=${SCRIPT_DIR}/../../.. 4 | 5 | workspace=${PROJECT_ROOT}/workspace/vec-add 6 | 7 | # compile kernels 8 | python3 ${PROJECT_ROOT}/python/flux_triton/tools/compile_aot.py \ 9 | --workspace ${workspace} \ 10 | --kernels ${PROJECT_ROOT}/python/flux_triton/tools/compile_aot.py:add_kernel \ 11 | --library flux_triton_kernel \ 12 | --build 13 | 14 | pushd ${workspace} 15 | # compile test bin 16 | g++ ${PROJECT_ROOT}/test/tools/aot/add_kernel_test.cc \ 17 | -I${workspace} \ 18 | -I/usr/local/cuda/include \ 19 | -L${workspace}/build \ 20 | -L/usr/local/cuda/lib64 \ 21 | -lflux_triton_kernel \ 22 | -lcudart -lcuda \ 23 | -o add_kernel_test 24 | 25 | # run test 26 | LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$(realpath ./build) ./add_kernel_test 27 | popd 28 | -------------------------------------------------------------------------------- /src/ths_op/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | file(GLOB OTHER_THS_FILES *.cc) 4 | set(FLUX_THS_FILES 5 | ${FLUX_THS_OP_FILES} 6 | ${OTHER_THS_FILES} 7 | ) 8 | 9 | message(STATUS "ths_files: ${FLUX_THS_FILES}") 10 | 11 | add_library(flux_cuda_ths_op SHARED ${FLUX_THS_FILES}) 12 | if (WITH_TRITON_AOT) 13 | SET(TRITON_AOT_LIB "flux_triton_aot") 14 | else() 15 | SET(TRITON_AOT_LIB ) 16 | endif() 17 | target_link_libraries(flux_cuda_ths_op 18 | flux_cuda 19 | ${TORCH_LIBRARIES} ${TRITON_AOT_LIB}) 20 | 21 | # Write the unchached variable to a file 22 | get_property(FLUX_THS_TARGETS GLOBAL PROPERTY FLUX_THS_TARGETS) 23 | file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/flux_ths_targets.txt" "FLUX_THS_TARGETS=ths_op;${FLUX_THS_TARGETS}\n") 24 | 25 | install(TARGETS flux_cuda_ths_op 26 | PUBLIC_HEADER DESTINATION include 27 | ) 28 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | BasedOnStyle: Google 4 | IndentWidth: 2 5 | TabWidth: 2 6 | ColumnLimit: 99 7 | ContinuationIndentWidth: 4 8 | AccessModifierOffset: -1 # The private/protected/public has no indent in class 9 | Standard: c++17 10 | AllowShortBlocksOnASingleLine: false 11 | AllowShortCaseLabelsOnASingleLine: true 12 | AllowShortFunctionsOnASingleLine: true 13 | AllowShortIfStatementsOnASingleLine: false 14 | AllowShortLoopsOnASingleLine: false 15 | AllowAllParametersOfDeclarationOnNextLine: true 16 | BinPackParameters: false 17 | BinPackArguments: false 18 | AlignAfterOpenBracket: AlwaysBreak 19 | AlwaysBreakTemplateDeclarations: true 20 | AlwaysBreakAfterDefinitionReturnType: All 21 | DerivePointerAlignment: false 22 | PointerAlignment: Right 23 | 24 | # clang-format 3.9+ 25 | SortIncludes: false 26 | ReflowComments: true 27 | ... 28 | -------------------------------------------------------------------------------- /python/flux/triton/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ################################################################################ 17 | 18 | from .ag_gemm import AgGemmTriton 19 | -------------------------------------------------------------------------------- /src/cuda/gemm_op_registry.cu.in: -------------------------------------------------------------------------------- 1 | #include "@IMPL_HEADER@" 2 | 3 | namespace bytedance::flux { 4 | 5 | static int _@IMPL_NAME@_@SPLIT_IDX@_@NSPLITS@_@ARCH@_ops [[maybe_unused]] = []() { 6 | static_assert(is_flux_op_space_v<@IMPL_NAME@_Space>); 7 | 8 | tuple_for_each( 9 | @IMPL_NAME@_Space::enumerate_split_meta_hparams_pairs<@SPLIT_IDX@, @NSPLITS@, @ARCH@>(), 10 | [](auto item) { 11 | auto [idx, meta_hparams_pair] = item; 12 | auto [meta, hparams] = meta_hparams_pair; 13 | using GemmMetaT = decltype(meta); 14 | using GemmHParamsT = decltype(hparams); 15 | 16 | OpRegistry::instance().register_creator( 17 | []() { 18 | OpRegistry::OpPtr op = std::make_unique<@IMPL_NAME@>(); 19 | return op; 20 | }, 21 | GemmMetaT{}, 22 | GemmHParamsT{}, 23 | idx); 24 | }); 25 | return 0; 26 | }(); 27 | } 28 | -------------------------------------------------------------------------------- /test/samples/config_gemm_rs_sm80_tp8_nnodes1.prototxt: -------------------------------------------------------------------------------- 1 | tune_configs { 2 | meta { 3 | dtype { 4 | A: BF16 5 | B: BF16 6 | C: Void 7 | D: BF16 8 | Acc: FP32 9 | } 10 | arch: Sm80 11 | comm_op: ReduceScatter 12 | gemm_layout: RRR 13 | impl: GemmV2 14 | gemm_v3_meta { 15 | fast_accum: false 16 | } 17 | reduce_scatter_meta { 18 | fuse_reduction: false 19 | comm_kind: IntraNodePcie 20 | } 21 | } 22 | rt_conf { 23 | m: 4096 24 | n: 12288 25 | k: 6144 26 | reduce_scatter_rt_conf { 27 | world_size: 8 28 | nnodes: 1 29 | } 30 | } 31 | best_hparams { 32 | gemm_v2_hparams { 33 | warp_shape: [64, 64, 32] 34 | instruction_shape: [16, 8, 16] 35 | streamk_mode: SK 36 | } 37 | 38 | tile_shape: [128, 128, 64] 39 | gemm_kind: GemmStreamK 40 | mainloop_stage: 3 41 | raster_order: AlongN 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /python/flux_triton/extra/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ################################################################################ 17 | 18 | try: 19 | from ._v3 import * 20 | except: 21 | print("warning: using triton 2.x version asm") 22 | from ._v2 import * 23 | -------------------------------------------------------------------------------- /cmake/modules/FindNUMA.cmake: -------------------------------------------------------------------------------- 1 | # refer to: https://github.com/facebook/rocksdb/blob/main/cmake/modules/FindNUMA.cmake 2 | # - Find NUMA 3 | # Find the NUMA library and includes 4 | # 5 | # NUMA_INCLUDE_DIRS - where to find numa.h, etc. 6 | # NUMA_LIBRARIES - List of libraries when using NUMA. 7 | # NUMA_FOUND - True if NUMA found. 8 | 9 | find_path(NUMA_INCLUDE_DIRS 10 | NAMES numa.h numaif.h 11 | HINTS ${NUMA_ROOT_DIR}/include) 12 | 13 | find_library(NUMA_LIBRARIES 14 | NAMES numa 15 | HINTS ${NUMA_ROOT_DIR}/lib) 16 | 17 | include(FindPackageHandleStandardArgs) 18 | find_package_handle_standard_args(NUMA DEFAULT_MSG NUMA_LIBRARIES NUMA_INCLUDE_DIRS) 19 | 20 | mark_as_advanced( 21 | NUMA_LIBRARIES 22 | NUMA_INCLUDE_DIRS) 23 | 24 | if(NUMA_FOUND AND NOT (TARGET NUMA::NUMA)) 25 | add_library (NUMA::NUMA UNKNOWN IMPORTED) 26 | set_target_properties(NUMA::NUMA 27 | PROPERTIES 28 | IMPORTED_LOCATION ${NUMA_LIBRARIES} 29 | INTERFACE_INCLUDE_DIRECTORIES ${NUMA_INCLUDE_DIRS}) 30 | endif() 31 | -------------------------------------------------------------------------------- /src/generator/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | project(FLUX_GENERATOR LANGUAGES CXX) 3 | 4 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON CACHE INTERNAL "") 5 | set(CMAKE_CXX_STANDARD "17") 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") 8 | 9 | find_package(CUDAToolkit REQUIRED) 10 | 11 | set(FLUX_PROJECT_DIR ${PROJECT_SOURCE_DIR}/../../) 12 | 13 | include_directories( 14 | ${CUDAToolkit_INCLUDE_DIRS} 15 | ${FLUX_PROJECT_DIR}/include 16 | ${FLUX_PROJECT_DIR}/3rdparty/cutlass/include 17 | ${FLUX_PROJECT_DIR}/3rdparty/cutlass/tools/util/include 18 | ) 19 | 20 | add_executable(gen_comm_none gen_comm_none.cc) 21 | add_executable(gen_ag_gemm gen_ag_gemm.cc) 22 | add_executable(gen_gemm_rs gen_gemm_rs.cc) 23 | add_executable(gen_a2a_transpose_gemm gen_a2a_transpose_gemm.cc) 24 | add_executable(gen_moe_ag_scatter gen_moe_ag_scatter.cc) 25 | add_executable(gen_moe_gather_rs gen_moe_gather_rs.cc) 26 | add_executable(gen_gemm_a2a_transpose gen_gemm_a2a_transpose.cc) 27 | -------------------------------------------------------------------------------- /src/inplace_cast/ths_op/helper_ops.h: -------------------------------------------------------------------------------- 1 | //===- helper_ops.h ----------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | namespace bytedance::flux::ths_op { 21 | void inplace_cast_fp32_to_bf16(torch::Tensor data); 22 | } // namespace bytedance::flux::ths_op 23 | -------------------------------------------------------------------------------- /python/flux/testing/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ################################################################################ 17 | 18 | 19 | from .utils import * 20 | from .moe_utils import * 21 | from .moe_ag_scatter_utils import * 22 | from .moe_gather_rs_utils import * 23 | from .gpu_perf_model import * 24 | from .ulysses_sp_utils import * 25 | -------------------------------------------------------------------------------- /include/flux/op_registry_proto_utils.h: -------------------------------------------------------------------------------- 1 | //===- op_registry_proto_utils.h --------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | #pragma once 18 | #include "flux/op_registry.h" 19 | namespace bytedance::flux { 20 | __attribute__((visibility("default"))) void load_tune_config_from_file( 21 | TuningConfigRegistry ®istry, const char *file_name); 22 | } 23 | -------------------------------------------------------------------------------- /python/flux/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ################################################################################ 17 | __version__ = "1.1.2" 18 | from .cpp_mod import * 19 | 20 | if not isinstance(cpp_mod.AGRingMode, cpp_mod.NotCompiled): 21 | from .ag_kernel_internode import * 22 | 23 | from .gemm_rs_sm80 import * 24 | from .util import * 25 | from .dist_utils import * 26 | -------------------------------------------------------------------------------- /src/gemm_a2a_transpose/ths_op/pre_attn_a2a_types.h: -------------------------------------------------------------------------------- 1 | //===- pre_attn_a2a_types.h --------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "flux/ths_op/topo_utils.h" 20 | #include 21 | 22 | namespace bytedance::flux { 23 | enum class PreAttnAllToAllCommOp { 24 | A2ATranspose = 0, 25 | QKVPackA2A = 1, 26 | }; 27 | } // namespace bytedance::flux 28 | -------------------------------------------------------------------------------- /python/flux_triton/tools/compile/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ################################################################################ 17 | import triton 18 | 19 | from packaging.version import Version 20 | 21 | _triton_ver = Version(triton.__version__) 22 | if _triton_ver.major < 3: 23 | raise RuntimeError("AOT compilation requires triton>=3.0.0") 24 | 25 | from .compile import make_ast_source, kernel_name_suffix, materialize_c_params, dump_c_code 26 | -------------------------------------------------------------------------------- /src/coll/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | set(CU_FILES 3 | all_gather_impls.cu 4 | local_copy_and_reset.cu 5 | ) 6 | if(ENABLE_NVSHMEM) 7 | list(APPEND CU_FILES 8 | dis_scatter_forward_impl.cu 9 | dis_scatter_backward_impl.cu 10 | all2all_impl.cu 11 | all2all_single_2d_impl.cu 12 | ) 13 | endif() 14 | 15 | set(LIB_NAME "flux_coll") 16 | flux_add_op_cu_obj_lib(${LIB_NAME} "${CU_FILES}") 17 | target_compile_options(${LIB_NAME} PRIVATE $<$:-rdc=true>) 18 | 19 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS}) 20 | list(APPEND FLUX_CUDA_OP_TARGETS ${LIB_NAME}) 21 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS} PARENT_SCOPE) 22 | if (BUILD_THS) 23 | file(GLOB THS_FILES ths_op/*.cc) 24 | file(GLOB ALL2ALL_THS_FILES ths_op/dis_scatter*.cc ths_op/all2all_op.cc ths_op/all2all_single_2d.cc) 25 | list(REMOVE_ITEM THS_FILES ${ALL2ALL_THS_FILES}) 26 | if(ENABLE_NVSHMEM) 27 | list(APPEND THS_FILES ${ALL2ALL_THS_FILES}) 28 | endif() 29 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES}) 30 | list(APPEND FLUX_THS_OP_FILES ${THS_FILES}) 31 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES} PARENT_SCOPE) 32 | flux_add_ths_op_target("flux_coll_op") 33 | endif() 34 | -------------------------------------------------------------------------------- /src/comm_none/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/registers") 4 | execute_process( 5 | COMMAND ${FLUX_GENERATOR_BINARY_DIR}/gen_comm_none "--dir=./registers" "--archs=${CUDAARCHS}" "--sm_cores=${SM_CORES}" 6 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} 7 | COMMAND_ERROR_IS_FATAL ANY 8 | COMMAND_ECHO STDOUT 9 | ) 10 | 11 | file(GLOB OP_REGS ${CMAKE_CURRENT_BINARY_DIR}/registers/*.cu) 12 | file(GLOB TUNING_CONFIGS tuning_config/*.cu) 13 | set(CU_FILES 14 | cutlass_blockscale_gemm_impl.cu 15 | ${OP_REGS} 16 | ${TUNING_CONFIGS} 17 | ) 18 | set(LIB_NAME "flux_cuda_comm_none") 19 | flux_add_op_cu_obj_lib(${LIB_NAME} "${CU_FILES}") 20 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS}) 21 | list(APPEND FLUX_CUDA_OP_TARGETS ${LIB_NAME}) 22 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS} PARENT_SCOPE) 23 | 24 | if (BUILD_THS) 25 | file(GLOB THS_FILES ths_op/*.cc) 26 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES}) 27 | list(APPEND FLUX_THS_OP_FILES ${THS_FILES}) 28 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES} PARENT_SCOPE) 29 | flux_add_ths_op_target("gemm_only") 30 | endif() 31 | 32 | if (BUILD_TEST) 33 | add_subdirectory(test) 34 | endif() 35 | -------------------------------------------------------------------------------- /src/ag_gemm/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | set(OP_REGS "") 4 | file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/registers") 5 | execute_process( 6 | COMMAND ${FLUX_GENERATOR_BINARY_DIR}/gen_ag_gemm "--dir=./registers" "--archs=${CUDAARCHS}" "--sm_cores=${SM_CORES}" 7 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} 8 | COMMAND_ERROR_IS_FATAL ANY 9 | COMMAND_ECHO STDOUT 10 | ) 11 | 12 | file(GLOB OP_REGS ${CMAKE_CURRENT_BINARY_DIR}/registers/*.cu) 13 | file(GLOB TUNING_CONFIGS tuning_config/*.cu) 14 | set(CU_FILES 15 | ${OP_REGS} 16 | ${TUNING_CONFIGS} 17 | ) 18 | 19 | set(LIB_NAME "flux_cuda_all_gather") 20 | flux_add_op_cu_obj_lib(${LIB_NAME} "${CU_FILES}") 21 | target_compile_options(${LIB_NAME} PRIVATE $<$:-rdc=true>) 22 | 23 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS}) 24 | list(APPEND FLUX_CUDA_OP_TARGETS ${LIB_NAME}) 25 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS} PARENT_SCOPE) 26 | 27 | if (BUILD_THS) 28 | file(GLOB THS_FILES ths_op/*.cc) 29 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES}) 30 | list(APPEND FLUX_THS_OP_FILES ${THS_FILES}) 31 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES} PARENT_SCOPE) 32 | flux_add_ths_op_target("ag_gemm") 33 | endif() 34 | -------------------------------------------------------------------------------- /src/moe_gather_rs/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/registers") 4 | execute_process( 5 | COMMAND ${FLUX_GENERATOR_BINARY_DIR}/gen_moe_gather_rs "--dir=./registers" "--archs=${CUDAARCHS}" "--sm_cores=${SM_CORES}" 6 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} 7 | COMMAND_ERROR_IS_FATAL ANY 8 | COMMAND_ECHO STDOUT 9 | ) 10 | 11 | file(GLOB OP_REGS ${CMAKE_CURRENT_BINARY_DIR}/registers/*.cu) 12 | file(GLOB TUNING_CONFIGS tuning_config/*.cu) 13 | set(CU_FILES 14 | moe_utils.cu 15 | topk_gather_rs.cu 16 | topk_gather_rs_v2.cu 17 | workspace_helper.cu 18 | ${OP_REGS} 19 | ${TUNING_CONFIGS} 20 | ) 21 | 22 | set(LIB_NAME "flux_cuda_moe_gather_rs") 23 | flux_add_op_cu_obj_lib(${LIB_NAME} "${CU_FILES}") 24 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS}) 25 | list(APPEND FLUX_CUDA_OP_TARGETS ${LIB_NAME}) 26 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS} PARENT_SCOPE) 27 | 28 | if (BUILD_THS) 29 | file(GLOB THS_FILES ths_op/*.cc) 30 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES}) 31 | list(APPEND FLUX_THS_OP_FILES ${THS_FILES}) 32 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES} PARENT_SCOPE) 33 | flux_add_ths_op_target("gemm_grouped_v2_gather_rs;gemm_grouped_v3_gather_rs") 34 | endif() 35 | -------------------------------------------------------------------------------- /src/inplace_cast/ths_op/inplace_cast.h: -------------------------------------------------------------------------------- 1 | //===- inplace_cast.h --------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | 22 | namespace bytedance::flux::ths_op { 23 | class InplaceCast { 24 | public: 25 | InplaceCast(int data_size); 26 | ~InplaceCast(); 27 | void from_fp32_to_bf16(torch::Tensor input); 28 | 29 | private: 30 | class InplaceCastImpl; 31 | InplaceCastImpl *impl = nullptr; 32 | }; 33 | } // namespace bytedance::flux::ths_op 34 | -------------------------------------------------------------------------------- /src/a2a_transpose_gemm/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | set(OP_REGS "") 4 | file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/registers") 5 | execute_process( 6 | COMMAND ${FLUX_GENERATOR_BINARY_DIR}/gen_a2a_transpose_gemm "--dir=./registers" "--archs=${CUDAARCHS}" "--sm_cores=${SM_CORES}" 7 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} 8 | COMMAND_ERROR_IS_FATAL ANY 9 | COMMAND_ECHO STDOUT 10 | ) 11 | 12 | file(GLOB OP_REGS ${CMAKE_CURRENT_BINARY_DIR}/registers/*.cu) 13 | file(GLOB TUNING_CONFIGS tuning_config/*.cu) 14 | set(CU_FILES 15 | post_attn_a2a_transpose_impls.cu 16 | ${OP_REGS} 17 | ${TUNING_CONFIGS} 18 | ) 19 | 20 | set(LIB_NAME "flux_cuda_all_to_all_gemm") 21 | flux_add_op_cu_obj_lib(${LIB_NAME} "${CU_FILES}") 22 | target_compile_options(${LIB_NAME} PRIVATE $<$:-rdc=true>) 23 | 24 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS}) 25 | list(APPEND FLUX_CUDA_OP_TARGETS ${LIB_NAME}) 26 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS} PARENT_SCOPE) 27 | 28 | if (BUILD_THS) 29 | file(GLOB THS_FILES ths_op/*.cc) 30 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES}) 31 | list(APPEND FLUX_THS_OP_FILES ${THS_FILES}) 32 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES} PARENT_SCOPE) 33 | flux_add_ths_op_target("a2a_transpose_gemm") 34 | endif() 35 | -------------------------------------------------------------------------------- /src/moe_gather_rs/ths_op/topk_reduce_gather_rs.h: -------------------------------------------------------------------------------- 1 | //===- topk_reduce_gather_rs.h ------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | 21 | #include "flux/args/moe_gather_rs.h" 22 | namespace bytedance::flux::ths_op { 23 | void topk_reduce_gather_rs(TopKReduceGatherRSArguments const &args, torch::Tensor output); 24 | void ep_topk_reduce_gather_rs( 25 | TopKReduceGatherRSArguments const &args, torch::Tensor output, int ep_m_start, int ep_m_end); 26 | 27 | } // namespace bytedance::flux::ths_op 28 | -------------------------------------------------------------------------------- /launch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # libflux_cuda.so maybe installed under /usr/local/lib or ~/.local/lib/ by pip3 3 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib:~/.local/lib/ 4 | SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd) 5 | FLUX_SRC_DIR=${SCRIPT_DIR} 6 | 7 | # add flux python package to PYTHONPATH 8 | export NVSHMEM_BOOTSTRAP=UID 9 | export NVSHMEM_DISABLE_CUDA_VMM=1 # moving from cpp to shell 10 | export CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS:-1} 11 | export CUDA_MODULE_LOADING=LAZY # EAGER if launch the consumer kernel before the producer kernel on host 12 | 13 | # set default communication env vars 14 | export BYTED_TORCH_BYTECCL=O0 15 | export NCCL_IB_TIMEOUT=${NCCL_IB_TIMEOUT:=23} 16 | 17 | nproc_per_node=$(nvidia-smi --list-gpus | wc -l) 18 | nnodes=1 19 | node_rank=0 20 | master_addr="127.0.0.1" 21 | master_port="23456" 22 | additional_args="--rdzv_endpoint=${master_addr}:${master_port}" 23 | IB_HCA=mlx5 24 | 25 | 26 | export NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX:=3} 27 | export NVSHMEM_IB_GID_INDEX=3 28 | 29 | 30 | CMD="torchrun \ 31 | --node_rank=${node_rank} \ 32 | --nproc_per_node=${nproc_per_node} \ 33 | --nnodes=${nnodes} \ 34 | ${FLUX_EXTRA_TORCHRUN_ARGS} ${additional_args} $@" 35 | 36 | echo ${CMD} 37 | ${CMD} 38 | 39 | ret=$? 40 | exit $ret 41 | -------------------------------------------------------------------------------- /src/moe_ag_scatter/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/registers") 4 | execute_process( 5 | COMMAND ${FLUX_GENERATOR_BINARY_DIR}/gen_moe_ag_scatter "--dir=./registers" "--archs=${CUDAARCHS}" "--sm_cores=${SM_CORES}" 6 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} 7 | COMMAND_ERROR_IS_FATAL ANY 8 | COMMAND_ECHO STDOUT 9 | ) 10 | 11 | file(GLOB OP_REGS ${CMAKE_CURRENT_BINARY_DIR}/registers/*.cu) 12 | file(GLOB TUNING_CONFIGS tuning_config/*.cu) 13 | set(CU_FILES 14 | ${OP_REGS} 15 | ${TUNING_CONFIGS} 16 | sort_util.cu 17 | triton_util.cu 18 | workspace_util.cu 19 | ) 20 | 21 | set(LIB_NAME "flux_cuda_moe_ag_scatter") 22 | flux_add_op_cu_obj_lib(${LIB_NAME} "${CU_FILES}") 23 | 24 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS}) 25 | list(APPEND FLUX_CUDA_OP_TARGETS ${LIB_NAME}) 26 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS} PARENT_SCOPE) 27 | 28 | if (BUILD_THS) 29 | file(GLOB THS_FILES ths_op/*.cc) 30 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES}) 31 | list(APPEND FLUX_THS_OP_FILES ${THS_FILES}) 32 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES} PARENT_SCOPE) 33 | flux_add_ths_op_target("gemm_grouped_v2_ag_scatter;gemm_grouped_v3_ag_scatter") 34 | endif() 35 | 36 | if (BUILD_TEST) 37 | add_subdirectory(test) 38 | endif() 39 | -------------------------------------------------------------------------------- /src/gemm_a2a_transpose/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | set(OP_REGS "") 4 | file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/registers") 5 | execute_process( 6 | COMMAND ${FLUX_GENERATOR_BINARY_DIR}/gen_gemm_a2a_transpose "--dir=./registers" "--archs=${CUDAARCHS}" "--sm_cores=${SM_CORES}" 7 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} 8 | COMMAND_ERROR_IS_FATAL ANY 9 | COMMAND_ECHO STDOUT 10 | ) 11 | 12 | file(GLOB OP_REGS ${CMAKE_CURRENT_BINARY_DIR}/registers/*.cu) 13 | file(GLOB TUNING_CONFIGS tuning_config/*.cu) 14 | set(CU_FILES 15 | pre_attn_a2a_transpose_impls.cu 16 | pre_attn_qkv_pack_a2a_impls.cu 17 | ${OP_REGS} 18 | ${TUNING_CONFIGS} 19 | ) 20 | 21 | set(LIB_NAME "flux_cuda_gemm_a2a_transpose") 22 | flux_add_op_cu_obj_lib(${LIB_NAME} "${CU_FILES}") 23 | target_compile_options(${LIB_NAME} PRIVATE $<$:-rdc=true>) 24 | 25 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS}) 26 | list(APPEND FLUX_CUDA_OP_TARGETS ${LIB_NAME}) 27 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS} PARENT_SCOPE) 28 | 29 | if (BUILD_THS) 30 | file(GLOB THS_FILES ths_op/*.cc) 31 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES}) 32 | list(APPEND FLUX_THS_OP_FILES ${THS_FILES}) 33 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES} PARENT_SCOPE) 34 | flux_add_ths_op_target("gemm_a2a_transpose") 35 | endif() -------------------------------------------------------------------------------- /src/gemm_rs/ring_reduce.hpp: -------------------------------------------------------------------------------- 1 | //===- ring_reduce.hpp -------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | #include 22 | 23 | #include "flux/flux.h" 24 | 25 | namespace bytedance { 26 | namespace flux { 27 | 28 | void ring_reduce( 29 | void *input, 30 | void *output, 31 | int32_t rank, 32 | int32_t node_num, 33 | int32_t world_size, 34 | int32_t chunk_size, 35 | DataTypeEnum dtype, 36 | cudaStream_t stream); 37 | 38 | } 39 | } // namespace bytedance 40 | -------------------------------------------------------------------------------- /src/gemm_rs/bsr_reduce.hpp: -------------------------------------------------------------------------------- 1 | //===- bsr_reduce.hpp --------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | #include 22 | 23 | #include 24 | 25 | #include "flux/flux.h" 26 | 27 | namespace bytedance { 28 | namespace flux { 29 | 30 | void bsr2dense_reduce( 31 | void *input, 32 | void *output, 33 | std::vector shape, 34 | int block_h, 35 | int block_w, 36 | DataTypeEnum dtype, 37 | cudaStream_t stream); 38 | 39 | } 40 | } // namespace bytedance 41 | -------------------------------------------------------------------------------- /src/gemm_rs/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | 3 | file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/registers") 4 | execute_process( 5 | COMMAND ${FLUX_GENERATOR_BINARY_DIR}/gen_gemm_rs "--dir=./registers" "--archs=${CUDAARCHS}" "--sm_cores=${SM_CORES}" 6 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} 7 | COMMAND_ERROR_IS_FATAL ANY 8 | COMMAND_ECHO STDOUT 9 | ) 10 | 11 | file(GLOB OP_REGS ${CMAKE_CURRENT_BINARY_DIR}/registers/*.cu) 12 | file(GLOB TUNING_CONFIGS tuning_config/*.cu) 13 | set(CU_FILES 14 | bsr_reduce.cu 15 | padding_util.cu 16 | ring_reduce.cu 17 | ${OP_REGS} 18 | ${TUNING_CONFIGS} 19 | ) 20 | 21 | set(LIB_NAME "flux_cuda_reduce_scatter") 22 | flux_add_op_cu_obj_lib(${LIB_NAME} "${CU_FILES}") 23 | target_compile_options(${LIB_NAME} PRIVATE $<$:-rdc=true>) 24 | 25 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS}) 26 | list(APPEND FLUX_CUDA_OP_TARGETS ${LIB_NAME}) 27 | set(FLUX_CUDA_OP_TARGETS ${FLUX_CUDA_OP_TARGETS} PARENT_SCOPE) 28 | 29 | if (BUILD_THS) 30 | file(GLOB THS_FILES ths_op/*.cc) 31 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES}) 32 | list(APPEND FLUX_THS_OP_FILES ${THS_FILES}) 33 | set(FLUX_THS_OP_FILES ${FLUX_THS_OP_FILES} PARENT_SCOPE) 34 | flux_add_ths_op_target("gemm_rs") 35 | endif() 36 | 37 | if (BUILD_TEST) 38 | add_subdirectory(test) 39 | endif() 40 | -------------------------------------------------------------------------------- /include/flux/args/gemm_a2a_transpose.h: -------------------------------------------------------------------------------- 1 | //===- gemm_a2a_transpose.h --------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | namespace bytedance::flux { 21 | 22 | struct GemmAllToAllTransposeArguments { 23 | int m; 24 | int n; 25 | int k; 26 | int nnodes; 27 | int rank; 28 | int world_size; 29 | float alpha; 30 | float beta; 31 | void const *input; 32 | void const *weight; 33 | void const *bias; 34 | void *gemm_output; 35 | void **barrier_ptrs; 36 | int32_t sm_margin; 37 | }; 38 | 39 | } // namespace bytedance::flux 40 | -------------------------------------------------------------------------------- /test/python/util/cuda_kernels/copy_kernel.cu: -------------------------------------------------------------------------------- 1 | //===- copy_kernel.cu ------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include 19 | 20 | extern "C" __global__ void 21 | copy_kernel(void *__restrict__ to_ptr, void *__restrict__ from_ptr, size_t nbytes) { 22 | size_t elems = nbytes / sizeof(uint4); 23 | uint4 *dst_ptr = (uint4 *)to_ptr; 24 | uint4 *src_ptr = (uint4 *)from_ptr; 25 | #pragma unroll(4) 26 | for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < elems; 27 | tid += blockDim.x * gridDim.x) { 28 | dst_ptr[tid] = src_ptr[tid]; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/ag_gemm/all_gather_swizzle.hpp: -------------------------------------------------------------------------------- 1 | //===- all_gather_swizzle.hpp ------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | 20 | namespace bytedance::flux { 21 | 22 | constexpr static int kNodes = 4; 23 | 24 | struct NodeSwizzle { 25 | int swizzle[4 * 4]; 26 | }; 27 | 28 | constexpr static __device__ NodeSwizzle nodes_swizzle[] = { 29 | {0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, // 1x1 30 | {0, 1, 1, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, // 2x2 31 | {0, 3, 2, 1, 1, 0, 3, 2, 2, 1, 0, 3, 3, 2, 1, 0}}; // 4x4 32 | 33 | } // namespace bytedance::flux 34 | -------------------------------------------------------------------------------- /src/gemm_rs/ths_op/helper_ops.h: -------------------------------------------------------------------------------- 1 | //===- helper_ops.h ----------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | 22 | namespace bytedance::flux::ths_op { 23 | void bsr_reduce(torch::Tensor input, torch::Tensor output, int block_h, int block_w); 24 | std::pair> pad_m_to_TPxTile( 25 | torch::Tensor input, c10::optional input_scale, int tp_size, int tile_size); 26 | void ring_reduce(torch::Tensor input, torch::Tensor output, int32_t dim, int rank); 27 | } // namespace bytedance::flux::ths_op 28 | -------------------------------------------------------------------------------- /include/flux/args/a2a_transpose_gemm.h: -------------------------------------------------------------------------------- 1 | //===- a2a_transpose_gemm.h --------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | namespace bytedance::flux { 21 | 22 | struct A2ATransposeGemmKernelArguments { 23 | int m; 24 | int n; 25 | int k; 26 | int m_per_barrier; 27 | int sm_margin; 28 | int rank; 29 | int world_size; 30 | int nnodes; 31 | float alpha; 32 | float beta; 33 | void *input; 34 | void const *weight; 35 | void const *bias; 36 | void *output; 37 | void *barrier_buffer; 38 | int32_t *a2a_signal; // wait until this signal is set 39 | }; 40 | 41 | } // namespace bytedance::flux 42 | -------------------------------------------------------------------------------- /src/comm_none/cutlass_blockscale_gemm_impl.h: -------------------------------------------------------------------------------- 1 | //===- cutlass_blockscale_gemm_impl.h ----------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include "flux/args/comm_none.h" 19 | #include "flux/gemm_meta.h" 20 | 21 | namespace bytedance { 22 | namespace flux { 23 | 24 | struct CutlassBlockScaleGemm { 25 | CutlassBlockScaleGemm(const UnifiedGemmMeta meta, const GemmBlockScaleNEnum scale_type_b); 26 | 27 | void run(const BlockScaleGemmArguments &flux_args, void *workspace, cudaStream_t stream); 28 | 29 | size_t get_workspace_size(const BlockScaleGemmArguments &flux_args); 30 | 31 | const UnifiedGemmMeta meta_; 32 | GemmBlockScaleNEnum scale_type_b_; 33 | }; 34 | 35 | } // namespace flux 36 | } // namespace bytedance 37 | -------------------------------------------------------------------------------- /src/comm_none/ths_op/gemm_grouped_v3.h: -------------------------------------------------------------------------------- 1 | //===- gemm_grouped_v3.h ----------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | 20 | #include 21 | #include "flux/ths_op/ths_op.h" 22 | 23 | namespace bytedance::flux::ths_op { 24 | class GemmGroupedV3 { 25 | public: 26 | GemmGroupedV3(torch::Tensor weight, int64_t num_experts); 27 | ~GemmGroupedV3(); 28 | torch::Tensor forward(torch::Tensor input, torch::Tensor splits_cpu); 29 | torch::Tensor profiling( 30 | torch::Tensor input, torch::Tensor splits_cpu, c10::intrusive_ptr opt_ctx); 31 | 32 | private: 33 | class GemmGroupedV3Impl; 34 | GemmGroupedV3Impl *impl_ = nullptr; 35 | }; 36 | } // namespace bytedance::flux::ths_op 37 | -------------------------------------------------------------------------------- /src/gemm_rs/padding_util.hpp: -------------------------------------------------------------------------------- 1 | //===- padding_util.hpp --------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | #include 22 | 23 | #include "flux/flux.h" 24 | 25 | namespace bytedance { 26 | namespace flux { 27 | 28 | // pad m-dim of the input tensor to be multiple of TPxtile_size 29 | void pad_m_to_TPxTile( 30 | void const *input, 31 | void const *scale, // shape: (m_size, 1) 32 | void *output, 33 | void *padded_scale, 34 | int m_size, 35 | int n_size, 36 | int tp_size, 37 | int tile_size, 38 | DataTypeEnum input_dtype, 39 | DataTypeEnum scale_dtype, 40 | cudaStream_t stream); 41 | } // namespace flux 42 | } // namespace bytedance 43 | -------------------------------------------------------------------------------- /include/flux/cuda/helper_kernels.h: -------------------------------------------------------------------------------- 1 | //===- helper_kernels.h ------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include "flux/flux.h" 21 | #include 22 | 23 | namespace bytedance::flux { 24 | 25 | bool bitwise_check(DataTypeEnum dtype, void *ptr_A, void *ptr_B, size_t capacity); 26 | 27 | void uniform_initialize( 28 | DataTypeEnum dtype, 29 | void *ptr, 30 | size_t capacity, 31 | uint64_t seed, 32 | double min = 0.0, 33 | double max = 1.0, 34 | void *stream = nullptr); 35 | 36 | void cudaipc_barrier_all_on_stream_impl( 37 | cudaStream_t stream, 38 | int32_t **sync_buffer_ptr, 39 | int rank, 40 | int world_size, 41 | bool ring_mode = false); 42 | } // namespace bytedance::flux 43 | -------------------------------------------------------------------------------- /test/python/util/test_uniform_initialize.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ################################################################################ 17 | 18 | import argparse 19 | 20 | import torch 21 | import torch.distributed 22 | 23 | import flux 24 | from flux.testing import DTYPE_MAP 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("-M", type=int, default=1024) 30 | parser.add_argument("-N", type=int, default=2048) 31 | parser.add_argument("--dtype", default="bfloat16", type=str, choices=list(DTYPE_MAP.keys())) 32 | 33 | return parser.parse_args() 34 | 35 | 36 | if __name__ == "__main__": 37 | args = parse_args() 38 | dtype = DTYPE_MAP[args.dtype] 39 | tensor = torch.zeros(args.M, args.N).cuda().to(dtype) 40 | flux.uniform_initialize(tensor, 2024, 0.0, 1.0) 41 | print(tensor) 42 | -------------------------------------------------------------------------------- /src/moe_ag_scatter/triton_util.h: -------------------------------------------------------------------------------- 1 | //===- triton_util.h ------------------------------------------- C++ ------===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | #pragma once 18 | #include 19 | 20 | #include 21 | 22 | namespace bytedance::flux { 23 | void get_moe_ag_scatter_args( 24 | const int32_t *splits_gpu_ptr, 25 | const int32_t *cumsum_per_rank_gpu_ptr, 26 | void *problem_schedule_ptr, 27 | int num_scheds, 28 | int32_t *gather_index_ptr, 29 | int32_t *scatter_index_ptr, 30 | int ep_start, 31 | int ep_experts, 32 | int world_size, 33 | int M_this_ep, 34 | int tile_size_m, 35 | int32_t *m_pad_ptr, 36 | int32_t *gather_a_ptr, 37 | int32_t *scatter_d_ptr, 38 | int32_t *expert_idx_ptr, 39 | int32_t *rank_start_ptr, 40 | int32_t *rank_end_ptr, 41 | cudaStream_t stream); 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/cuda/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | if (WITH_PROTOBUF) 3 | add_compile_options("-DWITH_PROTOBUF") 4 | set(flux_proto_lib flux_proto ${Protobuf_LIBRARIES}) 5 | set(flux_proto_inc ${Protobuf_INCLUDE_DIRS}) 6 | else () 7 | set(flux_proto_lib) 8 | set(flux_proto_inc) 9 | endif() 10 | 11 | set(LIB_FILES 12 | op_registry.cu 13 | op_registry_proto_utils.cc 14 | cudaipc_barrier_all.cu 15 | bitwise_check.cu 16 | random_initialize.cu 17 | utils.cc 18 | cuda_common.cc 19 | cuda_common.cu 20 | cuda_stub.cc 21 | nvml_stub.cc 22 | moe_utils.cu 23 | ) 24 | 25 | set(_VER_FILE ${CMAKE_CURRENT_SOURCE_DIR}/version.ld) 26 | 27 | add_library(flux_cuda SHARED ${LIB_FILES}) 28 | set_target_properties(flux_cuda PROPERTIES 29 | CUDA_RESOLVE_DEVICE_SYMBOLS ON 30 | LINK_DEPENDS ${_VER_FILE} 31 | POSITION_INDEPENDENT_CODE ON 32 | ) 33 | 34 | if(ENABLE_NVSHMEM) 35 | target_link_libraries(flux_cuda 36 | PUBLIC -lnvshmem_host CUDA::cudart CUDA::cuda_driver ${FLUX_CUDA_OP_TARGETS} 37 | PRIVATE -lnvshmem_device ${flux_proto_lib} 38 | ) 39 | else() 40 | target_link_libraries(flux_cuda 41 | PUBLIC CUDA::cudart CUDA::cuda_driver ${FLUX_CUDA_OP_TARGETS} 42 | PRIVATE ${flux_proto_lib} 43 | ) 44 | endif() 45 | target_include_directories(flux_cuda PRIVATE ${flux_proto_inc}) 46 | target_link_options(flux_cuda 47 | PRIVATE 48 | $<$:LINKER:--exclude-libs=libflux_proto:libprotobuf 49 | LINKER:--version-script=${_VER_FILE} 50 | LINKER:--no-as-needed> 51 | ) 52 | install(TARGETS flux_cuda 53 | PUBLIC_HEADER DESTINATION include 54 | ) 55 | -------------------------------------------------------------------------------- /src/coll/all2all_impl.hpp: -------------------------------------------------------------------------------- 1 | //===- all2all_impl.hpp --------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | 22 | namespace bytedance { 23 | namespace flux { 24 | 25 | struct All2allParams { 26 | void *input_ptr; 27 | void *output_ptr; 28 | int32_t *splits_input_buffer; 29 | int32_t *splits_output_buffer; 30 | uint64_t *signal_buffer; 31 | int32_t *input_splits_cumsum; 32 | float *scale_input_buffer; 33 | float *scale_output_buffer; 34 | int32_t rank; 35 | int32_t world_size; 36 | int32_t ndim; 37 | int32_t element_size; 38 | int32_t max_token; 39 | int32_t expert_per_rank; 40 | uint64_t signal_to_wait; 41 | bool with_scale; 42 | }; 43 | 44 | void all2all_impl(const All2allParams ¶ms, cudaStream_t stream); 45 | 46 | } // namespace flux 47 | } // namespace bytedance 48 | -------------------------------------------------------------------------------- /src/comm_none/ths_op/gemm_grouped_v2.h: -------------------------------------------------------------------------------- 1 | //===- gemm_grouped_v2.h ----------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | 20 | #include 21 | 22 | namespace bytedance::flux::ths_op { 23 | class GemmGroupedV2 { 24 | public: 25 | GemmGroupedV2( 26 | torch::Tensor weight, int64_t num_experts, at::ScalarType in_type, at::ScalarType out_type); 27 | ~GemmGroupedV2(); 28 | torch::Tensor forward( 29 | torch::Tensor input, 30 | torch::Tensor splits_cpu, 31 | c10::optional> input_scale, 32 | c10::optional> weight_scale, 33 | c10::optional output_scale, 34 | bool fast_accum, 35 | int64_t sm_margin); 36 | 37 | private: 38 | class GemmGroupedV2Impl; 39 | GemmGroupedV2Impl *impl_ = nullptr; 40 | }; 41 | } // namespace bytedance::flux::ths_op 42 | -------------------------------------------------------------------------------- /include/flux/cuda/moe_utils.h: -------------------------------------------------------------------------------- 1 | //===- moe_utils.h ----------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | #pragma once 18 | 19 | #include 20 | 21 | namespace bytedance::flux { 22 | /** 23 | * @brief a none-deterministic way to calculate scatter_index from choosed_experts. 24 | * 25 | * @param[in] choosed_experts : of topk * ntokens 26 | * @param[in] count : count of per experts. 27 | * @param[out] scatter_index : of topk * ntokens 28 | * @param[in] total_num : topk * ntokens 29 | * @param[in] expert_num 30 | * @param[in] stream 31 | */ 32 | void calc_scatter_index( 33 | const int *choosed_experts, // of total_num 34 | const int *count, // of expert_num 35 | int *scatter_index, // of total_num 36 | const int total_num, // topk * ntokens 37 | int expert_num, 38 | cudaStream_t stream); 39 | 40 | } // namespace bytedance::flux 41 | -------------------------------------------------------------------------------- /src/coll/all2all_single_2d_impl.hpp: -------------------------------------------------------------------------------- 1 | //===- all2all_single_2d_impl.hpp --------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "flux/flux.h" 20 | #include 21 | #include 22 | namespace bytedance { 23 | namespace flux { 24 | 25 | struct All2AllSingleParams { 26 | void *input_comm_ptr; // symm buf 27 | void *output_comm_ptr; // symm buf 28 | void *output_ptr; // normal buf 29 | uint64_t *barrier_ptr; 30 | int32_t *input_splits; 31 | int32_t *output_splits; 32 | int64_t n_dim; 33 | int64_t max_split; 34 | 35 | int32_t rank; 36 | int32_t local_rank; 37 | int32_t local_world_size; 38 | int32_t world_size; 39 | int32_t nvshmem_team; 40 | }; 41 | 42 | void a2a_single_impl( 43 | const All2AllSingleParams params, 44 | DataTypeEnum input_dtype, 45 | int32_t num_comm_sm, 46 | cudaStream_t stream); 47 | } // namespace flux 48 | } // namespace bytedance 49 | -------------------------------------------------------------------------------- /test/python/util/cuda_kernels/reduce_kernel.cu: -------------------------------------------------------------------------------- 1 | //===- reduce_kernel.cu ----------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include 19 | #include 20 | 21 | __device__ __forceinline__ void 22 | global_red(half2 const &D, void *ptr) { 23 | uint32_t const &data = reinterpret_cast(D); 24 | asm volatile( 25 | "{\n" 26 | " red.global.sys.add.noftz.f16x2 [%0], %1;\n" 27 | "}\n" 28 | : 29 | : "l"(ptr), "r"(data)); 30 | } 31 | 32 | extern "C" __global__ void 33 | reduce_kernel(void *dst, const void *src, size_t nbytes) { 34 | constexpr int kElemsPerVec = sizeof(half2); 35 | size_t elems = nbytes / kElemsPerVec; 36 | half2 *src_vec = (half2 *)src; 37 | half2 *dst_vec = (half2 *)dst; 38 | 39 | #pragma unroll(8) 40 | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < elems; i += gridDim.x * blockDim.x) { 41 | global_red(src_vec[i], dst_vec + i); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/coll/local_copy_and_reset.hpp: -------------------------------------------------------------------------------- 1 | //===- local_copy_and_reset.hpp ----------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "flux/flux.h" 20 | 21 | namespace bytedance { 22 | namespace flux { 23 | 24 | void local_copy_and_reset_impl( 25 | void *input_src, 26 | void *input_dst, 27 | void *scale_src, 28 | void *scale_dst, 29 | int32_t *counter, 30 | int32_t *ag_barrier, 31 | int32_t world_size, 32 | int32_t rank, 33 | int32_t m, 34 | int32_t n, 35 | int32_t **sync_barriers, // a list of barrier pointer to sync between device, if set to 36 | // nullptr, the sync is not perform. 37 | DataTypeEnum input_dtype, 38 | DataTypeEnum scale_dtype, 39 | bool sync_ring_mode, // sync in ring_mode or not. 40 | cudaStream_t stream); 41 | 42 | size_t get_local_copy_max_block_num(size_t num_input, int32_t pack_size = 1); 43 | } // namespace flux 44 | } // namespace bytedance 45 | -------------------------------------------------------------------------------- /src/coll/ths_op/all2all_single_2d.h: -------------------------------------------------------------------------------- 1 | //===- all2all_single_2d.h ---------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | #include 22 | #include "flux/ths_op/flux_shm.h" 23 | namespace bytedance::flux::ths_op { 24 | class All2AllSingle { 25 | public: 26 | All2AllSingle( 27 | std::shared_ptr pg, 28 | int64_t max_split, 29 | int64_t n_dim, 30 | int64_t local_world_size, 31 | at::ScalarType input_dtype, 32 | int64_t ep_team); 33 | 34 | ~All2AllSingle(); 35 | 36 | torch::Tensor forward( 37 | torch::Tensor input, 38 | torch::Tensor output, 39 | torch::Tensor input_splits, 40 | torch::Tensor output_splits, 41 | int32_t num_comm_sm); 42 | 43 | private: 44 | class All2AllSingleImpl; 45 | All2AllSingleImpl *impl_; 46 | }; 47 | 48 | } // namespace bytedance::flux::ths_op 49 | -------------------------------------------------------------------------------- /src/pybind/inplace_cast.cc: -------------------------------------------------------------------------------- 1 | //===- inplace_cast.cc -------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include "inplace_cast/ths_op/inplace_cast.h" 19 | 20 | #include "comm_none/ths_op/gemm_only.h" 21 | #include "flux/ths_op/ths_pybind.h" 22 | #include "inplace_cast/ths_op/helper_ops.h" 23 | 24 | namespace bytedance::flux::ths_op { 25 | 26 | namespace py = pybind11; 27 | using InplaceCastCls = TorchClassWrapper; 28 | 29 | static int _register_gemm_only_ops [[maybe_unused]] = []() { 30 | ThsOpsInitRegistry::instance().register_one("inplace_cast", [](py::module &m) { 31 | py::class_(m, "InplaceCast") 32 | .def( 33 | py::init([](int32_t data_size) { return new InplaceCast(data_size); }), 34 | py::arg("data_size")) 35 | .def("from_fp32_to_bf16", &InplaceCast::from_fp32_to_bf16, py::arg("input")); 36 | m.def("inplace_cast_fp32_to_bf16", &inplace_cast_fp32_to_bf16); 37 | }); 38 | return 0; 39 | }(); 40 | } // namespace bytedance::flux::ths_op 41 | -------------------------------------------------------------------------------- /src/moe_gather_rs/topk_gather_rs.hpp: -------------------------------------------------------------------------------- 1 | //===- topk_gather_rs.hpp --------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | #include 22 | 23 | #include "flux/args/moe_gather_rs.h" 24 | #include "flux/flux.h" 25 | 26 | namespace bytedance { 27 | namespace flux { 28 | 29 | void topk_gather_rs( 30 | TopKReduceGatherRSArguments const &args, DataTypeEnum dtype, cudaStream_t stream); 31 | 32 | void topk_gather_rs_v2( 33 | TopKReduceGatherRSV2Arguments const &args, DataTypeEnum dtype, cudaStream_t stream); 34 | 35 | void ep_topk_gather_rs( 36 | TopKReduceGatherRSArguments const &args, 37 | DataTypeEnum dtype, 38 | int32_t ep_m_start, 39 | int32_t ep_m_end, 40 | cudaStream_t stream); 41 | 42 | void ep_topk_gather_rs_v2( 43 | TopKReduceGatherRSV2Arguments const &args, 44 | DataTypeEnum dtype, 45 | int32_t ep_m_start, 46 | int32_t ep_m_end, 47 | cudaStream_t stream); 48 | } // namespace flux 49 | } // namespace bytedance 50 | -------------------------------------------------------------------------------- /src/gemm_a2a_transpose/pre_attn_a2a_transpose_impls.hpp: -------------------------------------------------------------------------------- 1 | //===- pre_attn_a2a_transpose_impls.hpp ------------------------ C++ ------===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "flux/flux.h" 20 | 21 | namespace bytedance { 22 | namespace flux { 23 | struct PreAttnAll2AllTransposeParam { 24 | void **input_ptrs; 25 | void *output_ptr; 26 | void *barrier_ptrs[kMaxWorldSize]; 27 | int32_t bs; 28 | int32_t local_nheads; 29 | int32_t local_seq_len; 30 | int32_t head_dim; 31 | int32_t rank; 32 | int32_t world_size; 33 | int32_t TILE_M; // along the seq dim of the output, ensure that local_seq_len % TILE_M == 0; 34 | int32_t TILE_N; // along the head_dim and local_nheads dim of output, ensure that 35 | // TILE_N % head_dim == 0 36 | int32_t m_per_barrier; 37 | int32_t n_per_barrier; 38 | int32_t NUM_COMM_SM; 39 | }; 40 | 41 | void pre_attn_all2all_transpose_impl( 42 | const PreAttnAll2AllTransposeParam param, DataTypeEnum input_dtype, cudaStream_t stream); 43 | 44 | } // namespace flux 45 | } // namespace bytedance 46 | -------------------------------------------------------------------------------- /src/moe_gather_rs/ths_op/moe_utils.h: -------------------------------------------------------------------------------- 1 | //===- moe_utils.h ------------------------------------------------ C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | 22 | namespace bytedance::flux::ths_op { 23 | 24 | class TransportOp { 25 | public: 26 | TransportOp(int64_t rank, int64_t world_size, torch::Tensor recv_buffer); 27 | 28 | void copy_by_sm( 29 | torch::Tensor send_buffer, torch::Tensor transport_offsets, torch::Tensor transport_nbytes); 30 | 31 | void copy_by_ce( 32 | torch::Tensor send_buffer, torch::Tensor transport_offsets, torch::Tensor transport_nbytes); 33 | 34 | private: 35 | class TransportOpImpl; 36 | TransportOpImpl *impl_; 37 | }; 38 | 39 | class All2AllOp { 40 | public: 41 | All2AllOp(int64_t rank, int64_t world_size, torch::Tensor recv_buffer); 42 | ~All2AllOp(); 43 | 44 | void forward(c10::List send_buffer); 45 | 46 | private: 47 | class All2AllOpImpl; 48 | All2AllOpImpl *impl_; 49 | }; 50 | } // namespace bytedance::flux::ths_op 51 | -------------------------------------------------------------------------------- /src/gemm_a2a_transpose/pre_attn_qkv_pack_a2a_impls.hpp: -------------------------------------------------------------------------------- 1 | //===- pre_attn_qkv_pack_a2a_impls.hpp ------------------------- C++ ------===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "flux/flux.h" 20 | 21 | namespace bytedance { 22 | namespace flux { 23 | struct PreAttnQKVPackA2AParams { 24 | void **input_ptrs; 25 | void *q_ptr; 26 | void *k_ptr; 27 | void *v_ptr; 28 | void *barrier_ptrs[kMaxWorldSize]; 29 | int32_t bs; 30 | int32_t local_q_nheads; 31 | int32_t local_k_nheads; 32 | int32_t local_v_nheads; 33 | int32_t head_dim; 34 | int32_t rank; 35 | int32_t world_size; 36 | int32_t TILE_S; // tile size of output(qkv pack) seq dim 37 | int32_t TILE_NH; // tile size of output(qkv pack) nheads dim 38 | int32_t m_per_barrier; 39 | int32_t n_per_barrier; 40 | int32_t num_comm_sm; 41 | int32_t cusum_seq_lens[kMaxWorldSize + 1]; 42 | bool skip_barrier = false; 43 | }; 44 | 45 | void pre_attn_qkv_pack_a2a_impl( 46 | const PreAttnQKVPackA2AParams params, DataTypeEnum input_dtype, cudaStream_t stream); 47 | 48 | } // namespace flux 49 | } // namespace bytedance 50 | -------------------------------------------------------------------------------- /src/coll/ths_op/isendrecv.h: -------------------------------------------------------------------------------- 1 | //===- isendrecv.h -------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | #include 22 | namespace bytedance::flux::ths_op { 23 | class AsyncSendRecv { 24 | public: 25 | AsyncSendRecv( 26 | int64_t max_m, 27 | int64_t n_dim, 28 | int64_t rank, // rank in pp 29 | int64_t world_size, // world_size of pp 30 | at::ScalarType input_dtype, 31 | int64_t duplicate); 32 | 33 | ~AsyncSendRecv(); 34 | 35 | torch::Tensor get_comm_buffer(int64_t comm_buff_id); 36 | torch::Tensor read_comm_buffer( 37 | int64_t tgt_rank, int64_t src_comm_buff_id, int64_t tgt_comm_buff_id); 38 | void write_comm_buffer(int64_t tgt_rank, int64_t src_comm_buff_id, int64_t tgt_comm_buff_id); 39 | void set_signal(int64_t tgt_rank, int64_t comm_buffer_id, int64_t value); 40 | void wait_signal_eq(int64_t comm_buffer_id, int64_t value); 41 | void reset_signal(int64_t comm_buffer_id); 42 | 43 | private: 44 | class AsyncSendRecvOpImpl; 45 | AsyncSendRecvOpImpl *impl_; 46 | }; 47 | 48 | } // namespace bytedance::flux::ths_op 49 | -------------------------------------------------------------------------------- /src/comm_none/tuning_config/config_gemm_only_sm80_A100.cu: -------------------------------------------------------------------------------- 1 | //===- config_gemm_only_sm80_A100.cu ----------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | // clang-format off 19 | #include "flux/op_registry.h" 20 | namespace bytedance::flux { 21 | using namespace cute; 22 | 23 | static int config_gemm_only_sm80_a100 = []() { 24 | auto &inst = TuningConfigRegistry::instance(); 25 | inst.add(make_gemm_meta(make_gemm_dtype_config(_S8{}(),_S8{}(),_Void{}(),_BF16{}(), _S32{}()),_Sm80{}(),_A100{}(),_CommNone{}(),_RCR{}(),_GemmV2{}()),make_runtime_config(512,8192,1024),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64,64),cute::make_tuple(16l,8l,32l),_StreamkSK{}()),None{},cute::make_tuple(128l,128l,64l),_GemmStreamK{}(),5,_RasterAlongM{}())); 26 | inst.add(make_gemm_meta(make_gemm_dtype_config(_S8{}(),_S8{}(),_BF16{}(),_BF16{}(), _S32{}()),_Sm80{}(),_A100{}(),_CommNone{}(),_RCR{}(),_GemmV2{}()),make_runtime_config(512,8192,1024),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64,64),cute::make_tuple(16l,8l,32l),_StreamkSK{}()),None{},cute::make_tuple(128l,128l,64l),_GemmStreamK{}(),5,_RasterAlongM{}())); 27 | return 0; 28 | }(); 29 | } 30 | // clang-format on 31 | -------------------------------------------------------------------------------- /src/moe_gather_rs/ths_op/topk_reduce_gather_rs.cc: -------------------------------------------------------------------------------- 1 | //===- topk_reduce_gather_rs.cc ---------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include "moe_gather_rs/ths_op/topk_reduce_gather_rs.h" 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include 26 | 27 | #include "flux/args/moe_gather_rs.h" 28 | #include "flux/ths_op/ths_op.h" 29 | #include "moe_gather_rs/topk_gather_rs.hpp" 30 | namespace bytedance::flux::ths_op { 31 | using torch::Tensor; 32 | 33 | void 34 | topk_reduce_gather_rs(TopKReduceGatherRSArguments const &args, torch::Tensor output) { 35 | topk_gather_rs(args, from_torch_dtype(output.scalar_type()), c10::cuda::getCurrentCUDAStream()); 36 | } 37 | 38 | void 39 | ep_topk_reduce_gather_rs( 40 | TopKReduceGatherRSArguments const &args, 41 | torch::Tensor output, 42 | int32_t ep_m_start, 43 | int ep_m_end) { 44 | ep_topk_gather_rs( 45 | args, 46 | from_torch_dtype(output.scalar_type()), 47 | ep_m_start, 48 | ep_m_end, 49 | c10::cuda::getCurrentCUDAStream()); 50 | } 51 | 52 | } // namespace bytedance::flux::ths_op 53 | -------------------------------------------------------------------------------- /src/moe_ag_scatter/tuning_config/config_ag_scatter_sm90_H800.cu: -------------------------------------------------------------------------------- 1 | //===- config_ag_scatter_sm90_H800.cu ---------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | // clang-format off 19 | #include "flux/op_registry.h" 20 | namespace bytedance::flux { 21 | using namespace cute; 22 | 23 | static int config_ag_scatter_sm90_h800 = []() { 24 | auto &inst = TuningConfigRegistry::instance(); 25 | inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_BF16{}(),_BF16{}(),_FP32{}()),_Sm90{}(),_H800{}(),_AGScatter{}(),_RCR{}(),_GemmGroupedV3{}(),make_gemm_v3_meta(false),None{}),make_runtime_config(768,2048,5120,None{}),make_gemm_hparams(make_gemm_v3_hparams(cute::make_tuple(2l,1l,1l),_PingPong{}()),None{},cute::make_tuple(64l,256l,64l),_GemmDefault{}(),0,_RasterAlongM{}())); 26 | inst.add(make_gemm_meta(make_gemm_dtype_config(_E4M3{}(),_E4M3{}(),_BF16{}(),_BF16{}(),_FP32{}()),_Sm90{}(),_H800{}(),_AGScatter{}(),_RCR{}(),_GemmGroupedV3{}(),make_gemm_v3_meta(false),None{}),make_runtime_config(6144,288,6144,None{}),make_gemm_hparams(make_gemm_v3_hparams(cute::make_tuple(1l,1l,1l),_Cooperative{}()),None{},cute::make_tuple(128l,128l,128l),_GemmDefault{}(),0,_RasterAlongN{}())); 27 | return 0; 28 | }(); 29 | } 30 | // clang-format on 31 | -------------------------------------------------------------------------------- /src/coll/all_gather_impls.hpp: -------------------------------------------------------------------------------- 1 | //===- all_gather_impls.hpp --------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "flux/flux.h" 20 | 21 | namespace bytedance { 22 | namespace flux { 23 | 24 | struct AllGatherParams { 25 | void *input_ptrs[kMaxWorldSize]; // input_ptrs[rank]: (m * world_size, n) 26 | void *scale_ptrs[kMaxWorldSize]; 27 | int32_t *ag_barriers[kMaxWorldSize]; // ag signal 28 | int32_t *counter; // sync between block to write signal 29 | 30 | int32_t world_size; 31 | int32_t rank; 32 | int sub_world_size; 33 | int32_t m; // input.size(0), actually m_per_rank 34 | int32_t n; // input.size(1) 35 | bool has_scale; 36 | int32_t *ag_signal; // the signal to ensure that GEMM is launched after AllGather 37 | }; 38 | 39 | void ag_a2a_mode( 40 | const AllGatherParams ¶ms, 41 | DataTypeEnum input_dtype, 42 | DataTypeEnum scale_dtype, 43 | cudaStream_t stream); 44 | 45 | void ag_ring_with_scale( 46 | const AllGatherParams ¶ms, 47 | int input_elem_size, 48 | int scale_elem_size, 49 | int num_grids, 50 | bool use_2d_mode, 51 | cudaStream_t stream); 52 | 53 | } // namespace flux 54 | } // namespace bytedance 55 | -------------------------------------------------------------------------------- /src/ag_gemm/ths_op/all_gather_gemm_op_internode.h: -------------------------------------------------------------------------------- 1 | //===- all_gather_gemm_op_internode.h ----------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | #pragma once 18 | #include 19 | #include 20 | #include "coll/ths_op/all_gather_types.h" 21 | 22 | namespace bytedance::flux::ths_op { 23 | 24 | class AllGatherGemmOpInterNode { 25 | public: 26 | AllGatherGemmOpInterNode( 27 | std::shared_ptr tp_group, 28 | std::shared_ptr intra_node_group, 29 | int32_t nnodes, 30 | torch::Tensor output_buffer, 31 | int32_t full_m, 32 | int32_t n_dim, 33 | int32_t k_dim, 34 | c10::ScalarType input_dtype, 35 | bool transpose_weight = true, 36 | bool local_copy = false, 37 | c10::optional ring_mode_ = c10::nullopt); 38 | ~AllGatherGemmOpInterNode(); 39 | void reset_signals(); 40 | void copy_local(torch::Tensor input); 41 | torch::Tensor gemm_only(torch::Tensor input, torch::Tensor full_input, torch::Tensor weight); 42 | torch::Tensor forward(torch::Tensor input, torch::Tensor weight); 43 | 44 | private: 45 | class AllGatherGemmOpInterNodeImpl; 46 | AllGatherGemmOpInterNodeImpl *impl_ = nullptr; 47 | }; 48 | } // namespace bytedance::flux::ths_op 49 | -------------------------------------------------------------------------------- /src/quantization/quantization.hpp: -------------------------------------------------------------------------------- 1 | //===- quantization.hpp ------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "flux/flux.h" 20 | 21 | namespace bytedance { 22 | namespace flux { 23 | 24 | void block_scaled_1d_cast_transpose_impl( 25 | void *input, 26 | void *output, 27 | void *output_t, 28 | void *scale_inv, 29 | void *scale_inv_t, 30 | const size_t row_length, 31 | const size_t num_rows, 32 | const size_t scale_stride_x, 33 | const size_t scale_stride_y, 34 | const size_t scale_t_stride_x, 35 | const size_t scale_t_stride_y, 36 | const float epsilon, 37 | dim3 grid, 38 | const bool return_transpose, 39 | cudaStream_t stream); 40 | 41 | void block_scaled_cast_transpose_impl( 42 | void *input, 43 | void *output, 44 | void *output_t, 45 | void *scale_inv, 46 | void *scale_inv_t, 47 | const size_t row_length, 48 | const size_t num_rows, 49 | const size_t scale_stride_x, 50 | const size_t scale_stride_y, 51 | const size_t scale_t_stride_x, 52 | const size_t scale_t_stride_y, 53 | const float epsilon, 54 | dim3 grid, 55 | const bool return_transpose, 56 | cudaStream_t stream); 57 | 58 | } // namespace flux 59 | } // namespace bytedance 60 | -------------------------------------------------------------------------------- /src/inplace_cast/ths_op/helper_ops.cc: -------------------------------------------------------------------------------- 1 | //===- helper_ops.cc ---------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include 19 | #include "flux/cuda/cuda_common.h" 20 | #include "inplace_cast/ths_op/helper_ops.h" 21 | #include "inplace_cast/inplace_cast.hpp" 22 | 23 | namespace bytedance::flux::ths_op { 24 | 25 | void 26 | inplace_cast_fp32_to_bf16(torch::Tensor data) { 27 | int block_size = INPLACE_CAST_BLOCK_SIZE; 28 | size_t data_size = data.numel(); 29 | 30 | size_t num_chunks = 31 | (data_size + block_size * INPLACE_CAST_TS - 1) / (block_size * INPLACE_CAST_TS); 32 | 33 | unsigned *flags; 34 | CUDA_CHECK(cudaMalloc(&flags, num_chunks * sizeof(unsigned))); 35 | CUDA_CHECK(cudaMemset(flags, 0, num_chunks * sizeof(unsigned))); 36 | 37 | unsigned *chunk_counter; 38 | CUDA_CHECK(cudaMalloc(&chunk_counter, sizeof(unsigned))); 39 | CUDA_CHECK(cudaMemset(chunk_counter, 0, sizeof(unsigned))); 40 | 41 | inplace_cast_fp32_to_bf16_impl( 42 | data.data_ptr(), 43 | data_size, 44 | flags, 45 | chunk_counter, 46 | c10::cuda::getCurrentCUDAStream(), 47 | INPLACE_CAST_NUM_BLOCKS, 48 | block_size); 49 | 50 | cudaFree(flags); 51 | cudaFree(chunk_counter); 52 | } 53 | 54 | } // namespace bytedance::flux::ths_op 55 | -------------------------------------------------------------------------------- /src/coll/ths_op/all2all_op.h: -------------------------------------------------------------------------------- 1 | //===- all2all_op.h -------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | #include 22 | #include 23 | namespace bytedance::flux::ths_op { 24 | class All2AllInference { 25 | public: 26 | All2AllInference( 27 | int64_t max_m, 28 | int64_t n_dim, 29 | int64_t rank, 30 | int64_t total_num_experts, 31 | int64_t world_size, 32 | int64_t local_world_size, 33 | int64_t max_element_size); 34 | 35 | ~All2AllInference(); 36 | std::vector get_input_buffer( 37 | std::vector input_shape, int64_t element_size, bool with_scale); 38 | std::vector forward( 39 | std::vector input_size, 40 | torch::Tensor input_split_cumsum, 41 | int64_t element_size, 42 | bool with_scale); 43 | std::vector forward_with_stream( 44 | std::vector input_size, 45 | torch::Tensor input_split_cumsum, 46 | int64_t element_size, 47 | bool with_scale, 48 | cudaStream_t stream); 49 | 50 | private: 51 | class All2AllInferenceOpImpl; 52 | All2AllInferenceOpImpl *impl_; 53 | }; 54 | 55 | } // namespace bytedance::flux::ths_op 56 | -------------------------------------------------------------------------------- /src/quantization/ths_op/quantization.h: -------------------------------------------------------------------------------- 1 | //===- quantization.h --------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | 22 | namespace bytedance::flux::ths_op { 23 | class Quantization { 24 | public: 25 | Quantization(c10::ScalarType input_dtype, c10::ScalarType output_dtype, int32_t num_streams); 26 | ~Quantization(); 27 | 28 | std::tuple< 29 | torch::Tensor, 30 | torch::Tensor, 31 | c10::optional, 32 | c10::optional> 33 | quantize_vector_blockwise(torch::Tensor input, bool return_transpose, float eps = 0.0f); 34 | 35 | std::tuple< 36 | torch::Tensor, 37 | torch::Tensor, 38 | c10::optional, 39 | c10::optional> 40 | quantize_square_blockwise(torch::Tensor input, bool return_transpose, float eps = 0.0f); 41 | 42 | std::tuple< 43 | torch::Tensor, 44 | torch::Tensor, 45 | c10::optional, 46 | c10::optional> 47 | batch_quantize_square_blockwise(torch::Tensor input, bool return_transpose, float eps = 0.0f); 48 | 49 | private: 50 | class QuantizationImpl; 51 | QuantizationImpl *impl = nullptr; 52 | }; 53 | 54 | } // namespace bytedance::flux::ths_op 55 | -------------------------------------------------------------------------------- /src/moe_ag_scatter/dispatch_policy.hpp: -------------------------------------------------------------------------------- 1 | //===- dispatch_policy.hpp ------------------------------------- C++ ------===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | 20 | #include "cutlass/arch/arch.h" 21 | #include "cutlass/gemm/gemm.h" 22 | 23 | #include "cute/layout.hpp" 24 | #include "cute/numeric/integral_constant.hpp" 25 | ////////////////////////////////////////////////////////////////////////////// 26 | 27 | namespace cutlass::gemm { 28 | using namespace cute; 29 | 30 | struct KernelPtrArrayCpAsyncWarpSpecializedCooperative {}; 31 | struct KernelPtrArrayCpAsyncWarpSpecializedPingpong {}; 32 | 33 | template < 34 | int Stages_, 35 | class ClusterShape_ = Shape<_1, _1, _1>, 36 | class KernelSchedule = KernelPtrArrayCpAsyncWarpSpecializedCooperative> 37 | struct MainloopSm90ArrayCpAsyncGmmaWarpSpecialized { 38 | constexpr static int Stages = Stages_; 39 | using ClusterShape = ClusterShape_; 40 | using ArchTag = arch::Sm90; 41 | using Schedule = KernelSchedule; 42 | static_assert( 43 | cute::is_base_of_v or 44 | cute::is_base_of_v, 45 | "KernelSchedule must be one of the Ptr-Array or Grouped Gemm Cp async Warp Specialized " 46 | "Cooperative or Pingpong policies"); 47 | }; 48 | 49 | } // namespace cutlass::gemm 50 | -------------------------------------------------------------------------------- /include/flux/ths_op/topo_utils.h: -------------------------------------------------------------------------------- 1 | //===- topo_utils.h ----------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "flux/ths_op/flux_shm.h" 20 | #include 21 | 22 | namespace bytedance::flux::topo_utils { 23 | 24 | bool is_topo_initialized(); 25 | /** 26 | * call this function multi times, you will got some warnings and only runs once really 27 | * @param group: this should be a local group. if not, split it to a local group from outside 28 | */ 29 | void initialize_topo(bytedance::flux::Group *pg); 30 | void initialize_topo(const std::vector &device_ids); 31 | 32 | // has any NV-link supported GPU exists 33 | bool has_nvlink(); 34 | // has NV-Switch(means all GPUS are connected to each other by NV-Link) 35 | bool has_nvswitch(); 36 | // has NVLink but not all-to-all connected. such as A100 PCI-e version with NVLink 37 | bool has_heterogeneous_nvlink(); 38 | // nvlink world_size. if with NV-Switch, this equals local world_size. otherwise, return the P2P 39 | // connected cluster size 40 | int topo_nvlink_local_world_size(); 41 | 42 | // has PIC-e but not under the same NUMA node 43 | bool has_heterogeneous_pcie(); 44 | // the Gpus under the same NUMA node 45 | // NOTE: same GPUs under different NUMA nodes are expected 46 | int topo_numa_local_world_size(); 47 | } // namespace bytedance::flux::topo_utils 48 | -------------------------------------------------------------------------------- /src/moe_gather_rs/ths_op/topk_scatter_reduce.cc: -------------------------------------------------------------------------------- 1 | //===- topk_scatter_reduce.cc ------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include 24 | 25 | #include "flux/ths_op/ths_op.h" 26 | #include "moe_gather_rs/moe_utils.h" 27 | 28 | namespace bytedance::flux::ths_op { 29 | using torch::Tensor; 30 | torch::Tensor 31 | topk_scatter_reduce(std::vector inputs, torch::Tensor scatter_idx, int64_t TOPK) { 32 | FLUX_CHECK(!inputs.empty()); 33 | int32_t topk = TOPK; 34 | int32_t M = inputs[0].size(0); 35 | int32_t N = inputs[0].size(1); 36 | int32_t new_M = M / topk; 37 | FLUX_CHECK_DIV(M, topk); 38 | FLUX_CHECK_EQ(scatter_idx.numel(), M); 39 | torch::Tensor output = torch::empty({new_M, inputs[0].size(1)}, inputs[0].options()); 40 | std::vector ptrs; 41 | for (int i = 0; i < inputs.size(); i++) { 42 | ptrs.push_back(inputs[i].data_ptr()); 43 | } 44 | auto data_type_enum = from_torch_dtype(output.scalar_type()); 45 | topk_reduce_scatter_impl( 46 | ptrs.data(), 47 | ptrs.size(), 48 | data_type_enum, 49 | scatter_idx.data_ptr(), 50 | topk, 51 | output.data_ptr(), 52 | new_M, 53 | N); 54 | return output; 55 | } 56 | 57 | } // namespace bytedance::flux::ths_op 58 | -------------------------------------------------------------------------------- /src/cuda/random_initialize.cu: -------------------------------------------------------------------------------- 1 | //===- random_initialize.cu --------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include "cute/container/tuple.hpp" 19 | #include "cutlass/numeric_conversion.h" 20 | #include "flux/cuda/cuda_common.h" 21 | #include "flux/flux.h" 22 | #include "cutlass/util/reference/device/tensor_fill.h" 23 | 24 | namespace bytedance::flux { 25 | 26 | void 27 | uniform_initialize( 28 | DataTypeEnum dtype, 29 | void *ptr, 30 | size_t capacity, 31 | uint64_t seed, 32 | double min, 33 | double max, 34 | void *stream = nullptr) { 35 | tuple_return_if( 36 | cute::make_tuple(_FP16{}, _BF16{}, _E4M3{}, _E5M2{}), 37 | [dtype](auto c_dtype) { return dtype == c_dtype; }, 38 | [&](auto c_dtype) { 39 | auto cu_stream = reinterpret_cast(stream); 40 | using Element = decltype(to_cutlass_element(c_dtype)); 41 | cutlass::NumericConverter converter; 42 | cutlass::reference::device::BlockFillRandomUniform( 43 | static_cast(ptr), 44 | capacity, 45 | seed, 46 | converter(max), 47 | converter(min), 48 | /*bits=*/-1, 49 | /*pnan=*/0, 50 | /*stream=*/cu_stream); 51 | }, 52 | [dtype]() { FLUX_CHECK(false) << "unsupported dtype: " << dtype; }); 53 | } 54 | 55 | } // namespace bytedance::flux 56 | -------------------------------------------------------------------------------- /include/flux/ths_op/util.h: -------------------------------------------------------------------------------- 1 | //===- util.h ----------------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | 20 | #include "flux/flux.h" 21 | 22 | #define CHECK_TYPE(x, st) FLUX_CHECK_EQ(x.scalar_type(), st) << "Inconsistency type of Tensor " #x 23 | #define CHECK_CPU(x) FLUX_CHECK(x.is_cpu()) << #x << " must be a cpu tensor" 24 | #define CHECK_CUDA(x) FLUX_CHECK(x.is_cuda()) << #x << " must be a CUDA tensor" 25 | #define CHECK_CONTIGUOUS(x) FLUX_CHECK(x.is_contiguous()) << #x << " must be contiguous" 26 | #define CHECK_INPUT(x, st) \ 27 | CHECK_CUDA(x); \ 28 | CHECK_CONTIGUOUS(x); \ 29 | CHECK_TYPE(x, st) 30 | #define CHECK_INPUT_LOOSE(x) \ 31 | CHECK_CUDA(x); \ 32 | CHECK_CONTIGUOUS(x) 33 | #define CHECK_NDIM(x, ndim) FLUX_CHECK_EQ((x).dim(), ndim) << "ndim check failed" 34 | #define CHECK_DIM(x, dim, sz) FLUX_CHECK_EQ(x.size(dim), sz) 35 | #define CHECK_1D(x, dim0) \ 36 | CHECK_NDIM(x, 1); \ 37 | CHECK_DIM(x, 0, (dim0)) 38 | #define CHECK_2D(x, dim0, dim1) \ 39 | CHECK_NDIM(x, 2); \ 40 | CHECK_DIM(x, 0, (dim0)); \ 41 | CHECK_DIM(x, 1, (dim1)) 42 | #define CHECK_3D(x, dim0, dim1, dim2) \ 43 | CHECK_NDIM(x, 3); \ 44 | CHECK_DIM(x, 0, (dim0)); \ 45 | CHECK_DIM(x, 1, (dim1)); \ 46 | CHECK_DIM(x, 2, (dim2)) 47 | #define CHECK_DIV(x, y) TORCH_CHECK(x % y == 0, #x, " % ", #y, " != 0") 48 | -------------------------------------------------------------------------------- /src/moe_gather_rs/workspace_helper.h: -------------------------------------------------------------------------------- 1 | 2 | //===- workspace_helper.h -------------------------------------- C++ ------===// 3 | // 4 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | // 17 | //===----------------------------------------------------------------------===// 18 | #pragma once 19 | #include 20 | #include 21 | 22 | #include "flux/args/moe_gather_rs.h" 23 | namespace bytedance::flux { 24 | 25 | struct MoeGatherRSWorkspaceArgs { 26 | int num_groups; 27 | int N_split; 28 | int ep_start; 29 | int ep_nexperts; 30 | int N, K; 31 | int32_t *splits_gpu; 32 | void *input[kMaxNumGroups]; 33 | void *weights[kMaxNumGroups]; 34 | void *output[kMaxNumGroups]; 35 | float *input_scales[kMaxNumGroups]; 36 | float *weight_scales[kMaxNumGroups]; 37 | }; 38 | 39 | /** 40 | 41 | workspace architecture 42 | 43 | problem_sizes, cutlass::gemm::GemmCoord *, problem_count 44 | ptr_A, void *, problem_count 45 | ptr_B, void *, problem_count 46 | ptr_C, void *, problem_count 47 | ptr_D, void *, problem_count 48 | lda, int64_t, problem_count 49 | ldb, int64_t, problem_count 50 | ldc, int64_t, problem_count 51 | ldd, int64_t, problem_count 52 | scale_A, float *, problem_count 53 | scale_B, float *, problem_count 54 | non_empty_problem_count, int, 1 55 | */ 56 | 57 | void make_workspace( 58 | const MoeGatherRSWorkspaceArgs &args, 59 | GemmLayoutEnum layout, 60 | int input_elem_size, 61 | int output_elem_size, 62 | void *workspace, 63 | cudaStream_t stream); 64 | } // namespace bytedance::flux 65 | -------------------------------------------------------------------------------- /src/comm_none/ths_op/gemm_only.h: -------------------------------------------------------------------------------- 1 | //===- gemm_only.h ----------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | 20 | #include 21 | #include "flux/ths_op/ths_op.h" 22 | 23 | namespace bytedance::flux::ths_op { 24 | class GemmOnly { 25 | public: 26 | GemmOnly( 27 | c10::ScalarType input_dtype, 28 | c10::ScalarType weight_dtype, 29 | c10::ScalarType output_dtype, 30 | bool transpose_weight, 31 | bool use_fp8_gemm); 32 | ~GemmOnly(); 33 | 34 | torch::Tensor forward( 35 | torch::Tensor input, 36 | torch::Tensor weight, 37 | c10::optional bias, 38 | c10::optional output_buf, 39 | c10::optional input_scale, 40 | c10::optional weight_scale, 41 | c10::optional output_scale, 42 | bool fast_accum); 43 | torch::Tensor profiling( 44 | torch::Tensor input, 45 | torch::Tensor weight, 46 | c10::optional bias, 47 | c10::optional output_buf, 48 | c10::optional input_scale, 49 | c10::optional weight_scale, 50 | c10::optional output_scale, 51 | bool fast_accum, 52 | c10::intrusive_ptr opt_ctx); 53 | 54 | private: 55 | class GemmOnlyImpl; 56 | GemmOnlyImpl *impl_ = nullptr; 57 | }; 58 | 59 | } // namespace bytedance::flux::ths_op 60 | -------------------------------------------------------------------------------- /src/a2a_transpose_gemm/ths_op/all_to_all_types.h: -------------------------------------------------------------------------------- 1 | //===- all_to_all_types.h ----------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "flux/ths_op/topo_utils.h" 20 | #include 21 | 22 | namespace bytedance::flux { 23 | // All2All for nvlink mode. for NVLINK machine, default is 0 24 | // Ring1D for 1d-ring. for PCI-e machine without GPUs cross NUMA nodes use ring 1d 25 | // Ring2D for 2d-ring. for PCI-e machine with GPUs cross NUMA nodes defaults to ring_2d 26 | enum class A2ARingMode { 27 | All2All = 0, 28 | Ring1D = 1, 29 | Ring2D = 2, 30 | }; 31 | 32 | namespace detail { 33 | template 34 | using optionally_optional = std::conditional_t, T>; 35 | 36 | template 37 | struct A2AOptionType { 38 | optionally_optional input_buffer_copied; 39 | optionally_optional use_cuda_core; 40 | optionally_optional fuse_sync; 41 | optionally_optional use_read; 42 | optionally_optional skip_barrier; 43 | optionally_optional return_comm_buf; 44 | optionally_optional mode; 45 | }; 46 | 47 | } // namespace detail 48 | 49 | using AllToAllOption = detail::A2AOptionType; 50 | using AllToAllOptionWithOptional = detail::A2AOptionType; 51 | 52 | inline A2ARingMode 53 | get_default_a2a_ring_mode() { 54 | return A2ARingMode::All2All; 55 | } 56 | 57 | } // namespace bytedance::flux 58 | -------------------------------------------------------------------------------- /src/ths_op/helper_ops.cc: -------------------------------------------------------------------------------- 1 | //===- helper_ops.cc ---------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include 19 | 20 | #include "flux/cuda/helper_kernels.h" 21 | #include "flux/ths_op/ths_op.h" 22 | 23 | namespace bytedance::flux::ths_op { 24 | using torch::Tensor; 25 | 26 | bool 27 | bitwise_check(torch::Tensor A, torch::Tensor B) { 28 | TORCH_CHECK(A.dim() == B.dim(), "Tensor dimension not matching! A:", A.dim(), " vs B:", B.dim()); 29 | return bitwise_check(from_torch_dtype(A.scalar_type()), A.data_ptr(), B.data_ptr(), A.numel()); 30 | } 31 | 32 | void 33 | uniform_initialize(torch::Tensor tensor, uint64_t seed, double min, double max) { 34 | cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); 35 | uniform_initialize( 36 | from_torch_dtype(tensor.scalar_type()), 37 | tensor.data_ptr(), 38 | tensor.numel(), 39 | seed, 40 | min, 41 | max, 42 | stream); 43 | } 44 | 45 | void 46 | cudaipc_barrier_all_on_stream( 47 | cudaStream_t stream, int rank, std::vector &sync_buffers, bool ring_mode) { 48 | std::vector sync_buffer_ptrs; 49 | int world_size = sync_buffers.size(); 50 | for (int i = 0; i < sync_buffers.size(); i++) { 51 | sync_buffer_ptrs.push_back(reinterpret_cast(sync_buffers[i].data_ptr())); 52 | } 53 | cudaipc_barrier_all_on_stream_impl(stream, sync_buffer_ptrs.data(), rank, world_size, ring_mode); 54 | } 55 | 56 | } // namespace bytedance::flux::ths_op 57 | -------------------------------------------------------------------------------- /include/flux/args/ag_gemm.h: -------------------------------------------------------------------------------- 1 | //===- ag_gemm.h -------------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | namespace bytedance::flux { 21 | 22 | struct AGKernelArguments { 23 | int m; 24 | int n; 25 | int k; 26 | int rank; 27 | int world_size; 28 | int nnodes; 29 | float alpha; 30 | float beta; 31 | void *input; 32 | void const *weight; 33 | void const *bias; 34 | void *output; 35 | void *barrier_buffer; 36 | }; 37 | 38 | struct AGS8KernelArguments { 39 | int m; 40 | int n; 41 | int k; 42 | int rank; 43 | int world_size; 44 | int nnodes; 45 | float alpha; 46 | float beta; 47 | void *A; 48 | void const *B; 49 | void const *bias; 50 | void *output; 51 | void const *scale_A; 52 | void const *scale_B; 53 | void *barrier_buffer; 54 | }; 55 | 56 | struct AGFP8KernelArguments { 57 | int m; 58 | int n; 59 | int k; 60 | int rank; 61 | int world_size; 62 | int nnodes; 63 | float alpha; 64 | float beta; 65 | void *A; // all gathered A, aka input_buffer 66 | void const *B; // weight 67 | void const *C; 68 | void *Aux = nullptr; 69 | void *D; // output 70 | void *barrier_buffer; 71 | void *Vector; // bias 72 | float *abs_max_Aux = nullptr; 73 | float *abs_max_D = nullptr; 74 | float const *scaleA; 75 | float const *scaleB; 76 | float const *scaleC; 77 | float const *scaleD = nullptr; 78 | float const *scaleAux = nullptr; 79 | }; 80 | 81 | } // namespace bytedance::flux 82 | -------------------------------------------------------------------------------- /cmake/FluxConfig.cmake: -------------------------------------------------------------------------------- 1 | # FindFlux 2 | # ------- 3 | # 4 | # Finds the Flux library 5 | # 6 | # This will define the following variables: 7 | # 8 | # FLUX_FOUND -- True if the system has the Flux library 9 | # FLUX_INCLUDE_DIRS -- The include directories for flux 10 | # FLUX_LIBRARIES -- Libraries to link against 11 | # FLUX_CXX_FLAGS -- Additional (required) compiler flags 12 | # 13 | # and the following imported targets: 14 | # 15 | # flux 16 | macro(append_fluxlib_if_found) 17 | foreach (_arg ${ARGN}) 18 | find_library(${_arg}_LIBRARY ${_arg} PATHS "${FLUX_INSTALL_PREFIX}/lib") 19 | if(${_arg}_LIBRARY) 20 | list(APPEND FLUX_LIBRARIES ${${_arg}_LIBRARY}) 21 | else() 22 | message(WARNING "library ${${_arg}_LIBRARY} not found.") 23 | endif() 24 | endforeach() 25 | endmacro() 26 | 27 | include(FindPackageHandleStandardArgs) 28 | 29 | if(DEFINED ENV{FLUX_INSTALL_PREFIX}) 30 | set(FLUX_INSTALL_PREFIX $ENV{FLUX_INSTALL_PREFIX}) 31 | else() 32 | # Assume we are in /share/cmake/Flux/FluxConfig.cmake 33 | get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) 34 | get_filename_component(FLUX_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) 35 | endif() 36 | 37 | # Include directories. 38 | set(FLUX_INCLUDE_DIRS 39 | ${FLUX_INSTALL_PREFIX}/include 40 | ${FLUX_INSTALL_PREFIX}/include/flux 41 | ) 42 | 43 | 44 | # Library dependencies. 45 | append_fluxlib_if_found(flux_cuda) 46 | append_fluxlib_if_found(flux_cuda_ths_op) 47 | if(EXISTS ${FLUX_INSTALL_PREFIX}/lib/libflux_triton_aot.so) 48 | append_fluxlib_if_found(flux_triton_aot) 49 | endif() 50 | 51 | # When we build libflux with the old libstdc++ ABI, dependent libraries must too. 52 | # if(CMAKE_SYSTEM_NAME STREQUAL "Linux") 53 | # set(FLUX_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=@GLIBCXX_USE_CXX11_ABI@") 54 | # endif() 55 | 56 | find_library(FLUX_LIBRARY flux_cuda_ths_op PATHS "${FLUX_INSTALL_PREFIX}/lib") 57 | # set_target_properties(flux PROPERTIES 58 | # INTERFACE_INCLUDE_DIRECTORIES "${FLUX_INCLUDE_DIRS}" 59 | # CXX_STANDARD 17 60 | # ) 61 | # if(FLUX_CXX_FLAGS) 62 | # set_property(TARGET flux PROPERTY INTERFACE_COMPILE_OPTIONS "${FLUX_CXX_FLAGS}") 63 | # endif() 64 | 65 | find_package_handle_standard_args(Flux DEFAULT_MSG FLUX_LIBRARY FLUX_INCLUDE_DIRS) 66 | -------------------------------------------------------------------------------- /src/a2a_transpose_gemm/post_attn_a2a_transpose_impls.hpp: -------------------------------------------------------------------------------- 1 | //===- post_attn_a2a_transpose_impls.hpp -------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "flux/flux.h" 20 | 21 | namespace bytedance { 22 | namespace flux { 23 | struct PostAttnAll2AllParams { 24 | void *input_ptr; 25 | void **output_ptrs; 26 | void **barrier_ptrs; 27 | int32_t **sync_barriers; 28 | int32_t bs; 29 | int32_t nheads; 30 | int32_t seq_len; 31 | int32_t head_dim; 32 | int32_t rank; 33 | int32_t world_size; 34 | int32_t TILE_M; 35 | 36 | int32_t num_comm_sm; 37 | int32_t *a2a_signal; 38 | 39 | int32_t cusum_seq_lens[kMaxWorldSize + 1]; // [world_size + 1, ] 40 | bool skip_barrier; 41 | }; 42 | 43 | enum class SyncMethod : int32_t { SyncNone = 0, SyncAtomic }; 44 | 45 | void post_attn_a2a_transpose_impl( 46 | const PostAttnAll2AllParams ¶m, 47 | DataTypeEnum input_dtype, 48 | SyncMethod sync_method, 49 | cudaStream_t stream); 50 | 51 | void post_attn_a2a_impl( 52 | const PostAttnAll2AllParams ¶ms, 53 | DataTypeEnum input_dtype, 54 | SyncMethod sync_method, 55 | cudaStream_t stream); 56 | 57 | void post_attn_a2a_dyn_impl( 58 | const PostAttnAll2AllParams ¶ms, 59 | DataTypeEnum input_dtype, 60 | SyncMethod sync_method, 61 | cudaStream_t stream); 62 | 63 | // num block of post_attn_a2a_impl is equal to post_attn_a2a_transpose_impl 64 | int32_t get_post_attn_all2all_transpose_block_num(int32_t bs, int32_t seq_len, int32_t tile_m); 65 | } // namespace flux 66 | } // namespace bytedance 67 | -------------------------------------------------------------------------------- /src/inplace_cast/inplace_cast.hpp: -------------------------------------------------------------------------------- 1 | //===- inplace_cast.hpp ------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "flux/flux.h" 20 | 21 | #define INPLACE_CAST_TS 8 22 | #define INPLACE_CAST_NUM_BLOCKS 184 23 | #define INPLACE_CAST_BLOCK_SIZE 256 24 | #define INPLACE_CAST_CHUNK_SIZE 2048 25 | // TS controls how many registers are used to hold data temporarily. 26 | // If too large, each block will use too many registers. This will reduce occupancy and so 27 | // significantly reduce performance. 28 | // If too small, each thread gets too few works to do and may increase scheduling overheads. 29 | // On L20, empirically, it was found that TS=8 and block_size=256 gives the best performance. TS=16 30 | // and block_size=256 will give slightly worse performance although both will have occupancy of 31 | // 100%. If occupancy is not 100%, the performance will be very bad. 32 | // To calculate occupancy, first, figure out the register usage per thread with 33 | // `cuobjdump -res-usage inplace_cast`. Then, input the register per thread and threads per block 34 | // (block size) into the occupancy calculator in nsight-compute. Please see 35 | // https://docs.nvidia.com/nsight-compute/NsightCompute/index.html#occupancy-calculator for 36 | // detailed usage. 37 | 38 | namespace bytedance { 39 | namespace flux { 40 | 41 | void inplace_cast_fp32_to_bf16_impl( 42 | void *data, 43 | size_t size, 44 | unsigned *flags, 45 | unsigned *counter, 46 | cudaStream_t stream, 47 | int num_blocks, 48 | int block_size); 49 | 50 | } 51 | } // namespace bytedance 52 | -------------------------------------------------------------------------------- /include/flux/args/moe_ag_scatter.h: -------------------------------------------------------------------------------- 1 | //===- moe_ag_scatter.h ------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "./comm_none.h" 20 | #include "flux/utils.h" 21 | 22 | namespace bytedance::flux { 23 | constexpr int kMaxNumGroups = 2; 24 | 25 | struct GemmGroupedAgScatterArguments : GemmGroupedV3Arguments { 26 | DistEnv dist_env; 27 | int ntokens; 28 | int h; 29 | void *nvshmem_input_buffer; 30 | int32_t const **gather_A; 31 | int32_t const **scatter_D; 32 | void const *problem_schedule; 33 | void *barrier_ptr = nullptr; 34 | int sm_margin = 0; 35 | }; 36 | 37 | struct GemmGroupedV2AGScatterArguments { 38 | int rank; 39 | int world_size; 40 | int sm_margin; 41 | 42 | int num_groups; // make sure num_groups <= kMaxNumGroups 43 | int ep_start; 44 | int ep_nexperts; 45 | void *input; // before gather_A 46 | void *weight[kMaxNumGroups]; // with groups 47 | void *output[kMaxNumGroups]; // with groups 48 | // FP8 arguments 49 | float *scaleD[kMaxNumGroups]; // with groups 50 | int M_this_ep, N, K; 51 | int lda, ldb, ldc, ldd; 52 | int *splits; 53 | // calculated on prepare workspace 54 | int32_t *gather_A; // on device memory expected 55 | int32_t *scatter_D; // on device memory expected 56 | void *problem_schedules; 57 | int num_problem_schedules; 58 | int *accum_per_rank_ptr; // on device memory expected 59 | int tile_size_m, tile_size_n; 60 | int *barrier_ptr; 61 | // fill inside op. only Op has the information 62 | float alpha = 1.f; 63 | float beta = 0.f; 64 | }; 65 | 66 | } // namespace bytedance::flux 67 | -------------------------------------------------------------------------------- /test/python/quantization/test_quantization.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ################################################################################ 17 | 18 | import argparse 19 | import torch 20 | import flux 21 | from flux.testing import DTYPE_MAP, init_seed 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("M", type=int) 27 | parser.add_argument("N", type=int) 28 | parser.add_argument("--num_streams", type=int, default=1) 29 | parser.add_argument( 30 | "--input_dtype", 31 | default="float8_e4m3fn", 32 | type=str, 33 | choices=["float8_e4m3fn", "float8_e5m2"], 34 | ) 35 | parser.add_argument( 36 | "--output_dtype", 37 | default="", 38 | type=str, 39 | choices=["float8_e4m3fn", "float8_e5m2", "bfloat16", "float32"], 40 | ) 41 | parser.add_argument( 42 | "--output_transpose", default=False, action="store_true", help="transpose output" 43 | ) 44 | 45 | return parser.parse_args() 46 | 47 | 48 | if __name__ == "__main__": 49 | init_seed() 50 | args = parse_args() 51 | input_dtype = DTYPE_MAP[args.input_dtype] 52 | output_dtype = DTYPE_MAP[args.output_dtype] 53 | 54 | flux_op = flux.Quantization( 55 | input_dtype=input_dtype, 56 | output_dtype=output_dtype, 57 | num_streams=args.num_streams, 58 | ) 59 | 60 | input_tensor = torch.rand([args.M, args.N], dtype=input_dtype).cuda() 61 | flux_out = flux_op.quantize_square_blockwise(input_tensor, args.output_transpose) 62 | ## TODO: add reference implementation 63 | flux.bitwise_check(flux_out[0], flux_out[0]) 64 | -------------------------------------------------------------------------------- /python/flux_triton/tools/runtime/triton_aot_runtime.h: -------------------------------------------------------------------------------- 1 | //===- triton_aot_runtime.h ------------------------------------- C++ ------===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | #pragma once 18 | #include 19 | // CUDA 12.0+ has CUDA context independent module loading. but what about CUDA 11.8 20 | // https://developer.nvidia.com/blog/cuda-context-independent-module-loading/ 21 | #ifdef __cplusplus 22 | extern "C" { 23 | #endif 24 | 25 | // CUDA driver stubs to avoid direct dependency on libcuda.so 26 | CUresult cuGetErrorString_stub(CUresult error, const char **pStr); 27 | CUresult cuDeviceGetAttribute_stub(int *pi, CUdevice_attribute attrib, CUdevice dev); 28 | 29 | // CUDA patch for Multiple CUDA context support: using any CUDA context 30 | typedef struct CUDAModule *CUDAModuleHandle; 31 | typedef struct CUDAFunction *CUDAFunctionHandle; 32 | 33 | CUresult CUDAModuleLoadData(CUDAModuleHandle *module, const void *image); 34 | 35 | CUresult CUDAModuleUnload(CUDAModuleHandle module); 36 | 37 | CUresult CUDAModuleGetFunction(CUDAFunctionHandle *hfunc, CUDAModuleHandle hmod, const char *name); 38 | 39 | CUresult CUDALaunchKernel( 40 | CUDAFunctionHandle f, 41 | unsigned int gridDimX, 42 | unsigned int gridDimY, 43 | unsigned int gridDimZ, 44 | unsigned int blockDimX, 45 | unsigned int blockDimY, 46 | unsigned int blockDimZ, 47 | unsigned int sharedMemBytes, 48 | CUstream hStream, 49 | void **kernelParams, 50 | void **extra); 51 | 52 | CUresult CUDAFuncSetAttribute(CUDAFunctionHandle func, CUfunction_attribute attrib, int value); 53 | 54 | CUresult CUDAFuncSetCacheConfig(CUDAFunctionHandle func, CUfunc_cache config); 55 | 56 | #ifdef __cplusplus 57 | } 58 | #endif 59 | -------------------------------------------------------------------------------- /src/coll/ths_op/all_gather_op.h: -------------------------------------------------------------------------------- 1 | //===- all_gather_op.h -------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | #pragma once 18 | 19 | #include "all_gather_types.h" 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | namespace bytedance::flux::ths_op { 29 | /** AllGather is not as common as it seems: 30 | * 1. it's actually with business logic, such as input_buffer and input_scale_buffer, which is for 31 | * int8 GEMM all-gather 32 | */ 33 | 34 | class AllGatherOp { 35 | public: 36 | AllGatherOp( 37 | std::shared_ptr tp_group, 38 | int nnodes, 39 | size_t max_m, 40 | size_t k, 41 | at::ScalarType input_dtype); 42 | 43 | ~AllGatherOp(); 44 | 45 | void run_with_optional_options( 46 | torch::Tensor input, 47 | c10::optional input_scale, 48 | const AllGatherOptionWithOptional &opt, 49 | cudaStream_t stream); 50 | 51 | void run( 52 | const torch::Tensor &input, 53 | c10::optional input_scale, 54 | const AllGatherOption &opt, 55 | cudaStream_t stream); 56 | 57 | // only provide local tensor 58 | torch::Tensor local_input_buffer(); 59 | torch::Tensor local_input_scale_buffer(); 60 | torch::Tensor local_barrier_buffer(); 61 | 62 | int32_t *ag_signal_ptr() const; 63 | 64 | cudaEvent_t &get_local_prepare_event(); 65 | 66 | private: 67 | class AllGatherOpImpl; 68 | AllGatherOpImpl *impl_; 69 | }; 70 | } // namespace bytedance::flux::ths_op 71 | -------------------------------------------------------------------------------- /include/flux/gemm_operator_base.h: -------------------------------------------------------------------------------- 1 | //===- gemm_operator_base.h --------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include "flux/flux.h" 21 | #include "flux/gemm_hparams.h" 22 | #include 23 | #include 24 | 25 | namespace bytedance::flux { 26 | 27 | struct GemmOperatorBase { 28 | public: 29 | FLUX_DEFINE_DEFAULT_SPECIAL_FUNCS(GemmOperatorBase) 30 | 31 | virtual ~GemmOperatorBase() = default; 32 | virtual void run( 33 | std::any const &args, 34 | void *workspace = nullptr, 35 | void *stream = nullptr, 36 | bool launch_with_pdl = false) = 0; 37 | 38 | virtual void run(void *stream = nullptr, bool launch_with_pdl = false) = 0; 39 | 40 | // for device allocation required by Gemm::Arguments 41 | virtual size_t 42 | get_args_workspace_size(std::any const &args) const { 43 | return 0; 44 | } 45 | 46 | virtual void 47 | initialize_args_workspace( 48 | std::any const &args, void *args_workspace = nullptr, void *stream = nullptr) const { 49 | // noop 50 | } 51 | 52 | virtual void initialize( 53 | std::any const &args, void *workspace = nullptr, void *stream = nullptr) = 0; 54 | 55 | // total workspace: args_workspace + gemm workspace 56 | virtual size_t 57 | get_workspace_size(std::any const &args) const { 58 | return 0; 59 | } 60 | 61 | virtual size_t 62 | get_barrier_workspace_size(std::any const &args) const { 63 | return 0; 64 | } 65 | 66 | virtual UnifiedGemmHParams 67 | get_runtime_gemm_hparams() const { 68 | throw std::logic_error("get_runtime_gemm_hparams not implemented"); 69 | } 70 | }; 71 | 72 | } // namespace bytedance::flux 73 | -------------------------------------------------------------------------------- /include/flux/args/gemm_rs.h: -------------------------------------------------------------------------------- 1 | //===- gemm_rs.h -------------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | namespace bytedance::flux { 20 | 21 | struct ReduceScatterArguments { 22 | int reduce_scatter_num_blocks = 12; 23 | void *rs_stream = nullptr; 24 | void *event = nullptr; 25 | bool use_barrier_queue = false; 26 | bool use_gemmk = true; // use gemmk mechanism 27 | bool per_tile_flags = true; // set flag per tile 28 | bool use_cudaMemcpyAsync = false; // use cudaMemcpyAsync for memcpy or not 29 | int n_split = 1; // if also split n 30 | int sub_world_size = 1; 31 | void *opaque = nullptr; // used to pass ncclComm_t for PCI-e cross node 32 | bool use_1d_ring = true; 33 | bool use_p2p_read = true; 34 | }; 35 | 36 | struct GemmReduceScatterArguments { 37 | int m; 38 | int n; 39 | int k; 40 | int rank; 41 | int world_size; 42 | int nnodes; 43 | float alpha; 44 | float beta; 45 | void const *input; 46 | void const *weight; 47 | void const *bias; 48 | void **output_scatter_ptrs; 49 | void **reduce_buffer_ptrs; 50 | void **barrier_ptrs = nullptr; 51 | int avail_sms = -1; 52 | 53 | void *Aux = nullptr; // m * n 54 | void *Vector = nullptr; // bias: 1 * n 55 | float *abs_max_Aux = nullptr; 56 | float *abs_max_D = nullptr; 57 | // scaling tensors 58 | float const *scaleA = nullptr; 59 | float const *scaleB = nullptr; 60 | float const *scaleC = nullptr; 61 | float const *scaleD = nullptr; // require if D is fp8 62 | float const *scaleAux = nullptr; // require if Aux is fp8 63 | ReduceScatterArguments reduce_scatter_args; 64 | }; 65 | 66 | } // namespace bytedance::flux 67 | -------------------------------------------------------------------------------- /docs/FAQ.md: -------------------------------------------------------------------------------- 1 | ### FAQ (Frequently Asked Questions) 2 | 3 | #### Common questions: 4 | 5 | 1. **Q:** What kernels do Flux support and how are them named? 6 | 7 | **A:** Flux mainly supports the following kernels: 8 | - Dense MLP layer0 (AllGather + GEMM) in `src/ag_gemm` 9 | - Dense MLP layer1 (GEMM + ReduceScatter) in `src/gemm_rs` 10 | - MoE layer0 (AllGather + Scatter + GroupGEMM) in `src/moe_ag_scatter` 11 | - MoE layer1 (GroupGEMM + Gather + Topk-reduce + ReduceScatter) in `src/moe_gather_rs` 12 | 13 | Flux supports MoE kernels with tensor parallelism/expert parallelism/tensor+expert parallelism. You can get a minimal example of a MoE layer with EP=4 in `examples/moe_flux_only.py` (Note that sequence parallelism is enabled and the ffn_tp_size is 2). There is also an illustration as `docs/assets/toy_example.png` for this toy example to help you understand the workflow in this TP+EP MoE case better. In this case, the communication of EP is also overlapped by GroupGEMM in Flux's implementation. 14 | Detailed information about the kernels can be found in the [Design Guide](https://github.com/bytedance/flux/blob/main/docs/design.md). 15 | 16 | #### Connection problems 17 | 18 | 1. **Q:** The NCCL/NVSHMEM connection hangs/fails when initializing Flux. 19 | 20 | **A:** If you encounter a NCCL/NVSHMEM connection problem, that may be the problem of the network configurations inside the `launch.sh` script. A possible solution is to export a proper `NCCL_SOCKET_IFNAME` variable manually. Try to set it to the first name you get from `ifconfig` (e.g., `export NCCL_SOCKET_IFNAME=bond0`). 21 | 22 | 23 | If you still cannot establish connection, try setting some more environment variables in the launch script which will describe your network configuration better: 24 | 25 | - `export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME=bond0` 26 | - `export NVSHMEM_IB_ADDR_FAMILY=AF_INET6` 27 | - `export NVSHMEM_SYMMETRIC_SIZE=10000000000` 28 | 29 | #### Installation problems 30 | 31 | 1. **Q:** The installation takes too long. 32 | 33 | **A:** Set `export OMP_NUM_THREADS=128` before installation, higher thread num may incur higher compiling speed. 34 | 35 | #### Performance problems 36 | 37 | 1. **Q:** The performance of kernels is not as good as I expected. 38 | 39 | **A:** You may need to tune the kernels because the performance of kernels can vary on different hardwares and with different shapes. About how to tune the kernels, please refer to [tuning guide](https://github.com/bytedance/flux/blob/main/docs/tuning_guide.md). -------------------------------------------------------------------------------- /test/python/util/test_flux_ring_barrier.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ################################################################################ 17 | 18 | # usage: torchrun --node_rank=0 --nproc_per_node=8 --nnodes=1 --rdzv_id=none --master_addr=127.0.0.1 --master_port=23456 test/test_flux_ring_barrier.py 19 | import argparse 20 | import time 21 | 22 | import torch 23 | import torch.distributed 24 | 25 | import flux 26 | from flux.testing import initialize_distributed 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--ring_mode", default=False, action="store_true") 32 | parser.add_argument("--iters", type=int, default=1000) 33 | parser.add_argument( 34 | "--profile", default=False, action="store_true", help="dump torch.profiler.profile" 35 | ) 36 | return parser.parse_args() 37 | 38 | 39 | if __name__ == "__main__": 40 | args = parse_args() 41 | TP_GROUP = initialize_distributed() 42 | RANK, WORLD_SIZE, NNODES = TP_GROUP.rank(), TP_GROUP.size(), flux.testing.NNODES() 43 | LOCAL_WORLD_SIZE = WORLD_SIZE // NNODES 44 | LOCAL_RANK = RANK % LOCAL_WORLD_SIZE 45 | 46 | ctx = flux.util.get_torch_prof_ctx(args.profile) 47 | 48 | print(f"GroupBarrier: {args.ring_mode}", flush=True) 49 | group_barrier = flux.GroupBarrier(TP_GROUP, args.ring_mode) 50 | 51 | TP_GROUP.barrier() 52 | t1 = time.time() 53 | torch.cuda.synchronize() 54 | stream = torch.cuda.current_stream() 55 | with ctx: 56 | for i in range(args.iters): 57 | group_barrier.barrier_all(stream.cuda_stream) 58 | torch.cuda.synchronize() 59 | t2 = time.time() 60 | print("Done in {t2 - t1}s", flush=True) 61 | if args.profile: 62 | ctx.export_chrome_trace(f"prof/trace_rank{TP_GROUP.rank()}.json.gz") 63 | -------------------------------------------------------------------------------- /src/a2a_transpose_gemm/tuning_config/config_a2a_transpose_gemm_kernel_sm90_H800_tp8_nnodes1.cu: -------------------------------------------------------------------------------- 1 | //===- config_a2a_transpose_gemm_kernel_sm90_H800_tp8_nnodes1.cu ------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | // clang-format off 19 | #include "flux/op_registry.h" 20 | namespace bytedance::flux { 21 | using namespace cute; 22 | 23 | static int config_a2a_transpose_gemm_kernel_sm90_h800_tp8_nnodes1 = []() { 24 | auto &inst = TuningConfigRegistry::instance(); 25 | inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}(),_FP32{}()),_Sm90{}(),_H800{}(),_PostAttnAllToAllTranspose{}(),_RCR{}(),_GemmV3{}(),make_gemm_v3_meta(false),None{}),make_runtime_config(2048,4096,8192,make_all_to_all_transpose_runtime_config(8,1)),make_gemm_hparams(make_gemm_v3_hparams(cute::make_tuple(2l,1l,1l),_Cooperative{}()),None{},cute::make_tuple(128l,128l,64l),_GemmDefault{}(),0,_RasterAlongN{}())); 26 | inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}(),_FP32{}()),_Sm90{}(),_H800{}(),_PostAttnAllToAllOnly{}(),_RCR{}(),_GemmV3{}(),make_gemm_v3_meta(false),None{}),make_runtime_config(2048,4096,8192,make_all_to_all_transpose_runtime_config(8,1)),make_gemm_hparams(make_gemm_v3_hparams(cute::make_tuple(2l,1l,1l),_Cooperative{}()),None{},cute::make_tuple(128l,128l,64l),_GemmDefault{}(),0,_RasterAlongN{}())); 27 | inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}(),_FP32{}()),_Sm90{}(),_H800{}(),_PostAttnAllToAllOnly{}(),_RCR{}(),_GemmV3{}(),make_gemm_v3_meta(false),None{}),make_runtime_config(8192,6144,12288,make_all_to_all_transpose_runtime_config(8,1,0)),make_gemm_hparams(make_gemm_v3_hparams(cute::make_tuple(2l,1l,1l),_Cooperative{}()),None{},cute::make_tuple(256l,128l,64l),_GemmStreamK{}(),0,_RasterAlongN{}())); 28 | return 0; 29 | }(); 30 | } 31 | // clang-format on 32 | -------------------------------------------------------------------------------- /src/pybind/quantization.cc: -------------------------------------------------------------------------------- 1 | //===- quantization.cc -------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include "quantization/ths_op/quantization.h" 19 | 20 | #include "comm_none/ths_op/gemm_only.h" 21 | #include "flux/ths_op/ths_pybind.h" 22 | 23 | namespace bytedance::flux::ths_op { 24 | 25 | namespace py = pybind11; 26 | using QuantizationCls = TorchClassWrapper; 27 | 28 | static int _ [[maybe_unused]] = []() { 29 | ThsOpsInitRegistry::instance().register_one("quantization", [](py::module &m) { 30 | py::class_(m, "Quantization") 31 | .def( 32 | py::init([](torch::ScalarType input_dtype, 33 | torch::ScalarType output_dtype, 34 | int32_t num_streams) { 35 | return new Quantization(input_dtype, output_dtype, num_streams); 36 | }), 37 | py::arg("input_dtype"), 38 | py::arg("output_dtype"), 39 | py::arg("num_streams") = 2) 40 | .def( 41 | "quantize_vector_blockwise", 42 | &Quantization::quantize_vector_blockwise, 43 | py::arg("input"), 44 | py::arg("return_tranpose") = true, 45 | py::arg("eps") = 0.0) 46 | .def( 47 | "quantize_square_blockwise", 48 | &Quantization::quantize_square_blockwise, 49 | py::arg("input"), 50 | py::arg("return_tranpose") = true, 51 | py::arg("eps") = 0.0) 52 | .def( 53 | "batch_quantize_square_blockwise", 54 | &Quantization::batch_quantize_square_blockwise, 55 | py::arg("input"), 56 | py::arg("return_tranpose") = true, 57 | py::arg("eps") = 0.0); 58 | }); 59 | return 0; 60 | }(); 61 | } // namespace bytedance::flux::ths_op 62 | -------------------------------------------------------------------------------- /python/flux_triton/tools/compile/compile.c: -------------------------------------------------------------------------------- 1 | /* clang-format off */ 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "triton_aot_runtime.h" 9 | 10 | // helpers to check for cuda errors 11 | #define CUDA_CHECK(ans) {{\ 12 | gpuAssert((ans), __FILE__, __LINE__);\ 13 | }}\ 14 | 15 | static inline void gpuAssert(CUresult code, const char *file, int line) {{ 16 | if (code != CUDA_SUCCESS) {{ 17 | const char *prefix = "Triton Error [CUDA]: "; 18 | const char *str; 19 | cuGetErrorString_stub(code, &str); 20 | char err[1024] = {{0}}; 21 | strcat(err, prefix); 22 | strcat(err, str); 23 | printf("%s\\n", err); 24 | exit(code); 25 | }} 26 | }} 27 | 28 | // globals 29 | #define CUBIN_NAME {kernel_name}_cubin 30 | CUDAModuleHandle {kernel_name}_mod = NULL; 31 | CUDAFunctionHandle {kernel_name}_func = NULL; 32 | unsigned char CUBIN_NAME[{bin_size}] = {{ {bin_data} }}; 33 | 34 | 35 | void unload_{kernel_name}(void) {{ 36 | CUDA_CHECK(CUDAModuleUnload({kernel_name}_mod)); 37 | }} 38 | 39 | // TODO: some code duplication with `runtime/backend/cuda.c` 40 | void load_{kernel_name}() {{ 41 | int dev = 0; 42 | void *bin = (void *)&CUBIN_NAME; 43 | int shared = {shared}; 44 | CUDA_CHECK(CUDAModuleLoadData(&{kernel_name}_mod, bin)); 45 | CUDA_CHECK(CUDAModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}")); 46 | // set dynamic shared memory if necessary 47 | int shared_optin; 48 | CUDA_CHECK(cuDeviceGetAttribute_stub(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev)); 49 | if (shared > 49152 && shared_optin > 49152) {{ 50 | CUDA_CHECK(CUDAFuncSetCacheConfig({kernel_name}_func, CU_FUNC_CACHE_PREFER_SHARED)); 51 | CUDA_CHECK(CUDAFuncSetAttribute({kernel_name}_func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin)) 52 | }} 53 | }} 54 | 55 | /* 56 | {kernel_docstring} 57 | */ 58 | CUresult {kernel_name}(CUstream stream, {signature}) {{ 59 | if ({kernel_name}_func == NULL) 60 | load_{kernel_name}(); 61 | unsigned int gX = {gridX}; 62 | unsigned int gY = {gridY}; 63 | unsigned int gZ = {gridZ}; 64 | void *args[{num_args}] = {{ {arg_pointers} }}; 65 | // TODO: shared memory 66 | if(gX * gY * gZ > 0) 67 | return CUDALaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * 32, 1, 1, {shared}, stream, args, NULL); 68 | 69 | fprintf(stderr, "invalid grid size: %d, %d, %d\n", gX, gY, gZ); 70 | return CUDA_ERROR_INVALID_VALUE; 71 | }} 72 | -------------------------------------------------------------------------------- /include/flux/cuda/reduce_utils.cuh: -------------------------------------------------------------------------------- 1 | //===- reduce_utils.cuh ----------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | 21 | namespace bytedance::flux { 22 | 23 | constexpr int kWarpSize = 32; 24 | 25 | // Input: each thread in a warp call this function 26 | // with its `id` (lane_id) and its `count` to be summed. 27 | // Output: each thread get a presum of all threads' `count` 28 | // that have `id` less than or equal to its own `id` 29 | template 30 | __inline__ __device__ T 31 | warp_prefix_sum(int id, T count) { 32 | for (int i = 1; i < kWarpSize; i <<= 1) { 33 | T val = __shfl_up_sync(0xffffffff, count, i); 34 | if (id >= i) 35 | count += val; 36 | } 37 | return count; 38 | } 39 | 40 | template 41 | __inline__ __device__ void 42 | aligned_block_prefix_sum_and_sync(const T *data_in, T *data_out, int count, int align) { 43 | int warp_idx = threadIdx.x / kWarpSize; 44 | int lane_idx = threadIdx.x % kWarpSize; 45 | if (warp_idx == 0) { 46 | int cur_offset = 0; 47 | int count_pad = (count + kWarpSize - 1) / kWarpSize * kWarpSize; 48 | for (int i = lane_idx; i < count_pad; i += kWarpSize) { 49 | int len = i < count ? data_in[i] : 0; 50 | len = (len + align - 1) / align * align; 51 | int temp_offset = warp_prefix_sum(threadIdx.x, len); 52 | if (i < count) { 53 | data_out[i] = cur_offset + temp_offset; 54 | } 55 | cur_offset += __shfl_sync(0xffffffff, temp_offset, kWarpSize - 1); 56 | } 57 | } 58 | __syncthreads(); 59 | } 60 | 61 | template 62 | __inline__ __device__ void 63 | block_prefix_sum_and_sync(const T *data_in, T *data_out, int count) { 64 | aligned_block_prefix_sum_and_sync(data_in, data_out, count, 1); 65 | } 66 | 67 | } // namespace bytedance::flux 68 | -------------------------------------------------------------------------------- /src/coll/ths_op/all_gather_types.h: -------------------------------------------------------------------------------- 1 | //===- all_gather_types.h ---------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "flux/ths_op/topo_utils.h" 20 | #include 21 | 22 | namespace bytedance::flux { 23 | // All2All for nvlink mode. for NVLINK machine, default is 0 24 | // Ring1D for 1d-ring. for PCI-e machine without GPUs cross NUMA nodes use ring 1d 25 | // Ring2D for 2d-ring. for PCI-e machine with GPUs cross NUMA nodes defaults to ring_2d 26 | enum class AGRingMode { 27 | All2All = 0, 28 | Ring1D = 1, 29 | Ring2D = 2, 30 | }; 31 | namespace detail { 32 | template 33 | using optionally_optional = std::conditional_t, T>; 34 | 35 | template 36 | struct AllGatherOptionType { 37 | optionally_optional input_buffer_copied; 38 | optionally_optional use_cuda_core_local; 39 | optionally_optional use_cuda_core_ag; 40 | optionally_optional fuse_sync; // only valid when use_cuda_core_local=True 41 | optionally_optional use_read; 42 | optionally_optional mode; 43 | }; 44 | 45 | } // namespace detail 46 | 47 | using AllGatherOption = detail::AllGatherOptionType; 48 | using AllGatherOptionWithOptional = detail::AllGatherOptionType; 49 | 50 | static const int kNumaWorldSize = 4; 51 | 52 | inline AGRingMode 53 | get_default_ag_ring_mode() { 54 | if (topo_utils::has_nvswitch()) { 55 | return AGRingMode::All2All; 56 | } 57 | 58 | if (topo_utils::has_heterogeneous_pcie()) { 59 | if (topo_utils::topo_numa_local_world_size() != kNumaWorldSize) { 60 | return AGRingMode::Ring1D; // PCI-e ring mode with no optimization 61 | } 62 | return AGRingMode::Ring2D; 63 | } 64 | return AGRingMode::Ring1D; 65 | } 66 | 67 | } // namespace bytedance::flux 68 | -------------------------------------------------------------------------------- /src/coll/ths_op/reduce_scatter_op.h: -------------------------------------------------------------------------------- 1 | //===- reduce_scatter_op.h ---------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include "flux/ths_op/topo_utils.h" 20 | #include 21 | 22 | namespace bytedance::flux { 23 | // All2All for nvlink mode. for NVLINK machine, default is 0 24 | // Ring1D for 1d-ring. for PCI-e machine without GPUs cross NUMA nodes use ring 1d 25 | // Ring2D for 2d-ring. for PCI-e machine with GPUs cross NUMA nodes defaults to ring_2d 26 | enum class RingMode { 27 | All2All = 0, 28 | Ring1D = 1, 29 | Ring2D = 2, 30 | }; 31 | namespace detail { 32 | template 33 | using optionally_optional = std::conditional_t, T>; 34 | 35 | template 36 | struct ReduceScatterOptionType { 37 | optionally_optional use_barrier_queue; 38 | optionally_optional use_1d_ring; 39 | optionally_optional use_p2p_read; 40 | optionally_optional use_cudaMemcpyAsync; 41 | optionally_optional use_gemmk; 42 | optionally_optional per_tile_flags; 43 | optionally_optional n_split; 44 | optionally_optional num_blocks; 45 | optionally_optional ring_mode; 46 | }; 47 | 48 | } // namespace detail 49 | 50 | using ReduceScatterOption = detail::ReduceScatterOptionType; 51 | using ReduceScatterOptionWithOptional = detail::ReduceScatterOptionType; 52 | 53 | inline RingMode 54 | get_default_rs_ring_mode() { 55 | if (topo_utils::has_nvswitch()) { 56 | return RingMode::All2All; 57 | } 58 | static const int kNumaWorldSize = 4; 59 | 60 | if (topo_utils::has_heterogeneous_pcie()) { 61 | if (topo_utils::topo_numa_local_world_size() != kNumaWorldSize) { 62 | return RingMode::Ring1D; // PCI-e ring mode with no optimization 63 | } 64 | return RingMode::Ring2D; 65 | } 66 | return RingMode::Ring1D; 67 | } 68 | 69 | } // namespace bytedance::flux 70 | -------------------------------------------------------------------------------- /include/flux/cuda/cuda_stub.h: -------------------------------------------------------------------------------- 1 | //===- cuda_stub.h ----------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | #pragma once 18 | 19 | /*! 20 | * \file cuda_stub.h 21 | * \brief CUDA stub to avoid direct CUDA driver call 22 | */ 23 | #pragma once 24 | 25 | #include 26 | 27 | namespace bytedance::flux { 28 | 29 | #define FLUX_FORALL_CUDA(_) \ 30 | _(cuDeviceGetName) \ 31 | _(cuGetErrorString) \ 32 | _(cuGetErrorName) \ 33 | _(cuStreamWaitValue32_v2) \ 34 | _(cuStreamWriteValue32_v2) \ 35 | _(cuStreamWaitValue64_v2) \ 36 | _(cuStreamWriteValue64_v2) \ 37 | _(cuStreamBatchMemOp_v2) \ 38 | _(cuCtxGetDevice) 39 | 40 | extern "C" { 41 | using CUDA = struct CUDA { 42 | #define CREATE_MEMBER(name) decltype(&(name)) name; 43 | FLUX_FORALL_CUDA(CREATE_MEMBER) 44 | #undef CREATE_MEMBER 45 | }; 46 | } 47 | 48 | CUDA &cuda_stub(); 49 | namespace { 50 | const char * 51 | get_cu_error_string(CUresult statuse) { 52 | const char *msg; 53 | if (cuda_stub().cuGetErrorString(statuse, &msg) == CUDA_SUCCESS) { 54 | return msg; 55 | } else { 56 | return "unknown error"; 57 | } 58 | } 59 | } // namespace 60 | 61 | #define CU_CHECK(status) \ 62 | do { \ 63 | CUresult error = status; \ 64 | FLUX_CHECK(error == CUDA_SUCCESS) << "Got bad cuda status: " << get_cu_error_string(error) \ 65 | << "(" << error << ") at " #status; \ 66 | } while (0) 67 | 68 | CUresult CUStreamWaitValue( 69 | CUstream stream, CUdeviceptr addr, cuuint32_t value, unsigned int flags); 70 | CUresult CUStreamWriteValue( 71 | CUstream stream, CUdeviceptr addr, cuuint32_t value, unsigned int flags); 72 | 73 | } // namespace bytedance::flux 74 | -------------------------------------------------------------------------------- /docs/mlsys_comet_ae.md: -------------------------------------------------------------------------------- 1 | # Guide for Comet (MLSys25 Artifact Evaluation) 2 | This git repo (Flux) contains the components for the paper "Comet: Fine-grained Computation-communication Overlapping for Mixture-of-Experts". The main code of Comet is located in the `src` directory. In detail, the implementation of MoE layer0 is in `src/moe_ag_scatter` and the implementation of MoE layer1 is in `src/moe_gather_rs`. 3 | 4 | 5 | ## Quick installation and test 6 | Hardware requirements - A single server with 8 Nvidia GPUs (Hopper/Ada Lovelace/Ampere). We recommend to use H100/H800. 7 | Software requirements - Please prepare as the following steps: 8 | 9 | ```bash 10 | 11 | # Quick installation 12 | conda create -n comet_ae python=3.11 -y 13 | conda activate comet_ae 14 | pip3 install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 15 | pip install byte-flux==1.1.1 16 | 17 | # Quick test 18 | git clone https://github.com/bytedance/flux.git && cd flux/examples 19 | bash run_moe.sh 20 | ``` 21 | 22 | The successful running of the above command will prove the usability of the code. 23 | 24 | ## Measure the MoE layer and E2E model latency 25 | Next, we provide a guide to measure the latency of MoE layer and E2E model using Comet with Megatron-LM, on a node with 8 GPUs. 26 | 27 | ### Prepare the environment 28 | ```bash 29 | # Under your workspace 30 | # Install some basic dependencies 31 | 32 | pip3 install flash-attn --no-build-isolation 33 | pip3 install transformer_engine[pytorch] 34 | git clone https://github.com/NVIDIA/apex 35 | pushd apex 36 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ 37 | popd 38 | pip install git+https://github.com/fanshiqing/grouped_gemm@v1.0 39 | pip install regex six pyyaml 40 | 41 | # Megatron-LM 42 | git clone https://github.com/ZSL98/Megatron-LM.git 43 | # fastmoe (baseline1) 44 | git clone https://github.com/ZSL98/fastmoe.git && cd fastmoe && python setup.py install && pip install dm-tree && cd .. 45 | # tutel (baseline2) 46 | git clone https://github.com/ZSL98/tutel && cd tutel && python setup.py install && cd .. 47 | 48 | ``` 49 | 50 | ### Run the tests 51 | 52 | ```bash 53 | cd Megatron-LM 54 | bash ./grid_test.sh # Record the single MoE layer results to timing_results.csv 55 | bash ./e2e_grid_test.sh # Record the e2e model results to e2e_timing_results.csv 56 | ``` 57 | You can modify the parameters such as `NUM_TOKENS` and `EXPERT_NUM` to see the results under different configurations. The scripts' output can be found in `Megatron-LM/timing_results.csv` and `Megatron-LM/e2e_timing_results.csv`. The feasibility of the scripts has been tested on both 8 L20 GPUs and 8 H100 GPUs. 58 | -------------------------------------------------------------------------------- /test/unit/test_cuda_common.cu: -------------------------------------------------------------------------------- 1 | //===- test_cuda_common.cu ---------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | #include "cutlass/util/device_memory.h" 18 | #include "flux/flux.h" 19 | #include "flux/cuda/cuda_common.h" 20 | 21 | namespace bytedance::flux { 22 | 23 | void 24 | test_copy_continous_aligned(int elems, bool alignment = true) { 25 | std::vector src(elems), dst(elems); 26 | int32_t *src_ptr = src.data(), *dst_ptr = dst.data(); 27 | 28 | constexpr int kAlignment = sizeof(uint4); 29 | 30 | cutlass::DeviceAllocation src_d(elems + (alignment ? 0 : 1)), dst_d(elems); 31 | int32_t *src_dptr = src_d.get(), *dst_dptr = dst_d.get(); 32 | // suppose allocation is aligned 33 | FLUX_CHECK_DIV(intptr_t(src_dptr), kAlignment); 34 | FLUX_CHECK_DIV(intptr_t(dst_dptr), kAlignment); 35 | if (!alignment) { 36 | src_dptr += 1; // force not aligned 37 | } 38 | 39 | cudaStream_t stream = nullptr; 40 | for (int i = 0; i < elems; i++) { 41 | src_ptr[i] = i + 1; 42 | } 43 | CUDA_CHECK( 44 | cudaMemcpyAsync(src_dptr, src_ptr, elems * sizeof(int32_t), cudaMemcpyHostToDevice, stream)); 45 | 46 | copy_continous_aligned(dst_dptr, src_dptr, elems * sizeof(int32_t), 1, 512, stream); 47 | 48 | CUDA_CHECK( 49 | cudaMemcpyAsync(dst_ptr, dst_dptr, elems * sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); 50 | 51 | for (int i = 0; i < elems; i++) { 52 | FLUX_CHECK_EQ(dst_ptr[i], src_ptr[i]) << " at index " << i; 53 | } 54 | std::cerr << "check passed\n"; 55 | } 56 | } // namespace bytedance::flux 57 | 58 | int 59 | main() { 60 | cudaFree(0); 61 | bytedance::flux::test_copy_continous_aligned(128, true); // copy by uint4 62 | bytedance::flux::test_copy_continous_aligned(128 + 2, true); // copy by uint2 63 | bytedance::flux::test_copy_continous_aligned(128 + 1, true); // copy by int 64 | bytedance::flux::test_copy_continous_aligned(1024, false); 65 | bytedance::flux::test_copy_continous_aligned(1024 + 2, false); 66 | return 0; 67 | } 68 | -------------------------------------------------------------------------------- /test/python/util/test_bitwise_check.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ################################################################################ 17 | 18 | import argparse 19 | 20 | import torch 21 | import torch.distributed 22 | 23 | from flux.testing import DTYPE_MAP 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("-M", type=int, default=4096) 29 | parser.add_argument("-N", type=int, default=12288) 30 | parser.add_argument("-B", type=int, default=8) 31 | parser.add_argument("--dtype", default="bfloat16", type=str, help="data type") 32 | 33 | return parser.parse_args() 34 | 35 | 36 | def recover_from_bcsr(ori_t, block_h, block_w): 37 | assert ori_t.size(0) % block_h == 0 38 | assert ori_t.size(1) % block_w == 0 39 | M = ori_t.size(0) 40 | N = ori_t.size(1) 41 | BM = block_h 42 | BN = block_w 43 | new_out = torch.empty_like(ori_t) 44 | # tmp_t = torch.em 45 | tmp_t = ori_t.flatten().view(M // BM, N // BN, BM, BN) 46 | # if(RANK==0): 47 | # import pdb; pdb.set_trace() 48 | for bmi in range(ori_t.size(0) // block_h): 49 | for bni in range(ori_t.size(1) // block_w): 50 | m_start = bmi * BM 51 | m_end = m_start + BM 52 | n_start = bni * BN 53 | n_end = n_start + BN 54 | new_out[m_start:m_end, n_start:n_end] = tmp_t[bmi, bni] 55 | return new_out 56 | 57 | 58 | if __name__ == "__main__": 59 | args = parse_args() 60 | dtype = DTYPE_MAP[args.dtype] 61 | input = torch.rand(args.B, args.M, args.N).cuda().to(dtype) 62 | # input[:, 0,:] = 1 63 | output = torch.zeros(args.M, args.N).cuda().to(dtype) 64 | # red_output = torch.zeros(args.M, args.N).cuda().to(dtype) 65 | ref_out = torch.sum(input, dim=0) 66 | ref_out = recover_from_bcsr(ref_out, 128, 128) 67 | import flux_ths_pybind as ths 68 | 69 | ths.bsr_reduce(input, output, 128, 128) 70 | ret = ths.bitwise_check(ref_out, output) 71 | assert ret == True 72 | print("Bitwise check passed!") 73 | -------------------------------------------------------------------------------- /include/flux/cuda/nvml_stub.h: -------------------------------------------------------------------------------- 1 | //===- nvml_stub.h ----------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | #pragma once 18 | 19 | /*! 20 | * \file nvml.h 21 | * \brief CUDA stub to avoid direct CUDA driver call 22 | */ 23 | #pragma once 24 | 25 | #include 26 | 27 | namespace bytedance::flux { 28 | 29 | #define FLUX_FORALL_NVML(_) \ 30 | _(nvmlDeviceGetCount) \ 31 | _(nvmlDeviceGetMemoryAffinity) \ 32 | _(nvmlDeviceGetCudaComputeCapability) \ 33 | _(nvmlDeviceGetNvLinkRemoteDeviceType) \ 34 | _(nvmlDeviceGetFieldValues) \ 35 | _(nvmlDeviceGetHandleByIndex) \ 36 | _(nvmlDeviceGetHandleByPciBusId) \ 37 | _(nvmlDeviceGetIndex) \ 38 | _(nvmlDeviceGetMaxPcieLinkGeneration) \ 39 | _(nvmlDeviceGetName) \ 40 | _(nvmlDeviceGetNvLinkCapability) \ 41 | _(nvmlDeviceGetNvLinkRemotePciInfo) \ 42 | _(nvmlDeviceGetNvLinkState) \ 43 | _(nvmlDeviceGetNvLinkVersion) \ 44 | _(nvmlDeviceGetP2PStatus) \ 45 | _(nvmlErrorString) \ 46 | _(nvmlInit) \ 47 | _(nvmlShutdown) 48 | 49 | extern "C" { 50 | typedef struct NVML { 51 | #define CREATE_MEMBER(name) decltype(&(name)) name; 52 | FLUX_FORALL_NVML(CREATE_MEMBER) 53 | #undef CREATE_MEMBER 54 | } NVML; 55 | } 56 | 57 | NVML &nvml_stub(); 58 | 59 | #define NVML_CHECK(expr) \ 60 | do { \ 61 | nvmlReturn_t rtn = expr; \ 62 | FLUX_CHECK(rtn == NVML_SUCCESS) \ 63 | << "Got bad nvml status: " << nvml_stub().nvmlErrorString(rtn) << "(" << rtn << ") at " \ 64 | << #expr << "\n"; \ 65 | } while (0) 66 | 67 | } // namespace bytedance::flux 68 | -------------------------------------------------------------------------------- /src/gemm_rs/ths_op/gemm_reduce_scatter.h: -------------------------------------------------------------------------------- 1 | //===- gemm_reduce_scatter.h -------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | #include "flux/ths_op/ths_op.h" 22 | #include "coll/ths_op/reduce_scatter_op.h" 23 | 24 | namespace bytedance::flux::ths_op { 25 | class GemmRS { 26 | public: 27 | GemmRS( 28 | std::shared_ptr tp_group, 29 | int32_t nnodes, 30 | int32_t max_m, 31 | int32_t n_dim, 32 | c10::ScalarType input_dtype, 33 | c10::ScalarType output_dtype, 34 | bool transpose_weight, 35 | bool fuse_reduction, 36 | bool ring_reduction); 37 | ~GemmRS(); 38 | void zero_buffers(); 39 | torch::Tensor forward( 40 | torch::Tensor input, 41 | torch::Tensor weight, 42 | c10::optional bias, 43 | c10::optional input_scale, 44 | c10::optional weight_scale, 45 | c10::optional output_scale, 46 | bool fast_accum, 47 | const ReduceScatterOptionWithOptional &reduce_scatter_option); 48 | torch::Tensor profiling( 49 | torch::Tensor input, 50 | torch::Tensor weight, 51 | c10::optional bias, 52 | c10::optional input_scale, 53 | c10::optional weight_scale, 54 | c10::optional output_scale, 55 | bool fast_accum, 56 | c10::intrusive_ptr opt_ctx, 57 | const ReduceScatterOptionWithOptional &reduce_scatter_option); 58 | void forward_barrier( 59 | torch::Tensor input, torch::Tensor weight, c10::optional bias); 60 | torch::Tensor forward_reduce_scatter( 61 | torch::Tensor input, 62 | torch::Tensor weight, 63 | c10::optional bias, 64 | c10::optional input_scale, 65 | c10::optional weight_scale); 66 | 67 | private: 68 | class GemmRSImpl; 69 | GemmRSImpl *impl_ = nullptr; 70 | }; 71 | } // namespace bytedance::flux::ths_op 72 | -------------------------------------------------------------------------------- /src/gemm_rs/reduce_scatter_topos.hpp: -------------------------------------------------------------------------------- 1 | //===- reduce_scatter_topos.hpp ----------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | 20 | namespace bytedance::flux { 21 | constexpr static int kLocalWorldSize = 8; 22 | constexpr static int kStages = 4; 23 | struct Topology { 24 | int rank_from[4][8]; 25 | int rank_to[4][8]; 26 | int unused_segments_push[8]; 27 | int segments[4][2]; 28 | int rank_index[2][8]; 29 | }; 30 | /* 31 | ring mode: topo 0 32 | 1rd stage: 4 -> [0 -> 1 -> 2 -> 3] -> [7 -> 6 -> 5 -> 4] -> 0 33 | 2rd stage: 5 -> [1 -> 2 -> 3 -> 0] -> [4 -> 7 -> 6 -> 5] -> 1 34 | 3nd stage: 6 -> [2 -> 3 -> 0 -> 1] -> [5 -> 4 -> 7 -> 6] -> 2 35 | 4st stage: 7 -> [3 -> 0 -> 1 -> 2] -> [6 -> 5 -> 4 -> 7] -> 3 36 | 37 | no ring mode: topo 1 38 | 1rd stage: 4 -> [0 -> 1 -> 2 -> 3] -> [7 -> 6 -> 5 -> 4] -> 0 39 | 2rd stage: 5 -> [1 -> 0 -> 3 -> 2] -> [6 -> 7 -> 4 -> 5] -> 1 40 | 3nd stage: 6 -> [2 -> 3 -> 0 -> 1] -> [5 -> 4 -> 7 -> 6] -> 2 41 | 4st stage: 7 -> [3 -> 2 -> 1 -> 0] -> [4 -> 5 -> 6 -> 7] -> 3 42 | 43 | */ 44 | constexpr static __device__ Topology kTopologys[] = { 45 | // topo 0 46 | {{{4, 0, 1, 2, 5, 6, 7, 3}, 47 | {3, 5, 1, 2, 0, 6, 7, 4}, 48 | {3, 0, 6, 2, 5, 1, 7, 4}, 49 | {3, 0, 1, 7, 5, 6, 2, 4}}, 50 | {{1, 2, 3, 7, 0, 4, 5, 6}, 51 | {4, 2, 3, 0, 7, 1, 5, 6}, 52 | {1, 5, 3, 0, 7, 4, 2, 6}, 53 | {1, 2, 6, 0, 7, 4, 5, 3}}, 54 | {3, 0, 1, 2, 5, 6, 7, 4}, 55 | {{3, 4}, {0, 5}, {1, 6}, {2, 7}}, 56 | { 57 | {7, 3, 4, 0, 5, 1, 6, 2}, // numa node 0 58 | {0, 4, 1, 5, 2, 6, 3, 7}, // numa node 1 59 | }}, 60 | // topo 1 61 | {{{4, 0, 1, 2, 5, 6, 7, 3}, 62 | {1, 5, 3, 0, 7, 4, 2, 6}, 63 | {3, 0, 6, 2, 5, 1, 7, 4}, 64 | {1, 2, 3, 7, 0, 4, 5, 6}}, 65 | {{1, 2, 3, 7, 0, 4, 5, 6}, 66 | {3, 0, 6, 2, 5, 1, 7, 4}, 67 | {1, 5, 3, 0, 7, 4, 2, 6}, 68 | {4, 0, 1, 2, 5, 6, 7, 3}}, 69 | {3, 2, 1, 0, 7, 6, 5, 4}, 70 | {{3, 4}, {2, 5}, {1, 6}, {0, 7}}, 71 | { 72 | {7, 3, 6, 2, 5, 1, 4, 0}, 73 | {0, 4, 1, 5, 2, 6, 3, 7}, 74 | }}}; 75 | } // namespace bytedance::flux 76 | -------------------------------------------------------------------------------- /test/python/util/test_bsr_reduce.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ################################################################################ 17 | 18 | import argparse 19 | 20 | import torch 21 | import torch.distributed 22 | 23 | import flux 24 | from flux.testing import DTYPE_MAP 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("-M", type=int, default=4096) 30 | parser.add_argument("-N", type=int, default=12288) 31 | parser.add_argument("-B", type=int, default=8) 32 | parser.add_argument("--dtype", default="bfloat16", type=str, help="data type") 33 | 34 | return parser.parse_args() 35 | 36 | 37 | def recover_from_bcsr(ori_t, block_h, block_w): 38 | assert ori_t.size(0) % block_h == 0 39 | assert ori_t.size(1) % block_w == 0 40 | M = ori_t.size(0) 41 | N = ori_t.size(1) 42 | BM = block_h 43 | BN = block_w 44 | new_out = torch.empty_like(ori_t) 45 | # tmp_t = torch.em 46 | tmp_t = ori_t.flatten().view(M // BM, N // BN, BM, BN) 47 | # if(RANK==0): 48 | # import pdb; pdb.set_trace() 49 | for bmi in range(ori_t.size(0) // block_h): 50 | for bni in range(ori_t.size(1) // block_w): 51 | m_start = bmi * BM 52 | m_end = m_start + BM 53 | n_start = bni * BN 54 | n_end = n_start + BN 55 | new_out[m_start:m_end, n_start:n_end] = tmp_t[bmi, bni] 56 | return new_out 57 | 58 | 59 | if __name__ == "__main__": 60 | args = parse_args() 61 | dtype = DTYPE_MAP[args.dtype] 62 | input = torch.rand(args.B, args.M, args.N).cuda().to(dtype) 63 | # input[:, 0,:] = 1 64 | output = torch.zeros(args.M, args.N).cuda().to(dtype) 65 | # red_output = torch.zeros(args.M, args.N).cuda().to(dtype) 66 | ref_out = torch.sum(input, dim=0) 67 | ref_out = recover_from_bcsr(ref_out, 128, 128) 68 | import flux_ths_pybind as ths 69 | 70 | ths.bsr_reduce(input, output, 128, 128) 71 | print(output) 72 | print(ref_out) 73 | # print(torch.sum(ref_out)) 74 | # print(torch.sum(output)) 75 | flux.torch_allclose(output, ref_out, 1e-4, 1e-4) 76 | for i in range(1000): 77 | ths.bsr_reduce(input, output, 128, 128) 78 | ref_out = torch.sum(input, dim=0) 79 | # assert(torch.allclose(output, ref_out, 1e-2, 1e-2)) 80 | -------------------------------------------------------------------------------- /src/gemm_rs/tuning_config/config_gemm_rs_sm80_A100_tp4_nnodes1.cu: -------------------------------------------------------------------------------- 1 | //===- config_gemm_rs_sm80_A100_tp4_nnodes1.cu ------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | // clang-format off 18 | #include "flux/op_registry.h" 19 | namespace bytedance::flux { 20 | using namespace cute; 21 | 22 | static int config_gemm_rs_sm80_a100_tp4_nnodes1 = []() { 23 | auto &inst = TuningConfigRegistry::instance(); 24 | /// PCIE 25 | inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}()),_Sm80{}(),_A100{}(),_ReduceScatter{}(),_RRR{}(),_GemmV2{}(),None{},make_reduce_scatter_meta(false,_IntraNodePcie{}())),make_runtime_config(8192,12288,12288,make_reduce_scatter_runtime_config(4,1)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l)),None{},cute::make_tuple(128l,256l,32l),_GemmStreamK{}(),3,_RasterHeuristic{}())); 26 | inst.add(make_gemm_meta(make_gemm_dtype_config(_FP16{}(),_FP16{}(),_Void{}(),_FP16{}()),_Sm80{}(),_A100{}(),_ReduceScatter{}(),_RRR{}(),_GemmV2{}(),None{},make_reduce_scatter_meta(false,_IntraNodePcie{}())),make_runtime_config(8192,12288,12288,make_reduce_scatter_runtime_config(4,1)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l)),None{},cute::make_tuple(128l,256l,32l),_GemmStreamK{}(),4,_RasterHeuristic{}())); 27 | /// NVLink 28 | inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}()),_Sm80{}(),_A100{}(),_ReduceScatter{}(),_RCR{}(),_GemmV2{}(),None{},make_reduce_scatter_meta(false,_IntraNode{}())),make_runtime_config(8192,12288,12288,make_reduce_scatter_runtime_config(4,1)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l)),None{},cute::make_tuple(128l,256l,32l),_GemmStreamK{}(),3,_RasterHeuristic{}())); 29 | inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}()),_Sm80{}(),_A100{}(),_ReduceScatter{}(),_RRR{}(),_GemmV2{}(),None{},make_reduce_scatter_meta(false,_IntraNode{}())),make_runtime_config(8192,12288,12288,make_reduce_scatter_runtime_config(4,1)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l)),None{},cute::make_tuple(128l,256l,32l),_GemmStreamK{}(),3,_RasterHeuristic{}())); 30 | return 0; 31 | }(); 32 | } 33 | // clang-format on 34 | -------------------------------------------------------------------------------- /test/python/inplace_cast/test_inplace_cast.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ################################################################################ 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | import flux 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument( 28 | "--sizes", 29 | nargs="+", 30 | type=int, 31 | default=[1048580, 1048572, 1048573, 1048575, 1048571, 8388604, 8388603, 8388605], 32 | ) 33 | 34 | return parser.parse_args() 35 | 36 | 37 | if __name__ == "__main__": 38 | args = parse_args() 39 | import flux_ths_pybind as ths 40 | 41 | # Test passed on (1048576 1048580 1048572 1048573 1048575 1048571 8388604 8388603 8388605) 42 | 43 | # Get the maximum data size from args.sizes 44 | max_data_size = max(args.sizes) 45 | 46 | ## Test inplace_cast_fp32_to_bf16 API 47 | for data_size in args.sizes: 48 | test_input = torch.rand(data_size, dtype=torch.float32).cuda() 49 | ref_output = test_input.to(torch.bfloat16) 50 | ths.inplace_cast_fp32_to_bf16(test_input) 51 | test_output = test_input.view(torch.bfloat16) 52 | test_output = torch.narrow(test_output, 0, 0, data_size) 53 | flux.torch_allclose(test_output, ref_output, 1e-5, 1e-8) 54 | 55 | ## Test InplaceCast class: create object outside loop 56 | inplace_cast_op = flux.InplaceCast(max_data_size) 57 | for data_size in args.sizes: 58 | test_data = torch.rand(data_size, dtype=torch.float32).cuda() 59 | golden = test_data.to(torch.bfloat16) 60 | inplace_cast_op.from_fp32_to_bf16(test_data) 61 | test_data_bf16 = test_data.view(torch.bfloat16) 62 | test_data_bf16 = torch.narrow(test_data_bf16, 0, 0, data_size) 63 | flux.torch_allclose(test_data_bf16, golden, 1e-5, 1e-8) 64 | 65 | ## Test InplaceCast class: create object inside loop 66 | for data_size in args.sizes: 67 | inplace_cast_op = flux.InplaceCast(data_size) 68 | test_data = torch.rand(data_size, dtype=torch.float32).cuda() 69 | golden = test_data.to(torch.bfloat16) 70 | inplace_cast_op.from_fp32_to_bf16(test_data) 71 | test_data_bf16 = test_data.view(torch.bfloat16) 72 | test_data_bf16 = torch.narrow(test_data_bf16, 0, 0, data_size) 73 | flux.torch_allclose(test_data_bf16, golden, 1e-5, 1e-8) 74 | -------------------------------------------------------------------------------- /src/cuda/cuda_common.cu: -------------------------------------------------------------------------------- 1 | //===- cuda_common.cu ------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | #include "flux/cuda/cuda_common.h" 18 | #include "flux/cuda/cuda_common_device.hpp" 19 | 20 | namespace bytedance::flux { 21 | void 22 | copy_continous_aligned( 23 | void *dst, 24 | const void *src, 25 | size_t nbytes, 26 | int threadblock_count, 27 | int thread_count, 28 | cudaStream_t stream) { 29 | dim3 grid(threadblock_count); 30 | dim3 block(thread_count); 31 | { // copy by uint4 32 | using PackT = uint4; 33 | constexpr int kPackSize = sizeof(PackT); 34 | if (intptr_t(dst) % sizeof(PackT) == 0 && intptr_t(src) % sizeof(PackT) == 0 && 35 | nbytes % kPackSize == 0) { 36 | copy_continous_aligned_kernel<<>>(dst, src, nbytes); 37 | CUTE_CHECK_ERROR(cudaGetLastError()); 38 | return; 39 | } 40 | } 41 | { // copy by uint2 42 | using PackT = uint2; 43 | constexpr int kPackSize = sizeof(PackT); 44 | if (intptr_t(dst) % sizeof(PackT) == 0 && intptr_t(src) % sizeof(PackT) == 0 && 45 | nbytes % kPackSize == 0) { 46 | copy_continous_aligned_kernel<<>>(dst, src, nbytes); 47 | CUTE_CHECK_ERROR(cudaGetLastError()); 48 | return; 49 | } 50 | } 51 | { // copy by uint 52 | using PackT = uint; 53 | constexpr int kPackSize = sizeof(PackT); 54 | if (intptr_t(dst) % sizeof(PackT) == 0 && intptr_t(src) % sizeof(PackT) == 0 && 55 | nbytes % kPackSize == 0) { 56 | copy_continous_aligned_kernel<<>>(dst, src, nbytes); 57 | CUTE_CHECK_ERROR(cudaGetLastError()); 58 | return; 59 | } 60 | } 61 | { // copy by int16_t 62 | using PackT = int16_t; 63 | constexpr int kPackSize = sizeof(PackT); 64 | if (intptr_t(dst) % sizeof(PackT) == 0 && intptr_t(src) % sizeof(PackT) == 0 && 65 | nbytes % kPackSize == 0) { 66 | copy_continous_aligned_kernel<<>>(dst, src, nbytes); 67 | CUTE_CHECK_ERROR(cudaGetLastError()); 68 | return; 69 | } 70 | } 71 | { // copy by int8_t 72 | using PackT = int8_t; 73 | copy_continous_aligned_kernel<<>>(dst, src, nbytes); 74 | CUTE_CHECK_ERROR(cudaGetLastError()); 75 | return; 76 | } 77 | } 78 | 79 | } // namespace bytedance::flux 80 | -------------------------------------------------------------------------------- /src/cuda/moe_utils.cu: -------------------------------------------------------------------------------- 1 | //===- moe_utils.cu ---------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include "flux/cuda/reduce_utils.cuh" 19 | #include "flux/cuda/cuda_common.h" 20 | #include "flux/cuda/moe_utils.h" 21 | namespace bytedance::flux { 22 | 23 | __global__ void 24 | calc_scatter_index_kernel( 25 | const int *rank, const int *count, int *scatter_index, const int total_num) { 26 | constexpr unsigned FULL_MASK = 0xffffffff; 27 | __shared__ int s_offset[1024]; 28 | const int expert_rank = blockIdx.x; 29 | const int expert_num = expert_rank + 1; 30 | if (threadIdx.x < 32) { 31 | int cur_offset = 0; 32 | int expert_num_pad = ((expert_num + 31) >> 5) << 5; 33 | for (int i = threadIdx.x; i < expert_num_pad; i += 32) { 34 | int len = i < expert_num ? count[i] : 0; 35 | int temp_offset = warp_prefix_sum(threadIdx.x, len); 36 | if (i < expert_num) 37 | s_offset[i] = cur_offset + temp_offset - len; 38 | cur_offset += __shfl_sync(FULL_MASK, temp_offset, 31); 39 | } 40 | } 41 | __syncthreads(); 42 | 43 | const int warp_tid = threadIdx.x & 0x1F; 44 | const unsigned int t_mask = (1 << warp_tid) - 1; 45 | 46 | int *s_expert_offset = s_offset + blockIdx.x; 47 | int total_num_pad = ((total_num + blockDim.x - 1) / blockDim.x) * blockDim.x; 48 | for (int tid = threadIdx.x; tid < total_num_pad; tid += blockDim.x) { 49 | int rank_id = tid < total_num ? __ldg(&rank[tid]) : -1; 50 | const bool match = (rank_id == expert_rank); 51 | int active_mask = __ballot_sync(FULL_MASK, match); 52 | 53 | int warp_expert_offset = 0; 54 | if (warp_tid == 0) 55 | warp_expert_offset = atomicAdd(s_expert_offset, __popc(active_mask)); 56 | warp_expert_offset = __shfl_sync(FULL_MASK, warp_expert_offset, 0); 57 | 58 | int warp_offset = __popc(active_mask & t_mask); 59 | if (match) 60 | scatter_index[tid] = warp_expert_offset + warp_offset; 61 | } 62 | } 63 | 64 | void 65 | calc_scatter_index( 66 | const int *choosed_experts, // of total_num 67 | const int *count, // of expert_num 68 | int *scatter_index, // of total_num 69 | const int total_num, // topk * ntokens 70 | int expert_num, 71 | cudaStream_t stream) { 72 | calc_scatter_index_kernel<<>>( 73 | choosed_experts, count, scatter_index, total_num); 74 | CUDA_CHECK(cudaGetLastError()); 75 | } 76 | 77 | } // namespace bytedance::flux 78 | -------------------------------------------------------------------------------- /src/ag_gemm/ths_op/all_gather_gemm_op.h: -------------------------------------------------------------------------------- 1 | //===- all_gather_gemm_op.h --------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | #pragma once 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include "flux/ths_op/ths_op.h" 24 | #include "coll/ths_op/all_gather_types.h" 25 | 26 | namespace bytedance::flux::ths_op { 27 | 28 | class AllGatherGemmOp { 29 | public: 30 | AllGatherGemmOp( 31 | std::shared_ptr tp_group, 32 | int32_t nnodes, 33 | int32_t full_m, 34 | int32_t n_dim, 35 | int32_t k_dim, 36 | c10::ScalarType input_dtype, 37 | c10::ScalarType output_dtype, 38 | bool use_pdl); 39 | 40 | ~AllGatherGemmOp(); 41 | 42 | torch::Tensor forward( 43 | torch::Tensor input, 44 | torch::Tensor weight, 45 | c10::optional bias, 46 | c10::optional output, 47 | c10::optional input_scale, 48 | c10::optional weight_scale, 49 | c10::optional output_scale, 50 | bool fast_accum, 51 | bool transpose_weight, 52 | AllGatherOptionWithOptional opt, 53 | c10::optional gathered_input); 54 | 55 | torch::Tensor gemm_only( 56 | torch::Tensor input, // this should be the full input 57 | torch::Tensor weight, 58 | c10::optional bias, 59 | c10::optional output, 60 | c10::optional input_scale, // this should be the full scale 61 | c10::optional weight_scale, 62 | c10::optional output_scale, 63 | bool fast_accum, 64 | bool transpose_weight); 65 | 66 | torch::Tensor profiling( 67 | torch::Tensor input, 68 | torch::Tensor weight, 69 | c10::optional bias, 70 | c10::optional output, 71 | c10::optional input_scale, 72 | c10::optional weight_scale, 73 | c10::optional output_scale, 74 | bool fast_accum, 75 | bool transpose_weight, 76 | AllGatherOptionWithOptional option_, 77 | c10::optional gathered_input, 78 | c10::intrusive_ptr opt_ctx); 79 | 80 | private: 81 | class AllGatherGemmOpImpl; 82 | AllGatherGemmOpImpl *impl_ = nullptr; 83 | }; 84 | } // namespace bytedance::flux::ths_op 85 | -------------------------------------------------------------------------------- /src/pybind/gemm_grouped_v3_ag_scatter.cc: -------------------------------------------------------------------------------- 1 | //===- gemm_grouped_v3_ag_scatter.cc ------------------------------ C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include "moe_ag_scatter/ths_op/gemm_grouped_v3_ag_scatter.h" 19 | 20 | #include "flux/ths_op/ths_pybind.h" 21 | 22 | namespace bytedance::flux::ths_op { 23 | 24 | namespace py = pybind11; 25 | using GemmGroupedV3AGScatterOpCls = TorchClassWrapper; 26 | 27 | static int _register_gemm_only_ops [[maybe_unused]] = []() { 28 | ThsOpsInitRegistry::instance().register_one("gemm_grouped_v3_ag_scatter", [](py::module &m) { 29 | py::class_(m, "GemmGroupedV3AGScatter") 30 | .def(py::init(), py::arg("tp_env"), py::arg("moe_args")) 31 | .def("clear_buffers", &GemmGroupedV3AGScatterOpCls::clear_buffers) 32 | .def( 33 | "forward", 34 | &GemmGroupedV3AGScatterOpCls::forward, 35 | py::arg("inputs_shard"), 36 | py::arg("weights"), 37 | py::arg("splits_gpu"), 38 | py::arg("scatter_index"), 39 | py::arg("output_scale") = py::none(), 40 | py::arg("outputs_buf") = py::none(), 41 | py::arg("allgather_output") = py::none(), 42 | py::arg("fast_accum") = false, 43 | py::arg("sm_margin") = 0) 44 | .def( 45 | "forward_multiple_weights", 46 | &GemmGroupedV3AGScatterOpCls::forward_multiple_weights, 47 | py::arg("inputs_shard"), 48 | py::arg("weights"), 49 | py::arg("splits_gpu"), 50 | py::arg("scatter_index"), 51 | py::arg("output_scale") = py::none(), 52 | py::arg("outputs_buf") = py::none(), 53 | py::arg("allgather_output") = py::none(), 54 | py::arg("fast_accum") = false, 55 | py::arg("sm_margin") = 0) 56 | .def( 57 | "profiling", 58 | &GemmGroupedV3AGScatterOpCls::profiling, 59 | py::arg("inputs_shard"), 60 | py::arg("weights"), 61 | py::arg("splits_gpu"), 62 | py::arg("scatter_index"), 63 | py::arg("output_scale") = py::none(), 64 | py::arg("outputs_buf") = py::none(), 65 | py::arg("allgather_output") = py::none(), 66 | py::arg("fast_accum") = false, 67 | py::arg("sm_margin") = 0, 68 | py::arg("prof_ctx") = nullptr); 69 | }); 70 | return 0; 71 | }(); 72 | } // namespace bytedance::flux::ths_op 73 | -------------------------------------------------------------------------------- /src/cuda/cuda_common.cc: -------------------------------------------------------------------------------- 1 | //===- cuda_common.cc --------------------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include "flux/cuda/cuda_common.h" 19 | #include "flux/cuda/nvml_stub.h" 20 | 21 | namespace bytedance::flux { 22 | 23 | void 24 | ensure_nvml_init() { 25 | static bool inited = []() -> bool { 26 | NVML_CHECK(nvml_stub().nvmlInit()); // can be initialized many times. 27 | return true; 28 | }(); 29 | } 30 | 31 | // why not std::string? flux/th_op is compiled with -D_GLIBCXX_USE_CXX11_ABI=0 but flux/cuda is not 32 | const char * 33 | get_gpu_device_name(int devid) { 34 | ensure_nvml_init(); 35 | constexpr int kMaxDevices = 32; 36 | static std::array kDeviceNames = []() { 37 | std::array device_names; 38 | int count = 0; 39 | CUDA_CHECK(cudaGetDeviceCount(&count)); 40 | for (int i = 0; i < count; i++) { 41 | nvmlDevice_t device; 42 | NVML_CHECK(nvml_stub().nvmlDeviceGetHandleByIndex(i, &device)); 43 | NVML_CHECK( 44 | nvml_stub().nvmlDeviceGetName(device, device_names[i], NVML_DEVICE_NAME_V2_BUFFER_SIZE)); 45 | } 46 | return device_names; 47 | }(); 48 | return kDeviceNames[devid]; 49 | } 50 | 51 | unsigned 52 | get_pcie_gen(int devid) { 53 | nvmlDevice_t device; 54 | NVML_CHECK(nvml_stub().nvmlDeviceGetHandleByIndex(devid, &device)); 55 | unsigned int gen = 0; 56 | NVML_CHECK(nvml_stub().nvmlDeviceGetMaxPcieLinkGeneration(device, &gen)); 57 | return gen; 58 | } 59 | 60 | int 61 | get_sm_count(int device_id) { 62 | static std::vector sms = []() { 63 | int device_count; 64 | CUDA_CHECK(cudaGetDevice(&device_count)); 65 | FLUX_CHECK_GT(device_count, 0) << "No CUDA device found."; 66 | std::vector sm_counts(device_count, 0); 67 | for (int i = 0; i < device_count; i++) { 68 | cudaDeviceGetAttribute(&sm_counts[i], cudaDevAttrMultiProcessorCount, 0); 69 | } 70 | return sm_counts; 71 | }(); 72 | 73 | if (device_id < 0) { 74 | CUDA_CHECK(cudaGetDevice(&device_id)); 75 | } 76 | FLUX_CHECK_LT(device_id, sms.size()); 77 | return sms[device_id]; 78 | } 79 | 80 | int 81 | get_highest_cuda_stream_priority() { 82 | static int priority = []() { 83 | int least_priority, greatest_priority; 84 | CUDA_CHECK(cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority)); 85 | return greatest_priority; 86 | }(); 87 | return priority; 88 | } 89 | } // namespace bytedance::flux 90 | -------------------------------------------------------------------------------- /src/moe_ag_scatter/ths_op/gemm_grouped_v3_ag_scatter.h: -------------------------------------------------------------------------------- 1 | //===- gemm_grouped_v3_ag_scatter.h ------------------------------- C++ ---===// 2 | // 3 | // Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved. 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #pragma once 19 | #include 20 | #include 21 | #include "coll/ths_op/all_gather_types.h" 22 | #include "flux/ths_op/ths_op.h" 23 | 24 | namespace bytedance::flux::ths_op { 25 | std::tuple< 26 | int, 27 | torch::Tensor, 28 | torch::Tensor, 29 | torch::Tensor, 30 | torch::Tensor, 31 | torch::Tensor, 32 | torch::Tensor> 33 | prepare_moe_ag_scatter_args( 34 | torch::Tensor splits_gpu, 35 | torch::Tensor scatter_index, 36 | int ntokens, 37 | int topk, 38 | int num_weights_group, 39 | int ep_start, 40 | int ep_nexperts, 41 | int rank, 42 | int world_size, 43 | int tile_size_m, 44 | intptr_t stream_); 45 | 46 | class GemmGroupedV3AGScatterOp { 47 | public: 48 | GemmGroupedV3AGScatterOp(DistEnvTPWithEP tp_env_, MoeArguments moe_args); 49 | ~GemmGroupedV3AGScatterOp(); 50 | void clear_buffers(); 51 | torch::Tensor forward( 52 | torch::Tensor inputs_shard, 53 | torch::Tensor weights, 54 | torch::Tensor splits_gpu, 55 | torch::Tensor scatter_index, 56 | c10::optional output_scale, 57 | c10::optional outputs_buf, 58 | c10::optional allgather_output, 59 | bool fast_accum, 60 | int sm_margin); 61 | std::vector forward_multiple_weights( 62 | torch::Tensor inputs_shard, 63 | std::vector weights, 64 | torch::Tensor splits_gpu, 65 | torch::Tensor scatter_index, 66 | c10::optional> output_scale, 67 | c10::optional> outputs_buf, 68 | c10::optional allgather_output, 69 | bool fast_accum, 70 | int sm_margin); 71 | std::vector profiling( 72 | torch::Tensor inputs_shard, 73 | std::vector weights, 74 | torch::Tensor splits_gpu, 75 | torch::Tensor scatter_index, 76 | c10::optional> output_scale, 77 | c10::optional> outputs_buf, 78 | c10::optional allgather_output, 79 | bool fast_accum, 80 | int sm_margin, 81 | c10::intrusive_ptr opt_ctx); 82 | 83 | private: 84 | class GemmGroupedV3AGScatterOpImpl; 85 | GemmGroupedV3AGScatterOpImpl *impl_ = nullptr; 86 | }; 87 | 88 | } // namespace bytedance::flux::ths_op 89 | --------------------------------------------------------------------------------