├── .gitignore ├── LICENSE ├── README.md ├── csrc ├── CMakeLists.txt ├── config.hpp ├── deep_ep.cpp ├── deep_ep.hpp ├── event.hpp └── kernels │ ├── CMakeLists.txt │ ├── api.cuh │ ├── buffer.cuh │ ├── configs.cuh │ ├── exception.cuh │ ├── ibgda_device.cuh │ ├── internode.cu │ ├── internode_ll.cu │ ├── intranode.cu │ ├── launch.cuh │ ├── layout.cu │ ├── runtime.cu │ └── utils.cuh ├── deep_ep ├── __init__.py ├── buffer.py └── utils.py ├── figures ├── low-latency.png └── normal.png ├── install.sh ├── setup.py ├── tests ├── test_internode.py ├── test_intranode.py ├── test_low_latency.py └── utils.py └── third-party ├── README.md └── nvshmem.patch /.gitignore: -------------------------------------------------------------------------------- 1 | compile_commands.json 2 | .idea 3 | .DS_Store 4 | *.pyc 5 | build/ 6 | .cache/ 7 | .vscode/ 8 | */cmake-build-*/ 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 DeepSeek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepEP 2 | 3 | DeepEP is a communication library tailored for Mixture-of-Experts (MoE) and expert parallelism (EP). It provides high-throughput and low-latency all-to-all GPU kernels, which are also known as MoE dispatch and combine. The library also supports low-precision operations, including FP8. 4 | 5 | To align with the group-limited gating algorithm proposed in the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper, DeepEP offers a set of kernels optimized for asymmetric-domain bandwidth forwarding, such as forwarding data from NVLink domain to RDMA domain. These kernels deliver high throughput, making them suitable for both training and inference prefilling tasks. Additionally, they support SM (Streaming Multiprocessors) number control. 6 | 7 | For latency-sensitive inference decoding, DeepEP includes a set of low-latency kernels with pure RDMA to minimize delays. The library also introduces a hook-based communication-computation overlapping method that does not occupy any SM resource. 8 | 9 | Notice: the implementation in this library may have some slight differences from the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper. 10 | 11 | ## Performance 12 | 13 | ### Normal kernels with NVLink and RDMA forwarding 14 | 15 | We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow the DeepSeek-V3/R1 pretraining setting (4096 tokens per batch, 7168 hidden, top-4 groups, top-8 experts, FP8 dispatching and BF16 combining). 16 | 17 | | Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth | 18 | |:---------:|:------------:|:--------------------:|:-----------:|:--------------------:| 19 | | Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) | 20 | | Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) | 21 | | Internode | 32 | 58 GB/s (RDMA) | 32 | 57 GB/s (RDMA) | 22 | | Internode | 64 | 51 GB/s (RDMA) | 64 | 50 GB/s (RDMA) | 23 | 24 | **News (2025.04.22)**: with optimizations from Tencent Network Platform Department, performance was enhanced by up to 30%, see [#130](https://github.com/deepseek-ai/DeepEP/pull/130) for more details. Thanks for the contribution! 25 | 26 | ### Low-latency kernels with pure RDMA 27 | 28 | We test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow a typical DeepSeek-V3/R1 production setting (128 tokens per batch, 7168 hidden, top-8 experts, FP8 dispatching and BF16 combining). 29 | 30 | | Dispatch #EP | Latency | RDMA bandwidth | Combine #EP | Latency | RDMA bandwidth | 31 | |:------------:|:-------:|:--------------:|:-----------:|:-------:|:--------------:| 32 | | 8 | 77 us | 98 GB/s | 8 | 114 us | 127 GB/s | 33 | | 16 | 118 us | 63 GB/s | 16 | 195 us | 74 GB/s | 34 | | 32 | 155 us | 48 GB/s | 32 | 273 us | 53 GB/s | 35 | | 64 | 173 us | 43 GB/s | 64 | 314 us | 46 GB/s | 36 | | 128 | 192 us | 39 GB/s | 128 | 369 us | 39 GB/s | 37 | | 256 | 194 us | 39 GB/s | 256 | 360 us | 40 GB/s | 38 | 39 | **News (2025.06.05)**: low-latency kernels now leverage NVLink as much as possible, see [#173](https://github.com/deepseek-ai/DeepEP/pull/173) for more details. Thanks for the contribution! 40 | 41 | ## Quick start 42 | 43 | ### Requirements 44 | 45 | - Ampere (SM80), Hopper (SM90) GPUs, or other architectures with SM90 PTX ISA support 46 | - Python 3.8 and above 47 | - CUDA version 48 | - CUDA 11.0 and above for SM80 GPUs 49 | - CUDA 12.3 and above for SM90 GPUs 50 | - PyTorch 2.1 and above 51 | - NVLink for intranode communication 52 | - RDMA network for internode communication 53 | 54 | ### Download and install NVSHMEM dependency 55 | 56 | DeepEP also depends on NVSHMEM. Please refer to our [NVSHMEM Installation Guide](third-party/README.md) for instructions. 57 | 58 | ### Development 59 | 60 | ```bash 61 | # Build and make symbolic links for SO files 62 | NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py build 63 | # You may modify the specific SO names according to your own platform 64 | ln -s build/lib.linux-x86_64-cpython-38/deep_ep_cpp.cpython-38-x86_64-linux-gnu.so 65 | 66 | # Run test cases 67 | # NOTES: you may modify the `init_dist` function in `tests/utils.py` 68 | # according to your own cluster settings, and launch into multiple nodes 69 | python tests/test_intranode.py 70 | python tests/test_internode.py 71 | python tests/test_low_latency.py 72 | ``` 73 | 74 | ### Installation 75 | 76 | ```bash 77 | NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py install 78 | ``` 79 | 80 | #### Installation environment variables 81 | 82 | - `NVSHMEM_DIR`: the path to the NVSHMEM directory, disable all internode and low-latency features if not specified 83 | - `DISABLE_SM90_FEATURES`: 0 or 1, whether to disable SM90 features, it is required for SM90 devices or CUDA 11 84 | - `TORCH_CUDA_ARCH_LIST`: the list of target architectures, e.g. `TORCH_CUDA_ARCH_LIST="9.0"` 85 | - `DISABLE_AGGRESSIVE_PTX_INSTRS`: 0 or 1, whether to disable aggressive load/store instructions, see [Undefined-behavior PTX usage](#undefined-behavior-ptx-usage) for more details 86 | 87 | Then, import `deep_ep` in your Python project, and enjoy! 88 | 89 | ## Network configurations 90 | 91 | DeepEP is fully tested with InfiniBand networks. However, it is theoretically compatible with RDMA over Converged Ethernet (RoCE) as well. 92 | 93 | ### Traffic isolation 94 | 95 | Traffic isolation is supported by InfiniBand through Virtual Lanes (VL). 96 | 97 | To prevent interference between different types of traffic, we recommend segregating workloads across different virtual lanes as follows: 98 | 99 | - workloads using normal kernels 100 | - workloads using low-latency kernels 101 | - other workloads 102 | 103 | For DeepEP, you can control the virtual lane assignment by setting the `NVSHMEM_IB_SL` environment variable. 104 | 105 | ### Adaptive routing 106 | 107 | Adaptive routing is an advanced routing feature provided by InfiniBand switches that can evenly distribute traffic across multiple paths. Enabling adaptive routing can completely eliminate network congestion caused by routing conflicts, but it also introduces additional latency. We recommend the following configuration for optimal performance: 108 | 109 | - enable adaptive routing in environments with heavy network loads 110 | - use static routing in environments with light network loads 111 | 112 | ### Congestion control 113 | 114 | Congestion control is disabled as we have not observed significant congestion in our production environment. 115 | 116 | ## Interfaces and examples 117 | 118 | ### Example use in model training or inference prefilling 119 | 120 | The normal kernels can be used in model training or the inference prefilling phase (without the backward part) as the below example code shows. 121 | 122 | ```python 123 | import torch 124 | import torch.distributed as dist 125 | from typing import List, Tuple, Optional, Union 126 | 127 | from deep_ep import Buffer, EventOverlap 128 | 129 | # Communication buffer (will allocate at runtime) 130 | _buffer: Optional[Buffer] = None 131 | 132 | # Set the number of SMs to use 133 | # NOTES: this is a static variable 134 | Buffer.set_num_sms(24) 135 | 136 | 137 | # You may call this function at the framework initialization 138 | def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer: 139 | global _buffer 140 | 141 | # NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests 142 | num_nvl_bytes, num_rdma_bytes = 0, 0 143 | for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())): 144 | num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes) 145 | num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) 146 | 147 | # Allocate a buffer if not existed or not enough buffer size 148 | if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes: 149 | _buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes) 150 | return _buffer 151 | 152 | 153 | def get_hidden_bytes(x: torch.Tensor) -> int: 154 | t = x[0] if isinstance(x, tuple) else x 155 | return t.size(1) * max(t.element_size(), 2) 156 | 157 | 158 | def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], 159 | topk_idx: torch.Tensor, topk_weights: torch.Tensor, 160 | num_experts: int, previous_event: Optional[EventOverlap] = None) -> \ 161 | Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]: 162 | # NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency 163 | # of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please 164 | # refer to the docs of `Buffer.dispatch` 165 | global _buffer 166 | 167 | # Calculate layout before actual dispatch 168 | num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = \ 169 | _buffer.get_dispatch_layout(topk_idx, num_experts, 170 | previous_event=previous_event, async_finish=True, 171 | allocate_on_comm_stream=previous_event is not None) 172 | # Do MoE dispatch 173 | # NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph 174 | # Unless you specify `num_worst_tokens`, but this flag is for intranode only 175 | # For more advanced usages, please refer to the docs of the `dispatch` function 176 | recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \ 177 | _buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights, 178 | num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, 179 | is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert, 180 | previous_event=previous_event, async_finish=True, 181 | allocate_on_comm_stream=True) 182 | # For event management, please refer to the docs of the `EventOverlap` class 183 | return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event 184 | 185 | 186 | def dispatch_backward(grad_recv_x: torch.Tensor, grad_recv_topk_weights: torch.Tensor, handle: Tuple) -> \ 187 | Tuple[torch.Tensor, torch.Tensor, EventOverlap]: 188 | global _buffer 189 | 190 | # The backward process of MoE dispatch is actually a combine 191 | # For more advanced usages, please refer to the docs of the `combine` function 192 | combined_grad_x, combined_grad_recv_topk_weights, event = \ 193 | _buffer.combine(grad_recv_x, handle, topk_weights=grad_recv_topk_weights, async_finish=True) 194 | 195 | # For event management, please refer to the docs of the `EventOverlap` class 196 | return combined_grad_x, combined_grad_recv_topk_weights, event 197 | 198 | 199 | def combine_forward(x: torch.Tensor, handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \ 200 | Tuple[torch.Tensor, EventOverlap]: 201 | global _buffer 202 | 203 | # Do MoE combine 204 | # For more advanced usages, please refer to the docs of the `combine` function 205 | combined_x, _, event = _buffer.combine(x, handle, async_finish=True, previous_event=previous_event, 206 | allocate_on_comm_stream=previous_event is not None) 207 | 208 | # For event management, please refer to the docs of the `EventOverlap` class 209 | return combined_x, event 210 | 211 | 212 | def combine_backward(grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], 213 | handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \ 214 | Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]: 215 | global _buffer 216 | 217 | # The backward process of MoE combine is actually a dispatch 218 | # For more advanced usages, please refer to the docs of the `dispatch` function 219 | grad_x, _, _, _, _, event = _buffer.dispatch(grad_combined_x, handle=handle, async_finish=True, 220 | previous_event=previous_event, 221 | allocate_on_comm_stream=previous_event is not None) 222 | 223 | # For event management, please refer to the docs of the `EventOverlap` class 224 | return grad_x, event 225 | ``` 226 | 227 | Moreover, inside the dispatch function, we may not know how many tokens to receive for the current rank. So an implicit CPU wait for GPU received count signal will be involved, as the following figure shows. 228 | 229 |  230 | 231 | ### Example use in inference decoding 232 | 233 | The low latency kernels can be used in the inference decoding phase as the below example code shows. 234 | 235 | ```python 236 | import torch 237 | import torch.distributed as dist 238 | from typing import Tuple, Optional 239 | 240 | from deep_ep import Buffer 241 | 242 | # Communication buffer (will allocate at runtime) 243 | # NOTES: there is no SM control API for the low-latency kernels 244 | _buffer: Optional[Buffer] = None 245 | 246 | 247 | # You may call this function at the framework initialization 248 | def get_buffer(group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> Buffer: 249 | # NOTES: the low-latency mode will consume much more space than the normal mode 250 | # So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256 251 | global _buffer 252 | num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts) 253 | 254 | # Allocate a buffer if not existed or not enough buffer size 255 | if _buffer is None or _buffer.group != group or not _buffer.low_latency_mode or _buffer.num_rdma_bytes < num_rdma_bytes: 256 | # NOTES: for the best performance, the QP number **must** be equal to the number of the local experts 257 | assert num_experts % group.size() == 0 258 | _buffer = Buffer(group, 0, num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // group.size()) 259 | return _buffer 260 | 261 | 262 | def low_latency_dispatch(hidden_states: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int): 263 | global _buffer 264 | 265 | # Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay) 266 | recv_hidden_states, recv_expert_count, handle, event, hook = \ 267 | _buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts, 268 | async_finish=False, return_recv_hook=True) 269 | 270 | # NOTES: the actual tensor will not be received only if you call `hook()`, 271 | # it is useful for double-batch overlapping, but **without any SM occupation** 272 | # If you don't want to overlap, please set `return_recv_hook=False` 273 | # Later, you can use our GEMM library to do the computation with this specific format 274 | return recv_hidden_states, recv_expert_count, handle, event, hook 275 | 276 | 277 | def low_latency_combine(hidden_states: torch.Tensor, 278 | topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: Tuple): 279 | global _buffer 280 | 281 | # Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay) 282 | combined_hidden_states, event_overlap, hook = \ 283 | _buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle, 284 | async_finish=False, return_recv_hook=True) 285 | 286 | # NOTES: the same behavior as described in the dispatch kernel 287 | return combined_hidden_states, event_overlap, hook 288 | ``` 289 | 290 | For two-micro-batch overlapping, you can refer to the following figure. With our receiving hook interface, the RDMA network traffic is happening in the background, without costing any GPU SMs from the computation part. But notice, the overlapped parts can be adjusted, i.e., the 4 parts of attention/dispatch/MoE/combine may not have the exact same execution time. You may adjust the stage settings according to your workload. 291 | 292 |  293 | 294 | ## Roadmap 295 | 296 | - [x] AR support 297 | - [x] Refactor low-latency mode AR code 298 | - [x] A100 support (intranode only) 299 | - [x] Support BF16 for the low-latency dispatch kernel 300 | - [x] Support NVLink protocol for intranode low-latency kernels 301 | - [ ] TMA copy instead of LD/ST 302 | - [x] Intranode kernels 303 | - [ ] Internode kernels 304 | - [ ] Low-latency kernels 305 | - [ ] SM-free kernels and refactors 306 | - [ ] Fully remove undefined-behavior PTX instructions 307 | 308 | ## Notices 309 | 310 | #### Easier potential overall design 311 | 312 | The current DeepEP implementation uses queues for communication buffers which save memory but introduce complexity and potential deadlocks. If you're implementing your own version based on DeepEP, consider using fixed-size buffers allocated to maximum capacity for simplicity and better performance. For a detailed discussion of this alternative approach, see https://github.com/deepseek-ai/DeepEP/issues/39. 313 | 314 | #### Undefined-behavior PTX usage 315 | 316 | - For extreme performance, we discover and use an undefined-behavior PTX usage: using read-only PTX `ld.global.nc.L1::no_allocate.L2::256B` to **read volatile data**. The PTX modifier `.nc` indicates that a non-coherent cache is used. But the correctness is tested to be guaranteed with `.L1::no_allocate` on Hopper architectures, and performance will be much better. The reason we guess may be: the non-coherent cache is unified with L1, and the L1 modifier is not just a hint but a strong option, so that the correctness can be guaranteed by no dirty data in L1. 317 | - Initially, because NVCC could not automatically unroll volatile read PTX, we tried using `__ldg` (i.e., `ld.nc`). Even compared to manually unrolled volatile reads, it was significantly faster (likely due to additional compiler optimizations). However, the results could be incorrect or dirty. After consulting the PTX documentation, we discovered that L1 and non-coherent cache are unified on Hopper architectures. We speculated that `.L1::no_allocate` might resolve the issue, leading to this discovery. 318 | - If you find kernels not working on some other platforms, you may add `DISABLE_AGGRESSIVE_PTX_INSTRS=1` to `setup.py` and disable this, or file an issue. 319 | 320 | #### Auto-tuning on your cluster 321 | 322 | For better performance on your cluster, we recommend to run all the tests and use the best auto-tuned configuration. The default configurations are optimized on the DeepSeek's internal cluster. 323 | 324 | ## License 325 | 326 | This code repository is released under [the MIT License](LICENSE), except for codes that reference NVSHMEM (including `csrc/kernels/ibgda_device.cuh` and `third-party/nvshmem.patch`), which are subject to [NVSHMEM SLA](https://docs.nvidia.com/nvshmem/api/sla.html). 327 | 328 | ## Community Forks 329 | 330 | - [Infrawaves/DeepEP_ibrc_dual-ports_multiQP](https://github.com/Infrawaves/DeepEP_ibrc_dual-ports_multiQP) - Adds multi-QP solution and dual-port NIC support in IBRC transport 331 | 332 | ## Citation 333 | 334 | If you use this codebase or otherwise find our work valuable, please cite: 335 | 336 | ```bibtex 337 | @misc{deepep2025, 338 | title={DeepEP: an efficient expert-parallel communication library}, 339 | author={Chenggang Zhao and Shangyan Zhou and Liyue Zhang and Chengqi Deng and Zhean Xu and Yuxuan Liu and Kuai Yu and Jiashi Li and Liang Zhao}, 340 | year={2025}, 341 | publisher = {GitHub}, 342 | howpublished = {\url{https://github.com/deepseek-ai/DeepEP}}, 343 | } 344 | ``` 345 | -------------------------------------------------------------------------------- /csrc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # NOTES: this CMake is only for debugging; for setup, please use Torch extension 2 | cmake_minimum_required(VERSION 3.10) 3 | project(deep_ep LANGUAGES CUDA CXX) 4 | set(CMAKE_VERBOSE_MAKEFILE ON) 5 | 6 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC") 7 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC") 8 | set(CUDA_SEPARABLE_COMPILATION ON) 9 | list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG") 10 | list(APPEND CUDA_NVCC_FLAGS "-O3") 11 | list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage") 12 | 13 | set(USE_SYSTEM_NVTX on) 14 | set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile") 15 | set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}") 16 | 17 | find_package(CUDAToolkit REQUIRED) 18 | find_package(pybind11 REQUIRED) 19 | find_package(Torch REQUIRED) 20 | find_package(NVSHMEM REQUIRED HINTS ${NVSHMEM_ROOT_DIR}/lib/cmake/nvshmem) 21 | 22 | add_library(nvshmem ALIAS nvshmem::nvshmem) 23 | add_library(nvshmem_host ALIAS nvshmem::nvshmem_host) 24 | add_library(nvshmem_device ALIAS nvshmem::nvshmem_device) 25 | 26 | set(CMAKE_CXX_STANDARD 17) 27 | set(CMAKE_CUDA_STANDARD 17) 28 | 29 | include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR}) 30 | link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR}) 31 | 32 | add_subdirectory(kernels) 33 | 34 | # Link CPP and CUDA together 35 | pybind11_add_module(deep_ep_cpp deep_ep.cpp) 36 | target_link_libraries(deep_ep_cpp PRIVATE ${EP_CUDA_LIBRARIES} ${TORCH_LIBRARIES} torch_python) 37 | -------------------------------------------------------------------------------- /csrc/config.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "kernels/api.cuh" 4 | #include "kernels/exception.cuh" 5 | 6 | namespace deep_ep { 7 | 8 | template <typename dtype_t> 9 | dtype_t ceil_div(dtype_t a, dtype_t b) { 10 | return (a + b - 1) / b; 11 | } 12 | 13 | template <typename dtype_t> 14 | dtype_t align(dtype_t a, dtype_t b) { 15 | return ceil_div<dtype_t>(a, b) * b; 16 | } 17 | 18 | struct Config { 19 | int num_sms; 20 | int num_max_nvl_chunked_send_tokens; 21 | int num_max_nvl_chunked_recv_tokens; 22 | int num_max_rdma_chunked_send_tokens; 23 | int num_max_rdma_chunked_recv_tokens; 24 | 25 | Config(int num_sms, 26 | int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, 27 | int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) : 28 | num_sms(num_sms), 29 | num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens), 30 | num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens), 31 | num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens), 32 | num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) { 33 | EP_HOST_ASSERT(num_sms >= 0); 34 | EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0); 35 | EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens); 36 | EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0); 37 | 38 | // Ceil up RDMA buffer size 39 | this->num_max_rdma_chunked_recv_tokens = align<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens); 40 | EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens); 41 | // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push 42 | EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2); 43 | } 44 | 45 | size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const { 46 | // Below are some assumptions 47 | // TODO: add assertions 48 | constexpr int kNumMaxTopK = 128; 49 | constexpr int kNumMaxScales = 128; 50 | EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); 51 | EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0); 52 | const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); 53 | const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); 54 | const int num_channels = num_sms / 2; 55 | 56 | size_t num_bytes = 0; 57 | num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int); 58 | num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes; 59 | #ifndef DISABLE_NVSHMEM 60 | num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(); 61 | #endif 62 | num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t); 63 | num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float); 64 | num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float); 65 | num_bytes = ((num_bytes + 127) / 128) * 128; 66 | return num_bytes; 67 | } 68 | 69 | size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const { 70 | #ifndef DISABLE_NVSHMEM 71 | // Legacy mode 72 | if (num_ranks <= NUM_MAX_NVL_PEERS) 73 | return 0; 74 | 75 | // Below are some assumptions 76 | // TODO: add assertions 77 | constexpr int kNumMaxTopK = 128; 78 | constexpr int kNumMaxScales = 128; 79 | EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); 80 | EP_HOST_ASSERT(num_sms % 2 == 0); 81 | const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; 82 | const int num_channels = num_sms / 2; 83 | 84 | size_t num_bytes = 0; 85 | num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); 86 | num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2; 87 | num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2; 88 | num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t) * 2; 89 | num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2; 90 | num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2; 91 | num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; 92 | num_bytes = ((num_bytes + 127) / 128) * 128; 93 | return num_bytes; 94 | #else 95 | EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation"); 96 | #endif 97 | } 98 | }; 99 | 100 | struct LowLatencyBuffer { 101 | int num_clean_int = 0; 102 | 103 | void* dispatch_rdma_send_buffer = nullptr; 104 | void* dispatch_rdma_recv_data_buffer = nullptr; 105 | int* dispatch_rdma_recv_count_buffer = nullptr; 106 | 107 | void* combine_rdma_send_buffer = nullptr; 108 | void* combine_rdma_recv_data_buffer = nullptr; 109 | int* combine_rdma_recv_flag_buffer = nullptr; 110 | 111 | void* combine_rdma_send_buffer_data_start = nullptr; 112 | size_t num_bytes_per_combine_msg = 0; 113 | 114 | std::pair<int*, int> clean_meta() { 115 | EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); 116 | return {dispatch_rdma_recv_count_buffer, num_clean_int}; 117 | } 118 | }; 119 | 120 | struct LowLatencyLayout { 121 | size_t total_bytes = 0; 122 | LowLatencyBuffer buffers[2]; 123 | 124 | template <typename out_ptr_t = void*, typename count_ptr_t = uint8_t*, typename in_ptr_t = void*> 125 | out_ptr_t advance(const in_ptr_t& ptr, size_t count) { 126 | return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count); 127 | } 128 | 129 | LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { 130 | const int num_scales = hidden / 128; 131 | 132 | // Dispatch and combine layout: 133 | // - 2 symmetric odd/even send buffer 134 | // - 2 symmetric odd/even receive buffers 135 | // - 2 symmetric odd/even signaling buffers 136 | 137 | // Message sizes 138 | // NOTES: you should add a control `int4` for combine messages if you want to do data transformation 139 | EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); 140 | size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); 141 | size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16); 142 | 143 | // Send buffer 144 | size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; 145 | size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; 146 | size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes); 147 | EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0); 148 | total_bytes += send_buffer_bytes * 2; 149 | 150 | // Symmetric receive buffers 151 | // TODO: optimize memory usages 152 | size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; 153 | size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; 154 | size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); 155 | EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); 156 | total_bytes += recv_buffer_bytes * 2; 157 | 158 | // Symmetric signaling buffers 159 | size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); 160 | size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; 161 | size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); 162 | total_bytes += signaling_buffer_bytes * 2; 163 | 164 | // Assign pointers 165 | // NOTES: we still leave some space for distinguishing dispatch/combine buffer, 166 | // so you may see some parameters are duplicated 167 | for (int i = 0; i < 2; ++ i) { 168 | buffers[i] = { 169 | static_cast<int>(signaling_buffer_bytes / sizeof(int)), 170 | advance(rdma_buffer, send_buffer_bytes * i), 171 | advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), 172 | advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), 173 | advance(rdma_buffer, send_buffer_bytes * i), 174 | advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), 175 | advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), 176 | advance(rdma_buffer, send_buffer_bytes * i), 177 | num_bytes_per_combine_msg 178 | }; 179 | } 180 | } 181 | }; 182 | 183 | size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { 184 | auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; 185 | return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; 186 | } 187 | 188 | } // namespace deep_ep 189 | -------------------------------------------------------------------------------- /csrc/deep_ep.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // Forcibly disable NDEBUG 4 | #ifdef NDEBUG 5 | #undef NDEBUG 6 | #endif 7 | 8 | #include <pybind11/pybind11.h> 9 | #include <pybind11/pytypes.h> 10 | #include <torch/types.h> 11 | #include <tuple> 12 | #include <vector> 13 | 14 | #include "config.hpp" 15 | #include "event.hpp" 16 | #include "kernels/configs.cuh" 17 | #include "kernels/exception.cuh" 18 | 19 | #ifndef TORCH_EXTENSION_NAME 20 | #define TORCH_EXTENSION_NAME deep_ep_cpp 21 | #endif 22 | 23 | namespace deep_ep { 24 | 25 | struct Buffer { 26 | EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8"); 27 | 28 | private: 29 | // Low-latency mode buffer 30 | int low_latency_buffer_idx = 0; 31 | bool low_latency_mode = false; 32 | 33 | // NVLink Buffer 34 | int64_t num_nvl_bytes; 35 | void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; 36 | void** buffer_ptrs_gpu = nullptr; 37 | 38 | // NVSHMEM Buffer 39 | int64_t num_rdma_bytes; 40 | void* rdma_buffer_ptr = nullptr; 41 | 42 | // Device info and communication 43 | int device_id; 44 | int num_device_sms; 45 | int rank, rdma_rank, nvl_rank; 46 | int num_ranks, num_rdma_ranks, num_nvl_ranks; 47 | cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; 48 | 49 | // Stream for communication 50 | at::cuda::CUDAStream comm_stream; 51 | 52 | // After IPC/NVSHMEM synchronization, this flag will be true 53 | bool available = false; 54 | 55 | // Whether explicit `destroy()` is required. 56 | bool explicitly_destroy; 57 | // After `destroy()` be called, this flag will be true 58 | bool destroyed = false; 59 | 60 | // Barrier signals 61 | int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; 62 | int** barrier_signal_ptrs_gpu = nullptr; 63 | 64 | // Workspace 65 | void* workspace = nullptr; 66 | 67 | // Host-side MoE info 68 | volatile int* moe_recv_counter = nullptr; 69 | int* moe_recv_counter_mapped = nullptr; 70 | 71 | // Host-side expert-level MoE info 72 | volatile int* moe_recv_expert_counter = nullptr; 73 | int* moe_recv_expert_counter_mapped = nullptr; 74 | 75 | // Host-side RDMA-level MoE info 76 | volatile int* moe_recv_rdma_counter = nullptr; 77 | int* moe_recv_rdma_counter_mapped = nullptr; 78 | 79 | public: 80 | Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy); 81 | 82 | ~Buffer() noexcept(false); 83 | 84 | bool is_available() const; 85 | 86 | bool is_internode_available() const; 87 | 88 | int get_num_rdma_ranks() const; 89 | 90 | int get_rdma_rank() const; 91 | 92 | int get_root_rdma_rank(bool global) const; 93 | 94 | int get_local_device_id() const; 95 | 96 | pybind11::bytearray get_local_ipc_handle() const; 97 | 98 | pybind11::bytearray get_local_nvshmem_unique_id() const; 99 | 100 | torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; 101 | 102 | torch::Stream get_comm_stream() const; 103 | 104 | void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt); 105 | 106 | void destroy(); 107 | 108 | std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>> 109 | get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event, 110 | bool async, bool allocate_on_comm_stream); 111 | 112 | std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>> 113 | intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales, 114 | const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights, 115 | const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert, 116 | int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix, 117 | int expert_alignment, int num_worst_tokens, const Config& config, 118 | std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream); 119 | 120 | std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> 121 | intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights, 122 | const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1, 123 | const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, 124 | const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream); 125 | 126 | std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<EventHandle>> 127 | internode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales, 128 | const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights, 129 | const std::optional<torch::Tensor>& num_tokens_per_rank, const std::optional<torch::Tensor>& num_tokens_per_rdma_rank, 130 | const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert, 131 | int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, 132 | const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum, 133 | const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum, 134 | int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream); 135 | 136 | std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> 137 | internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights, 138 | const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1, 139 | const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, 140 | const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, 141 | const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, 142 | const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream); 143 | 144 | void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); 145 | 146 | std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> 147 | low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, 148 | const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats, 149 | int num_max_dispatch_tokens_per_rank, int num_experts, 150 | bool use_fp8, bool round_scale, bool use_ue8m0, 151 | bool async, bool return_recv_hook); 152 | 153 | std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> 154 | low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, 155 | const torch::Tensor& src_info, const torch::Tensor& layout_range, 156 | int num_max_dispatch_tokens_per_rank, int num_experts, 157 | bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, 158 | const std::optional<torch::Tensor>& out = std::nullopt); 159 | 160 | torch::Tensor 161 | get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const; 162 | }; 163 | 164 | } // namespace deep_ep 165 | -------------------------------------------------------------------------------- /csrc/event.hpp: -------------------------------------------------------------------------------- 1 | #include <ATen/cuda/CUDAContext.h> 2 | #include <memory> 3 | 4 | #include "kernels/exception.cuh" 5 | 6 | namespace deep_ep { 7 | 8 | struct EventHandle { 9 | std::shared_ptr<torch::Event> event; 10 | 11 | EventHandle() { 12 | event = std::make_shared<torch::Event>(torch::kCUDA); 13 | event->record(at::cuda::getCurrentCUDAStream()); 14 | } 15 | 16 | explicit EventHandle(const at::cuda::CUDAStream& stream) { 17 | event = std::make_shared<torch::Event>(torch::kCUDA); 18 | event->record(stream); 19 | } 20 | 21 | EventHandle(const EventHandle& other) = default; 22 | 23 | void current_stream_wait() const { 24 | at::cuda::getCurrentCUDAStream().unwrap().wait(*event); 25 | } 26 | }; 27 | 28 | torch::Event create_event(const at::cuda::CUDAStream &s) { 29 | auto event = torch::Event(torch::kCUDA); 30 | event.record(s); 31 | return event; 32 | } 33 | 34 | void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) { 35 | EP_HOST_ASSERT(s_0.id() != s_1.id()); 36 | s_0.unwrap().wait(create_event(s_1)); 37 | } 38 | 39 | void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) { 40 | s.unwrap().wait(*event.event); 41 | } 42 | 43 | } // namespace deep_ep 44 | -------------------------------------------------------------------------------- /csrc/kernels/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | function(add_deep_ep_library target_name source_file) 2 | add_library(${target_name} STATIC ${source_file}) 3 | set_target_properties(${target_name} PROPERTIES 4 | POSITION_INDEPENDENT_CODE ON 5 | CXX_STANDARD_REQUIRED ON 6 | CUDA_STANDARD_REQUIRED ON 7 | CXX_STANDARD 17 8 | CUDA_STANDARD 17 9 | CUDA_SEPARABLE_COMPILATION ON 10 | ) 11 | target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5) 12 | endfunction() 13 | 14 | add_deep_ep_library(runtime_cuda runtime.cu) 15 | add_deep_ep_library(layout_cuda layout.cu) 16 | add_deep_ep_library(intranode_cuda intranode.cu) 17 | add_deep_ep_library(internode_cuda internode.cu) 18 | add_deep_ep_library(internode_ll_cuda internode_ll.cu) 19 | 20 | # Later, we should link all libraries in `EP_CUDA_LIBRARIES` 21 | set(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda internode_ll_cuda PARENT_SCOPE) 22 | -------------------------------------------------------------------------------- /csrc/kernels/api.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <vector> 4 | 5 | namespace deep_ep { 6 | 7 | // Intranode runtime 8 | namespace intranode { 9 | 10 | void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); 11 | 12 | } // namespace intranode 13 | 14 | // Internode runtime 15 | namespace internode { 16 | 17 | std::vector<uint8_t> get_unique_id(); 18 | 19 | int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode); 20 | 21 | void *alloc(size_t size, size_t alignment); 22 | 23 | void free(void *ptr); 24 | 25 | void barrier(); 26 | 27 | void finalize(); 28 | 29 | } // namespace internode 30 | 31 | // Layout kernels 32 | namespace layout { 33 | 34 | void get_dispatch_layout(const int64_t* topk_idx, 35 | int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, 36 | int* num_tokens_per_expert, bool* is_token_in_rank, 37 | int num_tokens, int num_topk, int num_ranks, int num_experts, 38 | cudaStream_t stream); 39 | 40 | } // namespace layout 41 | 42 | // Intranode kernels 43 | namespace intranode { 44 | 45 | void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, 46 | const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, 47 | int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, 48 | int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, 49 | void** buffer_ptrs, int** barrier_signal_ptrs, int rank, 50 | cudaStream_t stream, int num_sms); 51 | 52 | void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, 53 | void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks, 54 | cudaStream_t stream); 55 | 56 | void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, 57 | int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, 58 | const bool* is_token_in_rank, const int* channel_prefix_matrix, 59 | int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, 60 | int scale_token_stride, int scale_hidden_stride, 61 | void** buffer_ptrs, int rank, int num_ranks, 62 | cudaStream_t stream, int num_sms, 63 | int num_max_send_tokens, int num_recv_buffer_tokens); 64 | 65 | void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, 66 | int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); 67 | 68 | void combine(cudaDataType_t type, 69 | void* recv_x, float* recv_topk_weights, 70 | const void* x, const float* topk_weights, 71 | const void* bias_0, const void* bias_1, 72 | const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, 73 | int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, 74 | void** buffer_ptrs, int rank, int num_ranks, 75 | cudaStream_t stream, int num_sms, 76 | int num_max_send_tokens, int num_recv_buffer_tokens); 77 | 78 | } // namespace intranode 79 | 80 | // Internode kernels 81 | namespace internode { 82 | 83 | int get_source_meta_bytes(); 84 | 85 | void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, 86 | const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, 87 | const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, 88 | const bool* is_token_in_rank, int num_tokens, int num_channels, 89 | int hidden_int4, int num_scales, int num_topk, int expert_alignment, 90 | int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, 91 | int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, 92 | void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, 93 | void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, 94 | int** barrier_signal_ptrs, int rank, 95 | cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, 96 | bool low_latency_mode); 97 | 98 | void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, 99 | const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, 100 | int* send_rdma_head, int* send_nvl_head, 101 | int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, 102 | const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, 103 | const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, 104 | const bool* is_token_in_rank, 105 | int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, 106 | int scale_token_stride, int scale_hidden_stride, 107 | void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, 108 | void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, 109 | int rank, int num_ranks, bool is_cached_dispatch, 110 | cudaStream_t stream, int num_channels, bool low_latency_mode); 111 | 112 | void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, 113 | int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, 114 | const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, 115 | void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, 116 | void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, 117 | int** barrier_signal_ptrs, int rank, cudaStream_t stream, 118 | int64_t num_rdma_bytes, int64_t num_nvl_bytes, 119 | bool is_cached_dispatch, bool low_latency_mode); 120 | 121 | void combine(cudaDataType_t type, 122 | void* combined_x, float* combined_topk_weights, 123 | const bool* is_combined_token_in_rank, 124 | const void* x, const float* topk_weights, 125 | const void* bias_0, const void* bias_1, 126 | const int* combined_rdma_head, const int* combined_nvl_head, 127 | const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, 128 | int num_tokens, int num_combined_tokens, int hidden, int num_topk, 129 | void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, 130 | void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, 131 | int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode); 132 | 133 | } // namespace internode 134 | 135 | // Internode low-latency kernels 136 | namespace internode_ll { 137 | 138 | void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, 139 | int* clean_1, int num_clean_int_1, 140 | cudaStream_t stream); 141 | 142 | void dispatch(void* packed_recv_x, void* packed_recv_x_scales, 143 | int* packed_recv_src_info, int64_t* packed_recv_layout_range, 144 | int* packed_recv_count, 145 | int* cumulative_local_expert_recv_stats, 146 | void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, 147 | const void* x, const int64_t* topk_idx, 148 | int* next_clean, int num_next_clean_int, 149 | int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, 150 | int num_topk, int num_experts, int rank, int num_ranks, 151 | bool use_fp8, bool round_scale, bool use_ue8m0, 152 | void* workspace, int num_device_sms, 153 | cudaStream_t stream, int phases); 154 | 155 | void combine(void* combined_x, 156 | void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, 157 | const void* x, const int64_t* topk_idx, const float* topk_weights, 158 | const int* src_info, const int64_t* layout_range, 159 | int* next_clean, int num_next_clean_int, 160 | int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, 161 | int num_topk, int num_experts, int rank, int num_ranks, 162 | bool use_logfmt, 163 | void* workspace, int num_device_sms, 164 | cudaStream_t stream, int phases, bool zero_copy); 165 | 166 | } // namespace internode_ll 167 | 168 | } // namespace deep_ep 169 | -------------------------------------------------------------------------------- /csrc/kernels/buffer.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "configs.cuh" 4 | #include "exception.cuh" 5 | 6 | namespace deep_ep { 7 | 8 | template <typename dtype_t> 9 | struct Buffer { 10 | private: 11 | uint8_t* ptr; 12 | 13 | public: 14 | int total_bytes; 15 | 16 | __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {} 17 | 18 | __device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) { 19 | total_bytes = num_elems * sizeof(dtype_t); 20 | ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t); 21 | gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes; 22 | } 23 | 24 | __device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) { 25 | gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes; 26 | return *this; 27 | } 28 | 29 | __device__ __forceinline__ dtype_t* buffer() { 30 | return reinterpret_cast<dtype_t*>(ptr); 31 | } 32 | 33 | __device__ __forceinline__ dtype_t& operator[](int idx) { 34 | return buffer()[idx]; 35 | } 36 | }; 37 | 38 | template <typename dtype_t, int kNumRanks = 1> 39 | struct AsymBuffer { 40 | private: 41 | uint8_t* ptrs[kNumRanks]; 42 | int num_bytes; 43 | 44 | public: 45 | int total_bytes; 46 | 47 | __device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, 48 | int sm_id = 0, int num_sms = 1, int offset = 0) { 49 | EP_STATIC_ASSERT(kNumRanks == 1, ""); 50 | num_bytes = num_elems * sizeof(dtype_t); 51 | 52 | int per_channel_bytes = num_bytes * num_ranks; 53 | total_bytes = per_channel_bytes * num_sms; 54 | ptrs[0] = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; 55 | gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes; 56 | } 57 | 58 | __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, 59 | int sm_id = 0, int num_sms = 1, int offset = 0) { 60 | EP_STATIC_ASSERT(kNumRanks > 1, ""); 61 | num_bytes = num_elems * sizeof(dtype_t); 62 | 63 | int per_channel_bytes = num_bytes * num_ranks; 64 | total_bytes = per_channel_bytes * num_sms; 65 | for (int i = 0; i < kNumRanks; ++ i) { 66 | ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; 67 | gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes; 68 | } 69 | } 70 | 71 | __device__ __forceinline__ void advance(int shift) { 72 | #pragma unroll 73 | for (int i = 0; i < kNumRanks; ++ i) 74 | ptrs[i] = ptrs[i] + shift * sizeof(dtype_t); 75 | } 76 | 77 | __device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) { 78 | gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes; 79 | return *this; 80 | } 81 | 82 | template<int kNumAlsoRanks> 83 | __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) { 84 | for (int i = 0; i < kNumAlsoRanks; ++ i) 85 | gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes; 86 | return *this; 87 | } 88 | 89 | __device__ __forceinline__ dtype_t* buffer(int idx = 0) { 90 | EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case"); 91 | return reinterpret_cast<dtype_t*>(ptrs[0] + num_bytes * idx); 92 | } 93 | 94 | __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) { 95 | EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case"); 96 | return reinterpret_cast<dtype_t*>(ptrs[rank_idx] + num_bytes * idx); 97 | } 98 | }; 99 | 100 | template <typename dtype_t, bool kDecoupled = true> 101 | struct SymBuffer { 102 | private: 103 | // NOTES: for non-decoupled case, `recv_ptr` is not used 104 | uint8_t* send_ptr; 105 | uint8_t* recv_ptr; 106 | int num_bytes; 107 | 108 | public: 109 | int total_bytes; 110 | 111 | __device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, 112 | int sm_id = 0, int num_sms = 1) { 113 | num_bytes = num_elems * sizeof(dtype_t); 114 | 115 | int per_channel_bytes = num_bytes * num_ranks; 116 | total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1); 117 | send_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id; 118 | recv_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); 119 | gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes; 120 | } 121 | 122 | __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) { 123 | EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case"); 124 | return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx); 125 | } 126 | 127 | __device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) { 128 | EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case"); 129 | return reinterpret_cast<dtype_t*>(recv_ptr + num_bytes * idx); 130 | } 131 | 132 | __device__ __forceinline__ dtype_t* buffer(int idx = 0) { 133 | EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case"); 134 | return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx); 135 | } 136 | }; 137 | 138 | } // namespace deep_ep 139 | -------------------------------------------------------------------------------- /csrc/kernels/configs.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define NUM_MAX_NVL_PEERS 8 4 | #define NUM_MAX_RDMA_PEERS 20 5 | #define NUM_WORKSPACE_BYTES (32 * 1024 * 1024) 6 | #define NUM_MAX_LOCAL_EXPERTS 1024 7 | #define NUM_BUFFER_ALIGNMENT_BYTES 128 8 | 9 | #define FINISHED_SUM_TAG 1024 10 | #define NUM_WAIT_NANOSECONDS 500 11 | 12 | #ifndef ENABLE_FAST_DEBUG 13 | #define NUM_CPU_TIMEOUT_SECS 100 14 | #define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s 15 | #else 16 | #define NUM_CPU_TIMEOUT_SECS 10 17 | #define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s 18 | #endif 19 | 20 | #define LOW_LATENCY_SEND_PHASE 1 21 | #define LOW_LATENCY_RECV_PHASE 2 22 | 23 | // Make CLion CUDA indexing work 24 | #ifdef __CLION_IDE__ 25 | #define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier) 26 | #define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) 27 | #endif 28 | 29 | // Remove Torch restrictions 30 | #ifdef __CUDA_NO_HALF_CONVERSIONS__ 31 | #undef __CUDA_NO_HALF_CONVERSIONS__ 32 | #endif 33 | #ifdef __CUDA_NO_HALF_OPERATORS__ 34 | #undef __CUDA_NO_HALF_OPERATORS__ 35 | #endif 36 | #ifdef __CUDA_NO_HALF2_OPERATORS__ 37 | #undef __CUDA_NO_HALF2_OPERATORS__ 38 | #endif 39 | #ifdef __CUDA_NO_BFLOAT16_CONVERSIONS__ 40 | #undef __CUDA_NO_BFLOAT16_CONVERSIONS__ 41 | #endif 42 | #ifdef __CUDA_NO_BFLOAT162_OPERATORS__ 43 | #undef __CUDA_NO_BFLOAT162_OPERATORS__ 44 | #endif 45 | 46 | #include <cstdint> 47 | #include <cuda_bf16.h> 48 | #include <cuda_runtime.h> 49 | 50 | #ifndef DISABLE_SM90_FEATURES 51 | #include <cuda_fp8.h> 52 | #else 53 | // Ampere does not support FP8 features 54 | #define __NV_E4M3 0 55 | #define __NV_E5M2 1 56 | typedef int __nv_fp8_interpretation_t; 57 | typedef int __nv_fp8x4_e4m3; 58 | typedef uint8_t __nv_fp8_storage_t; 59 | #endif 60 | 61 | #ifndef DISABLE_NVSHMEM 62 | #include <nvshmem.h> 63 | #include <nvshmemx.h> 64 | #include <infiniband/mlx5dv.h> 65 | #include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh> 66 | #include <device_host_transport/nvshmem_common_ibgda.h> 67 | #endif 68 | -------------------------------------------------------------------------------- /csrc/kernels/exception.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <string> 4 | #include <exception> 5 | 6 | #include "configs.cuh" 7 | 8 | #ifndef EP_STATIC_ASSERT 9 | #define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason) 10 | #endif 11 | 12 | class EPException: public std::exception { 13 | private: 14 | std::string message = {}; 15 | 16 | public: 17 | explicit EPException(const char *name, const char* file, const int line, const std::string& error) { 18 | message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'"; 19 | } 20 | 21 | const char *what() const noexcept override { return message.c_str(); } 22 | }; 23 | 24 | #ifndef CUDA_CHECK 25 | #define CUDA_CHECK(cmd) \ 26 | do { \ 27 | cudaError_t e = (cmd); \ 28 | if (e != cudaSuccess) { \ 29 | throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \ 30 | } \ 31 | } while (0) 32 | #endif 33 | 34 | #ifndef EP_HOST_ASSERT 35 | #define EP_HOST_ASSERT(cond) \ 36 | do { \ 37 | if (not (cond)) { \ 38 | throw EPException("Assertion", __FILE__, __LINE__, #cond); \ 39 | } \ 40 | } while (0) 41 | #endif 42 | 43 | #ifndef EP_DEVICE_ASSERT 44 | #define EP_DEVICE_ASSERT(cond) \ 45 | do { \ 46 | if (not (cond)) { \ 47 | printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ 48 | asm("trap;"); \ 49 | } \ 50 | } while (0) 51 | #endif 52 | -------------------------------------------------------------------------------- /csrc/kernels/ibgda_device.cuh: -------------------------------------------------------------------------------- 1 | // Portions derived from NVSHMEM (https://developer.nvidia.com/nvshmem) 2 | // Copyright (c) NVIDIA Corporation. 3 | // Licensed under the NVSHMEM Software License Agreement (version: September 3, 2019). 4 | // See full license at: https://docs.nvidia.com/nvshmem/api/sla.html 5 | // 6 | // Modified from original source: 7 | // - nvshmem/src/include/non_abi/device/pt-to-pt/ibgda_device.cuh 8 | #pragma once 9 | 10 | #include "configs.cuh" 11 | #include "exception.cuh" 12 | #include "utils.cuh" 13 | 14 | namespace deep_ep { 15 | 16 | EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth"); 17 | 18 | __device__ static __forceinline__ 19 | uint64_t HtoBE64(uint64_t x) { 20 | uint64_t ret; 21 | asm("{\n\t" 22 | ".reg .b32 ign;\n\t" 23 | ".reg .b32 lo;\n\t" 24 | ".reg .b32 hi;\n\t" 25 | ".reg .b32 new_lo;\n\t" 26 | ".reg .b32 new_hi;\n\t" 27 | "mov.b64 {lo,hi}, %1;\n\t" 28 | "prmt.b32 new_hi, lo, ign, 0x0123;\n\t" 29 | "prmt.b32 new_lo, hi, ign, 0x0123;\n\t" 30 | "mov.b64 %0, {new_lo,new_hi};\n\t" 31 | "}" : "=l"(ret) : "l"(x)); 32 | return ret; 33 | } 34 | 35 | __device__ static __forceinline__ 36 | uint32_t HtoBE32(uint32_t x) { 37 | uint32_t ret; 38 | asm("{\n\t" 39 | ".reg .b32 ign;\n\t" 40 | "prmt.b32 %0, %1, ign, 0x0123;\n\t" 41 | "}" : "=r"(ret) : "r"(x)); 42 | return ret; 43 | } 44 | 45 | __device__ static __forceinline__ 46 | uint16_t HtoBE16(uint16_t x) { 47 | // TODO: simplify PTX using 16-bit instructions 48 | auto a = static_cast<uint32_t>(x); 49 | uint32_t d; 50 | asm volatile( 51 | "{\n\t" 52 | ".reg .b32 mask;\n\t" 53 | ".reg .b32 ign;\n\t" 54 | "mov.b32 mask, 0x4401;\n\t" 55 | "mov.b32 ign, 0x0;\n\t" 56 | "prmt.b32 %0, %1, ign, mask;\n\t" 57 | "}" 58 | : "=r"(d) 59 | : "r"(a)); 60 | return static_cast<uint16_t>(d); 61 | } 62 | 63 | typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t; 64 | 65 | typedef struct { 66 | uint32_t add_data; 67 | uint32_t field_boundary; 68 | uint64_t reserved; 69 | } __attribute__((__packed__)) ibgda_atomic_32_masked_fa_seg_t; 70 | 71 | __device__ static __forceinline__ 72 | nvshmemi_ibgda_device_state_t* ibgda_get_state() { 73 | return &nvshmemi_ibgda_device_state_d; 74 | } 75 | 76 | __device__ static __forceinline__ 77 | nvshmemi_ibgda_device_qp_t* ibgda_get_rc(int pe, int id) { 78 | auto state = ibgda_get_state(); 79 | const auto num_rc_per_pe = ibgda_get_state()->num_rc_per_pe; 80 | return &state->globalmem.rcs[pe * num_rc_per_pe * state->num_devices_initialized + id % (num_rc_per_pe * state->num_devices_initialized)]; 81 | } 82 | 83 | __device__ static __forceinline__ 84 | void ibgda_lock_acquire(int *lock) { 85 | while (atomicCAS(lock, 0, 1) == 1); 86 | 87 | // Prevent reordering before the lock is acquired 88 | memory_fence_cta(); 89 | } 90 | 91 | __device__ static __forceinline__ 92 | void ibgda_lock_release(int *lock) { 93 | memory_fence_cta(); 94 | 95 | // Prevent reordering before lock is released 96 | st_na_relaxed(lock, 0); 97 | } 98 | 99 | __device__ static __forceinline__ 100 | void ibgda_update_dbr(nvshmemi_ibgda_device_qp_t *qp, uint32_t dbrec_head) { 101 | // `DBREC` contains the index of the next empty `WQEBB` 102 | __be32 dbrec_val; 103 | __be32 *dbrec_ptr = qp->tx_wq.dbrec; 104 | 105 | // This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(dbrec_head & 0xffff))` 106 | asm("{\n\t" 107 | ".reg .b32 dbrec_head_16b;\n\t" 108 | ".reg .b32 ign;\n\t" 109 | "and.b32 dbrec_head_16b, %1, 0xffff;\n\t" 110 | "prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t" 111 | "}" 112 | : "=r"(dbrec_val) 113 | : "r"(dbrec_head)); 114 | st_na_release(dbrec_ptr, dbrec_val); 115 | } 116 | 117 | __device__ static __forceinline__ 118 | void ibgda_ring_db(nvshmemi_ibgda_device_qp_t *qp, uint16_t prod_idx) { 119 | auto bf_ptr = reinterpret_cast<uint64_t*>(qp->tx_wq.bf); 120 | ibgda_ctrl_seg_t ctrl_seg = { 121 | .opmod_idx_opcode = HtoBE32(prod_idx << 8), 122 | .qpn_ds = HtoBE32(qp->qpn << 8) 123 | }; 124 | 125 | EP_STATIC_ASSERT(sizeof(decltype(&ctrl_seg)) == sizeof(uint64_t), ""); 126 | st_na_release(bf_ptr, *(reinterpret_cast<uint64_t*>(&ctrl_seg))); 127 | } 128 | 129 | __device__ static __forceinline__ 130 | void ibgda_post_send(nvshmemi_ibgda_device_qp_t *qp, uint64_t new_prod_idx) { 131 | nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars; 132 | uint64_t old_prod_idx; 133 | 134 | // Update `prod_idx` before ringing the doorbell, so that we know which index is needed in quiet/fence 135 | ibgda_lock_acquire(&mvars->post_send_lock); 136 | 137 | old_prod_idx = atomicMax(reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.prod_idx), new_prod_idx); 138 | if (new_prod_idx > old_prod_idx) { 139 | ibgda_update_dbr(qp, new_prod_idx); 140 | ibgda_ring_db(qp, new_prod_idx); 141 | } 142 | ibgda_lock_release(&mvars->post_send_lock); 143 | } 144 | 145 | template <bool kAlwaysDoPostSend> 146 | __device__ static __forceinline__ 147 | void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t *qp, uint64_t base_wqe_idx, 148 | uint32_t num_wqes, int message_idx = 0) { 149 | auto state = ibgda_get_state(); 150 | nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars; 151 | uint64_t new_wqe_idx = base_wqe_idx + num_wqes; 152 | 153 | // WQE writes must be finished first 154 | __threadfence(); 155 | 156 | unsigned long long int *ready_idx = 157 | (unsigned long long int *)(state->use_async_postsend ? qp->tx_wq.prod_idx 158 | : &mvars->tx_wq.ready_head); 159 | 160 | // Wait for prior WQE slots to be filled first 161 | while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx); 162 | 163 | // Always post, not in batch 164 | if (!state->use_async_postsend) { 165 | constexpr int kNumRequestInBatch = 4; 166 | if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0) 167 | ibgda_post_send(qp, new_wqe_idx); 168 | } 169 | } 170 | 171 | __device__ static __forceinline__ void 172 | ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *val, uint64_t raddr, 173 | __be32 rkey, uint16_t wqe_idx, void** out_wqes, uint32_t imm) { 174 | ibgda_ctrl_seg_t ctrl_seg; 175 | struct mlx5_wqe_raddr_seg raddr_seg; 176 | struct mlx5_wqe_inl_data_seg inl_seg; 177 | 178 | auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]); 179 | auto *raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); 180 | auto *inl_seg_ptr = reinterpret_cast<mlx5_wqe_inl_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr)); 181 | auto *wqe_data_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(inl_seg_ptr) + sizeof(*inl_seg_ptr)); 182 | 183 | raddr_seg.raddr = HtoBE64(raddr); 184 | raddr_seg.rkey = rkey; 185 | raddr_seg.reserved = 0; 186 | 187 | inl_seg.byte_count = HtoBE32(4 | MLX5_INLINE_SEG); 188 | 189 | // `imm == std::numeric_limits<uint32_t>::max()` means no imm writes 190 | ctrl_seg = {0}; 191 | ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3); 192 | ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; 193 | ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | (imm != std::numeric_limits<uint32_t>::max() ? MLX5_OPCODE_RDMA_WRITE_IMM : MLX5_OPCODE_RDMA_WRITE)); 194 | if (imm != std::numeric_limits<uint32_t>::max()) 195 | ctrl_seg.imm = HtoBE32(imm); 196 | 197 | EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16"); 198 | EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16"); 199 | EP_STATIC_ASSERT(sizeof(*inl_seg_ptr) == 4, "sizeof(*inl_seg_ptr) == 4"); 200 | st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg)); 201 | st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg)); 202 | st_na_relaxed(reinterpret_cast<uint32_t*>(inl_seg_ptr), *reinterpret_cast<const uint32_t*>(&inl_seg)); 203 | st_na_relaxed(reinterpret_cast<uint32_t*>(wqe_data_ptr), *reinterpret_cast<const uint32_t*>(val)); 204 | } 205 | 206 | __device__ static __forceinline__ 207 | uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey, 208 | uint64_t raddr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey, uint32_t dev_idx) { 209 | auto state = ibgda_get_state(); 210 | auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base); 211 | auto log2_cumem_granularity = state->log2_cumem_granularity; 212 | 213 | // Local key 214 | uint64_t idx = ((laddr - heap_start) >> log2_cumem_granularity) * state->num_devices_initialized + dev_idx; 215 | auto device_key = state->constmem.lkeys[idx]; 216 | auto lchunk_size = device_key.next_addr - laddr; 217 | *lkey = device_key.key; 218 | 219 | // Remote key 220 | uint64_t roffset = raddr - heap_start; 221 | 222 | idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) * state->num_devices_initialized 223 | + dst_pe * state->num_devices_initialized + dev_idx; 224 | if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) { 225 | device_key = state->constmem.rkeys[idx]; 226 | } else { 227 | device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS]; 228 | } 229 | *out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset; 230 | *out_rkey = device_key.key; 231 | 232 | // Return the minimum of local and remote chunk sizes 233 | auto rchunk_size = device_key.next_addr - roffset; 234 | return min(lchunk_size, rchunk_size); 235 | } 236 | 237 | __device__ static __forceinline__ void 238 | ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey, uint32_t dev_idx) { 239 | auto state = ibgda_get_state(); 240 | auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base); 241 | 242 | uint64_t roffset = addr - heap_start; 243 | uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes * state->num_devices_initialized) 244 | + dst_pe * state->num_devices_initialized + dev_idx; 245 | nvshmemi_ibgda_device_key_t device_key; 246 | if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) 247 | device_key = state->constmem.rkeys[idx]; 248 | else 249 | device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS]; 250 | *out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset; 251 | *out_rkey = device_key.key; 252 | } 253 | 254 | __device__ static __forceinline__ uint64_t 255 | ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) { 256 | auto mvars = &qp->mvars; 257 | return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes)); 258 | } 259 | 260 | __device__ static __forceinline__ void* 261 | ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) { 262 | uint16_t cnt = qp->tx_wq.nwqes; 263 | uint16_t idx = wqe_idx & (cnt - 1); 264 | return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT)); 265 | } 266 | 267 | __device__ static __forceinline__ void 268 | nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits<uint32_t>::max()) { 269 | // Get rkey 270 | // NOTES: the `p` operation will not cross multiple remote chunks 271 | __be32 rkey; 272 | uint64_t raddr; 273 | auto qp = ibgda_get_rc(dst_pe, qp_id); 274 | ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), dst_pe, &raddr, &rkey, qp->dev_idx); 275 | 276 | // Write WQEs 277 | uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); 278 | void *wqe_ptrs; 279 | wqe_ptrs = ibgda_get_wqe_ptr(qp, base_wqe_idx); 280 | ibgda_write_rdma_write_inl_wqe(qp, reinterpret_cast<const uint32_t*>(&value), raddr, rkey, base_wqe_idx, &wqe_ptrs, imm); 281 | 282 | // Submit requests 283 | ibgda_submit_requests<true>(qp, base_wqe_idx, 1); 284 | } 285 | 286 | __device__ static __forceinline__ void 287 | ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t *qp, uint64_t laddr, __be32 lkey, 288 | uint64_t raddr, __be32 rkey, uint32_t bytes, uint16_t wqe_idx, 289 | void** out_wqes) { 290 | ibgda_ctrl_seg_t ctrl_seg; 291 | struct mlx5_wqe_raddr_seg raddr_seg; 292 | struct mlx5_wqe_data_seg data_seg; 293 | 294 | auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]); 295 | void *av_seg_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); 296 | struct mlx5_wqe_raddr_seg *raddr_seg_ptr; 297 | struct mlx5_wqe_data_seg *data_seg_ptr; 298 | 299 | raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(av_seg_ptr)); 300 | data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr)); 301 | 302 | raddr_seg.raddr = HtoBE64(raddr); 303 | raddr_seg.rkey = rkey; 304 | raddr_seg.reserved = 0; 305 | 306 | data_seg.byte_count = HtoBE32(bytes); 307 | data_seg.lkey = lkey; 308 | data_seg.addr = HtoBE64(laddr); 309 | 310 | ctrl_seg = {0}; 311 | ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3); 312 | ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; 313 | ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | MLX5_OPCODE_RDMA_WRITE); 314 | 315 | EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16"); 316 | EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16"); 317 | EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == 16, "sizeof(*data_seg_ptr) == 16"); 318 | st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg)); 319 | st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg)); 320 | st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg)); 321 | } 322 | 323 | __device__ static __forceinline__ void 324 | ibgda_write_empty_recv_wqe(void *out_wqe) { 325 | auto *data_seg_ptr = reinterpret_cast<struct mlx5_wqe_data_seg*>(out_wqe); 326 | struct mlx5_wqe_data_seg data_seg; 327 | 328 | // Make the first segment in the WQE invalid, then the entire list will be invalid 329 | data_seg.byte_count = 0; 330 | data_seg.lkey = HtoBE64(MLX5_INVALID_LKEY); 331 | data_seg.addr = 0; 332 | 333 | EP_STATIC_ASSERT(sizeof(mlx5_wqe_data_seg) == sizeof(int4), "Invalid data type length"); 334 | st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg)); 335 | } 336 | 337 | template <bool kAlwaysDoPostSend = false> 338 | __device__ static __forceinline__ void 339 | nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) { 340 | // Get lkey and rkey, store them into lanes 341 | uint32_t num_wqes = 0; 342 | __be32 my_lkey = 0; 343 | uint64_t my_laddr = 0; 344 | __be32 my_rkey = 0; 345 | uint64_t my_raddr = 0; 346 | uint64_t my_chunk_size = 0; 347 | 348 | auto qp = ibgda_get_rc(dst_pe, qp_id); 349 | 350 | // Decide how many messages (theoretically 3 for maximum) 351 | auto remaining_bytes = bytes; 352 | while (remaining_bytes > 0) { 353 | if (lane_id == num_wqes) { 354 | my_chunk_size = min(remaining_bytes, 355 | ibgda_get_lkey_and_rkey(my_laddr = req_lptr, 356 | &my_lkey, 357 | req_rptr, 358 | dst_pe, 359 | &my_raddr, 360 | &my_rkey, 361 | qp->dev_idx)); 362 | } 363 | 364 | // Move one more message 365 | auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast<int>(num_wqes)); 366 | remaining_bytes -= chunk_size; 367 | req_lptr += chunk_size; 368 | req_rptr += chunk_size; 369 | ++ num_wqes; 370 | } 371 | EP_DEVICE_ASSERT(num_wqes <= 32); 372 | 373 | // Process WQE 374 | uint64_t base_wqe_idx = 0; 375 | if (lane_id == 0) 376 | base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes); 377 | base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0); 378 | if (lane_id < num_wqes) { 379 | auto wqe_idx = base_wqe_idx + lane_id; 380 | auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx); 381 | ibgda_write_rdma_write_wqe(qp, my_laddr, my_lkey, my_raddr, my_rkey, my_chunk_size, 382 | wqe_idx, &wqe_ptr); 383 | } 384 | __syncwarp(); 385 | 386 | // Submit 387 | if (lane_id == 0) 388 | ibgda_submit_requests<kAlwaysDoPostSend>(qp, base_wqe_idx, num_wqes, message_idx); 389 | __syncwarp(); 390 | } 391 | 392 | __device__ static __forceinline__ void ibgda_write_amo_add_wqe( 393 | nvshmemi_ibgda_device_qp_t *qp, const int &value, 394 | uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey, 395 | uint16_t wqe_idx, void** out_wqes) { 396 | ibgda_ctrl_seg_t ctrl_seg = {0}; 397 | struct mlx5_wqe_raddr_seg raddr_seg; 398 | struct mlx5_wqe_atomic_seg atomic_seg_1; 399 | struct mlx5_wqe_data_seg data_seg; 400 | 401 | auto ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]); 402 | auto raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); 403 | auto atomic_seg_ptr = reinterpret_cast<mlx5_wqe_atomic_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr)); 404 | auto data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(atomic_seg_ptr) + sizeof(*atomic_seg_ptr)); 405 | 406 | raddr_seg.raddr = HtoBE64(raddr); 407 | raddr_seg.rkey = rkey; 408 | raddr_seg.reserved = 0; 409 | 410 | // NOTES: `0x08000000` means `IBGDA_4_BYTE_EXT_AMO_OPMOD` 411 | ctrl_seg.opmod_idx_opcode = HtoBE32(MLX5_OPCODE_ATOMIC_MASKED_FA | (wqe_idx << 8) | 0x08000000); 412 | auto atomic_32_masked_fa_seg = reinterpret_cast<ibgda_atomic_32_masked_fa_seg_t*>(&atomic_seg_1); 413 | atomic_32_masked_fa_seg->add_data = HtoBE32(value); 414 | atomic_32_masked_fa_seg->field_boundary = 0; 415 | 416 | ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 4); 417 | ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; 418 | 419 | data_seg.byte_count = HtoBE32(sizeof(int)); 420 | data_seg.lkey = lkey; 421 | data_seg.addr = HtoBE64(laddr); 422 | 423 | EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == sizeof(int4), "Invalid vectorization"); 424 | EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == sizeof(int4), "Invalid vectorization"); 425 | EP_STATIC_ASSERT(sizeof(*atomic_seg_ptr) == sizeof(int4), "Invalid vectorization"); 426 | EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == sizeof(int4), "Invalid vectorization"); 427 | st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<int4*>(&ctrl_seg)); 428 | st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<int4*>(&raddr_seg)); 429 | st_na_relaxed(reinterpret_cast<int4*>(atomic_seg_ptr), *reinterpret_cast<int4*>(&atomic_seg_1)); 430 | st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<int4*>(&data_seg)); 431 | } 432 | 433 | __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) { 434 | if (is_local_copy) { 435 | atomicAdd(static_cast<unsigned long long*>(rptr), value); 436 | } else { 437 | nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id); 438 | 439 | __be32 rkey; 440 | uint64_t raddr; 441 | ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey, qp->dev_idx); 442 | 443 | uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); 444 | void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx); 445 | 446 | ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf), 447 | qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs); 448 | 449 | ibgda_submit_requests<true>(qp, my_wqe_idx, 1); 450 | } 451 | } 452 | 453 | __device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, const int& rank, const int& dst_rank) { 454 | // Local rank, no need for mapping 455 | if (rank == dst_rank) 456 | return ptr; 457 | auto peer_base = __ldg(reinterpret_cast<uint64_t*>(nvshmemi_device_state_d.peer_heap_base_p2p) + dst_rank); 458 | 459 | // RDMA connected 460 | if (peer_base == 0) 461 | return 0; 462 | 463 | // NVLink P2P is enabled 464 | return peer_base + (ptr - reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base)); 465 | } 466 | 467 | // This is a simplified version of NVSHMEM's `ibgda_poll_cq`. 468 | // Note that this implementation does not guarantee thread safety, 469 | // so we must ensure that no other threads are concurrently using the same QP. 470 | __device__ static __forceinline__ void 471 | ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx) { 472 | const auto cqe64 = static_cast<mlx5_cqe64*>(cq->cqe); 473 | const uint32_t ncqes = cq->ncqes; 474 | memory_fence_cta(); 475 | 476 | // NOTES: this while loop is part of do-while below. 477 | // `wqe_counter` is the HW consumer index. However, we always maintain `index + 1`. 478 | // To be able to compare with the index, we need to use `wqe_counter + 1`. 479 | // Because `wqe_counter` is `uint16_t`, it may be overflow. Still, we know for 480 | // sure that if `idx - wqe_counter - 1 < ncqes`, `wqe_counter + 1 is less than 481 | // idx, and thus we need to wait. We don't need to wait when `idx == wqe_counter + 1` 482 | // That's why we use `- 2` here to make this case overflow. 483 | uint16_t wqe_counter; 484 | do { 485 | wqe_counter = HtoBE16(ld_na_relaxed(&cqe64->wqe_counter)); 486 | } while ((static_cast<uint16_t>(static_cast<uint16_t>(idx) - wqe_counter - static_cast<uint16_t>(2)) < ncqes)); 487 | *cq->cons_idx = idx; 488 | 489 | // Prevent reordering of this function and later instructions 490 | memory_fence_cta(); 491 | } 492 | 493 | // Wait until wqe `idx - 1` is completed. 494 | __device__ static __forceinline__ void 495 | nvshmemi_ibgda_quiet(int dst_pe, int qp_id) { 496 | auto qp = ibgda_get_rc(dst_pe, qp_id); 497 | auto state = ibgda_get_state(); 498 | uint64_t prod_idx = state->use_async_postsend ? ld_na_relaxed(qp->tx_wq.prod_idx) : ld_na_relaxed(&qp->mvars.tx_wq.ready_head); 499 | ibgda_poll_cq(qp->tx_wq.cq, prod_idx); 500 | } 501 | 502 | } // namespace deep_ep 503 | -------------------------------------------------------------------------------- /csrc/kernels/launch.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "configs.cuh" 4 | #include "exception.cuh" 5 | 6 | #ifndef SETUP_LAUNCH_CONFIG 7 | #ifndef DISABLE_SM90_FEATURES 8 | #define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \ 9 | cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \ 10 | cudaLaunchAttribute attr[2]; \ 11 | attr[0].id = cudaLaunchAttributeCooperative; \ 12 | attr[0].val.cooperative = 1; \ 13 | attr[1].id = cudaLaunchAttributeClusterDimension; \ 14 | attr[1].val.clusterDim.x = (num_sms % 2 == 0 ? 2 : 1); \ 15 | attr[1].val.clusterDim.y = 1; \ 16 | attr[1].val.clusterDim.z = 1; \ 17 | cfg.attrs = attr; \ 18 | cfg.numAttrs = 2 19 | #else 20 | #define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ 21 | int __num_sms = (sms); \ 22 | int __num_threads = (threads); \ 23 | auto __stream = (stream) 24 | #endif 25 | #endif 26 | 27 | #ifndef LAUNCH_KERNEL 28 | #ifndef DISABLE_SM90_FEATURES 29 | #define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__)) 30 | #else 31 | #define LAUNCH_KERNEL(config, kernel, ...) \ 32 | do { \ 33 | kernel<<<__num_sms, __num_threads, 0, __stream>>>(__VA_ARGS__); \ 34 | cudaError_t e = cudaGetLastError(); \ 35 | if (e != cudaSuccess) { \ 36 | EPException cuda_exception("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \ 37 | fprintf(stderr, "%s\n", cuda_exception.what()); \ 38 | throw cuda_exception; \ 39 | } \ 40 | } while (0) 41 | #endif 42 | #endif 43 | 44 | #ifndef SET_SHARED_MEMORY_FOR_TMA 45 | #ifndef DISABLE_SM90_FEATURES 46 | #define SET_SHARED_MEMORY_FOR_TMA(kernel) \ 47 | EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \ 48 | cfg.dynamicSmemBytes = smem_size; 49 | #else 50 | #define SET_SHARED_MEMORY_FOR_TMA(kernel) void() 51 | #endif 52 | #endif 53 | 54 | #define SWITCH_RANKS(case_macro) \ 55 | switch (num_ranks) { \ 56 | case 2: case_macro(2); \ 57 | case 4: case_macro(4); \ 58 | case 8: case_macro(8); \ 59 | default: EP_HOST_ASSERT(false and "Unsupported ranks"); \ 60 | } while (false) 61 | 62 | #define SWITCH_RDMA_RANKS(case_macro) \ 63 | switch (num_ranks / NUM_MAX_NVL_PEERS) { \ 64 | case 2: case_macro(2); \ 65 | case 4: case_macro(4); \ 66 | case 8: case_macro(8); \ 67 | case 16: case_macro(16); \ 68 | default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \ 69 | } while (false) 70 | 71 | #define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \ 72 | switch (num_ranks) { \ 73 | case 2: case_macro(dtype, 2); \ 74 | case 4: case_macro(dtype, 4); \ 75 | case 8: case_macro(dtype, 8); \ 76 | default: EP_HOST_ASSERT(false and "Unsupported ranks"); \ 77 | } while (false) 78 | 79 | #define SWITCH_TYPES(case_macro) \ 80 | switch (type) { \ 81 | case CUDA_R_16BF: case_macro(nv_bfloat16); \ 82 | default: EP_HOST_ASSERT(false and "Unsupported type"); \ 83 | } while (false) 84 | 85 | #define SWITCH_HIDDEN(case_macro) \ 86 | switch (hidden) { \ 87 | case 2048: case_macro(2048); \ 88 | case 2560: case_macro(2560); \ 89 | case 4096: case_macro(4096); \ 90 | case 5120: case_macro(5120); \ 91 | case 7168: case_macro(7168); \ 92 | case 8192: case_macro(8192); \ 93 | default: EP_HOST_ASSERT(false and "Unsupported hidden"); \ 94 | } while (false) 95 | -------------------------------------------------------------------------------- /csrc/kernels/layout.cu: -------------------------------------------------------------------------------- 1 | #include "configs.cuh" 2 | #include "exception.cuh" 3 | #include "launch.cuh" 4 | 5 | namespace deep_ep { 6 | 7 | namespace layout { 8 | 9 | template <int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM> 10 | __global__ void get_dispatch_layout(const int64_t* topk_idx, 11 | int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, 12 | int* num_tokens_per_expert, bool* is_token_in_rank, 13 | int num_tokens, int num_topk, int num_ranks, int num_experts) { 14 | auto sm_id = static_cast<int>(blockIdx.x); 15 | auto thread_id = static_cast<int>(threadIdx.x); 16 | 17 | // Count expert statistics 18 | __shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM]; 19 | int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts); 20 | if (expert_begin_idx < expert_end_idx) { 21 | // Per-thread count 22 | #pragma unroll 23 | for (int i = 0; i < kNumExpertsPerSM; ++ i) 24 | num_tokens_per_expert_per_thread[thread_id][i] = 0; 25 | #pragma unroll 26 | for (int i = thread_id; i < num_tokens; i += kNumThreads) { 27 | auto shifted_topk_idx = topk_idx + i * num_topk; 28 | #pragma unroll 29 | for (int j = 0, expert_idx; j < num_topk; ++ j) { 30 | expert_idx = static_cast<int>(shifted_topk_idx[j]); 31 | if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx) 32 | ++ num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx]; 33 | } 34 | } 35 | __syncthreads(); 36 | 37 | // Sum up 38 | EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM"); 39 | if (expert_begin_idx + thread_id < expert_end_idx) { 40 | int sum = 0; 41 | #pragma unroll 42 | for (int i = 0; i < kNumThreads; ++ i) 43 | sum += num_tokens_per_expert_per_thread[i][thread_id]; 44 | num_tokens_per_expert[expert_begin_idx + thread_id] = sum; 45 | } 46 | return; 47 | } 48 | 49 | if (num_tokens_per_rdma_rank != nullptr) 50 | EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS); 51 | 52 | // Count rank statistics 53 | constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS; 54 | __shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM]; 55 | __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM]; 56 | auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM; 57 | int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks); 58 | int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS; 59 | if (rank_begin_idx < rank_end_idx) { 60 | const auto num_expert_per_rank = num_experts / num_ranks; 61 | auto expert_begin = rank_begin_idx * num_expert_per_rank; 62 | auto expert_end = rank_end_idx * num_expert_per_rank; 63 | 64 | // Per-thread count 65 | #pragma unroll 66 | for (int i = 0; i < kNumRanksPerSM; ++ i) 67 | num_tokens_per_rank_per_thread[thread_id][i] = 0; 68 | #pragma unroll 69 | for (int i = 0; i < kNumRDMARanksPerSM; ++ i) 70 | num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0; 71 | #pragma unroll 72 | for (int i = thread_id; i < num_tokens; i += kNumThreads) { 73 | auto shifted_topk_idx = topk_idx + i * num_topk; 74 | int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0}; 75 | #pragma unroll 76 | for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) { 77 | expert_idx = static_cast<int>(shifted_topk_idx[j]); 78 | if (expert_begin <= expert_idx and expert_idx < expert_end) { 79 | // Count single rank 80 | rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx; 81 | is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++; 82 | } 83 | } 84 | 85 | auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks; 86 | #pragma unroll 87 | for (int j = 0; j + rank_begin_idx < rank_end_idx; ++ j) { 88 | shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0); 89 | num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0); 90 | } 91 | 92 | #pragma unroll 93 | for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j) 94 | num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0); 95 | } 96 | __syncthreads(); 97 | 98 | // Sum up 99 | EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM"); 100 | if (rank_begin_idx + thread_id < rank_end_idx) { 101 | int sum = 0; 102 | #pragma unroll 103 | for (int i = 0; i < kNumThreads; ++ i) 104 | sum += num_tokens_per_rank_per_thread[i][thread_id]; 105 | num_tokens_per_rank[rank_begin_idx + thread_id] = sum; 106 | } 107 | 108 | if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) { 109 | int sum = 0; 110 | #pragma unroll 111 | for (int i = 0; i < kNumThreads; ++ i) 112 | sum += num_tokens_per_rdma_rank_per_thread[i][thread_id]; 113 | num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum; 114 | } 115 | } 116 | } 117 | 118 | void get_dispatch_layout(const int64_t* topk_idx, 119 | int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, 120 | int* num_tokens_per_expert, bool* is_token_in_rank, 121 | int num_tokens, int num_topk, int num_ranks, int num_experts, 122 | cudaStream_t stream) { 123 | constexpr int kNumThreads = 256, kNumExpertsPerSM = 4, kNumRanksPerSM = 8; 124 | int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM; 125 | EP_STATIC_ASSERT(kNumRanksPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of ranks per SM"); 126 | 127 | SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); 128 | LAUNCH_KERNEL(&cfg, (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>), 129 | topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, 130 | num_tokens, num_topk, num_ranks, num_experts); 131 | } 132 | 133 | } // namespace layout 134 | 135 | } // namespace deep_ep 136 | -------------------------------------------------------------------------------- /csrc/kernels/runtime.cu: -------------------------------------------------------------------------------- 1 | #include <vector> 2 | #include <cstring> 3 | 4 | #include "configs.cuh" 5 | #include "exception.cuh" 6 | #include "launch.cuh" 7 | #include "utils.cuh" 8 | 9 | #ifndef DISABLE_NVSHMEM 10 | #include "nvshmem.h" 11 | #include "ibgda_device.cuh" 12 | #endif 13 | 14 | namespace deep_ep { 15 | 16 | namespace intranode { 17 | 18 | template<int kNumRanks> 19 | __global__ void barrier(int** barrier_signal_ptrs, int rank) { 20 | barrier_block<kNumRanks>(barrier_signal_ptrs, rank); 21 | } 22 | 23 | void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream) { 24 | #define BARRIER_LAUNCH_CASE(ranks) \ 25 | LAUNCH_KERNEL(&cfg, barrier<ranks>, barrier_signal_ptrs, rank); \ 26 | break 27 | 28 | SETUP_LAUNCH_CONFIG(1, 32, stream); 29 | SWITCH_RANKS(BARRIER_LAUNCH_CASE); 30 | #undef BARRIER_LAUNCH_CASE 31 | } 32 | 33 | } // namespace intranode 34 | 35 | namespace internode { 36 | 37 | #ifndef DISABLE_NVSHMEM 38 | nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID; 39 | nvshmem_team_config_t cpu_rdma_team_config; 40 | 41 | std::vector<uint8_t> get_unique_id() { 42 | nvshmemx_uniqueid_t unique_id; 43 | nvshmemx_get_uniqueid(&unique_id); 44 | std::vector<uint8_t> result(sizeof(nvshmemx_uniqueid_t)); 45 | std::memcpy(result.data(), &unique_id, sizeof(nvshmemx_uniqueid_t)); 46 | return result; 47 | } 48 | 49 | int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) { 50 | nvshmemx_uniqueid_t root_unique_id; 51 | nvshmemx_init_attr_t attr; 52 | std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(nvshmemx_uniqueid_t)); 53 | nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr); 54 | nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); 55 | 56 | // Create sub-RDMA teams 57 | // NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used 58 | if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) { 59 | EP_HOST_ASSERT(cpu_rdma_team == NVSHMEM_TEAM_INVALID); 60 | EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); 61 | EP_HOST_ASSERT(nvshmem_team_split_strided(NVSHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS, NUM_MAX_NVL_PEERS, 62 | num_ranks / NUM_MAX_NVL_PEERS, &cpu_rdma_team_config, 0, &cpu_rdma_team) == 0); 63 | EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID); 64 | } 65 | 66 | nvshmem_barrier_all(); 67 | return nvshmem_my_pe(); 68 | } 69 | 70 | void* alloc(size_t size, size_t alignment) { 71 | return nvshmem_align(alignment, size); 72 | } 73 | 74 | void free(void* ptr) { 75 | nvshmem_free(ptr); 76 | } 77 | 78 | void barrier() { 79 | nvshmem_barrier_all(); 80 | } 81 | 82 | void finalize() { 83 | if (cpu_rdma_team != NVSHMEM_TEAM_INVALID) { 84 | nvshmem_team_destroy(cpu_rdma_team); 85 | cpu_rdma_team = NVSHMEM_TEAM_INVALID; 86 | } 87 | nvshmem_finalize(); 88 | } 89 | #endif 90 | 91 | } // namespace internode 92 | 93 | } // namespace deep_ep 94 | -------------------------------------------------------------------------------- /csrc/kernels/utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "exception.cuh" 4 | 5 | #define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \ 6 | { \ 7 | constexpr int kLoopStride = 32 * (UNROLL_FACTOR); \ 8 | typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)]; \ 9 | auto __src = (SRC); \ 10 | auto __dst = (DST); \ 11 | for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \ 12 | _Pragma("unroll") \ 13 | for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \ 14 | unrolled_values[__j] = LD_FUNC(__src + __i + __j * 32); \ 15 | _Pragma("unroll") \ 16 | for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \ 17 | ST_FUNC(__dst + __i + __j * 32, unrolled_values[__j]); \ 18 | } \ 19 | for (int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += 32) \ 20 | ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \ 21 | } 22 | 23 | namespace deep_ep { 24 | 25 | template <int kBytes> 26 | struct VecInt {}; 27 | template<> struct VecInt<1> { using vec_t = int8_t; }; 28 | template<> struct VecInt<2> { using vec_t = int16_t; }; 29 | template<> struct VecInt<4> { using vec_t = int; }; 30 | template<> struct VecInt<8> { using vec_t = int64_t; }; 31 | template<> struct VecInt<16> { using vec_t = int4; }; 32 | 33 | template <typename FuncT> 34 | struct PatternVisitor { 35 | FuncT func; 36 | 37 | __device__ __host__ 38 | explicit PatternVisitor(FuncT&& func): func(std::forward<FuncT>(func)) {} 39 | 40 | __device__ __host__ 41 | auto operator [](const uint32_t& i) { 42 | return func(i); 43 | } 44 | }; 45 | 46 | __device__ __forceinline__ void trap() { 47 | asm("trap;"); 48 | } 49 | 50 | __device__ __forceinline__ void memory_fence() { 51 | asm volatile("fence.acq_rel.sys;":: : "memory"); 52 | } 53 | 54 | __device__ __forceinline__ void memory_fence_gpu() { 55 | asm volatile("fence.acq_rel.gpu;":: : "memory"); 56 | } 57 | 58 | __device__ __forceinline__ void memory_fence_cta() { 59 | asm volatile("fence.acq_rel.cta;":: : "memory"); 60 | } 61 | 62 | __device__ __forceinline__ void st_relaxed_sys_global(const int *ptr, int val) { 63 | asm volatile("st.relaxed.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory"); 64 | } 65 | 66 | __device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) { 67 | asm volatile("st.release.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory"); 68 | } 69 | 70 | __device__ __forceinline__ void st_release_cta(const int *ptr, int val) { 71 | asm volatile("st.release.cta.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory"); 72 | } 73 | 74 | __device__ __forceinline__ int ld_acquire_sys_global(const int *ptr) { 75 | int ret; 76 | asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); 77 | return ret; 78 | } 79 | 80 | __device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t *ptr) { 81 | uint64_t ret; 82 | asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); 83 | return ret; 84 | } 85 | 86 | __device__ __forceinline__ int ld_acquire_global(const int *ptr) { 87 | int ret; 88 | asm volatile("ld.acquire.gpu.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); 89 | return ret; 90 | } 91 | 92 | __device__ __forceinline__ int atomic_add_release_sys_global(const int* ptr, int value) { 93 | int ret; 94 | asm volatile("atom.add.release.sys.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); 95 | return ret; 96 | } 97 | 98 | __device__ __forceinline__ int atomic_add_release_global(const int* ptr, int value) { 99 | int ret; 100 | asm volatile("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); 101 | return ret; 102 | } 103 | 104 | __device__ __forceinline__ int ld_acquire_cta(const int *ptr) { 105 | int ret; 106 | asm volatile("ld.acquire.cta.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); 107 | return ret; 108 | } 109 | 110 | __device__ __forceinline__ uint8_t ld_na_relaxed(const uint8_t *ptr) { 111 | uint16_t ret; 112 | asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr)); 113 | return static_cast<uint8_t>(ret); 114 | } 115 | 116 | __device__ __forceinline__ uint16_t ld_na_relaxed(const uint16_t *ptr) { 117 | uint16_t ret; 118 | asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr)); 119 | return ret; 120 | } 121 | 122 | __device__ __forceinline__ uint32_t ld_na_relaxed(const uint32_t *ptr) { 123 | uint32_t ret; 124 | asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr)); 125 | return ret; 126 | } 127 | 128 | __device__ __forceinline__ uint64_t ld_na_relaxed(const uint64_t *ptr) { 129 | uint64_t ret; 130 | asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); 131 | return ret; 132 | } 133 | 134 | __device__ __forceinline__ int ld_volatile_global(const int *ptr) { 135 | int ret; 136 | asm volatile("ld.volatile.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); 137 | return ret; 138 | } 139 | 140 | __device__ __forceinline__ float ld_volatile_global(const float *ptr) { 141 | float ret; 142 | asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); 143 | return ret; 144 | } 145 | 146 | __device__ __forceinline__ int64_t ld_volatile_global(const int64_t *ptr) { 147 | int64_t ret; 148 | asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr)); 149 | return ret; 150 | } 151 | 152 | __device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) { 153 | int64_t ret; 154 | asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); 155 | return ret; 156 | } 157 | 158 | #ifndef DISABLE_AGGRESSIVE_PTX_INSTRS 159 | #define LD_NC_FUNC "ld.global.nc.L1::no_allocate.L2::256B" 160 | #else 161 | #define LD_NC_FUNC "ld.volatile.global.L2::256B" 162 | #endif 163 | 164 | // `ld.global.nc.L1::no_allocate` will be translated into `LDG.E.NA.[width].CONSTANT` in SASS 165 | template <typename dtype_t> 166 | __device__ __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) { 167 | auto ret = ld_nc_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr)); 168 | return *reinterpret_cast<dtype_t*>(&ret); 169 | } 170 | 171 | template <> 172 | __device__ __forceinline__ uint8_t ld_nc_global(const uint8_t *ptr) { 173 | uint16_t ret; 174 | // NOTES: we must use `uint16_t` as inline ASM does not support 8-bit constraint letter (`h` below means unsigned 16-bit) 175 | asm volatile(LD_NC_FUNC ".u8 %0, [%1];" : "=h"(ret) : "l"(ptr)); 176 | return static_cast<uint8_t>(ret); 177 | } 178 | 179 | template <> 180 | __device__ __forceinline__ int ld_nc_global(const int *ptr) { 181 | int ret; 182 | asm volatile(LD_NC_FUNC ".s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); 183 | return ret; 184 | } 185 | 186 | template <> 187 | __device__ __forceinline__ int64_t ld_nc_global(const int64_t *ptr) { 188 | int64_t ret; 189 | asm volatile(LD_NC_FUNC ".s64 %0, [%1];" : "=l"(ret) : "l"(ptr)); 190 | return ret; 191 | } 192 | 193 | template <> 194 | __device__ __forceinline__ float ld_nc_global(const float *ptr) { 195 | float ret; 196 | asm volatile(LD_NC_FUNC ".f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); 197 | return ret; 198 | } 199 | 200 | template <> 201 | __device__ __forceinline__ int2 ld_nc_global(const int2 *ptr) { 202 | int2 ret; 203 | asm volatile(LD_NC_FUNC ".v2.s32 {%0, %1}, [%2];" : "=r"(ret.x), "=r"(ret.y) : "l"(ptr)); 204 | return ret; 205 | } 206 | 207 | template <> 208 | __device__ __forceinline__ int4 ld_nc_global(const int4 *ptr) { 209 | int4 ret; 210 | asm volatile(LD_NC_FUNC ".v4.s32 {%0, %1, %2, %3}, [%4];" 211 | : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr)); 212 | return ret; 213 | } 214 | 215 | __device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) { 216 | asm volatile("st.relaxed.gpu.global.L1::no_allocate.b8 [%0], %1;" : : "l"(ptr), "h"(static_cast<uint16_t>(val))); 217 | } 218 | 219 | __device__ __forceinline__ void st_na_relaxed(const uint16_t *ptr, uint16_t val) { 220 | asm volatile("st.relaxed.gpu.global.L1::no_allocate.b16 [%0], %1;" : : "l"(ptr), "h"(val)); 221 | } 222 | 223 | __device__ __forceinline__ void st_na_relaxed(const uint32_t *ptr, uint32_t val) { 224 | asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); 225 | } 226 | 227 | __device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) { 228 | asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); 229 | } 230 | 231 | __device__ __forceinline__ void st_na_relaxed(const int4 *ptr, int4 val) { 232 | asm volatile("st.relaxed.gpu.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};" 233 | : : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); 234 | } 235 | 236 | __device__ __forceinline__ void st_na_release(const int *ptr, int val) { 237 | asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); 238 | } 239 | 240 | __device__ __forceinline__ void st_na_release(const uint32_t *ptr, uint32_t val) { 241 | asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); 242 | } 243 | 244 | __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) { 245 | asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val)); 246 | } 247 | 248 | // `st.global.L1::no_allocate` will be translated into `ST.E.NA.[width]` in SASS 249 | #ifndef DISABLE_AGGRESSIVE_PTX_INSTRS 250 | #define ST_NA_FUNC "st.global.L1::no_allocate" 251 | #else 252 | #define ST_NA_FUNC "st.global" 253 | #endif 254 | 255 | template <typename dtype_t> 256 | __device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t& value) { 257 | st_na_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr), 258 | *reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(&value)); 259 | } 260 | 261 | template <> 262 | __device__ __forceinline__ void st_na_global(const int *ptr, const int& value) { 263 | asm volatile(ST_NA_FUNC ".s32 [%0], %1;" ::"l"(ptr), "r"(value)); 264 | } 265 | 266 | template <> 267 | __device__ __forceinline__ void st_na_global(const int64_t *ptr, const int64_t& value) { 268 | asm volatile(ST_NA_FUNC ".s64 [%0], %1;" ::"l"(ptr), "l"(value)); 269 | } 270 | 271 | template <> 272 | __device__ __forceinline__ void st_na_global(const float *ptr, const float& value) { 273 | asm volatile(ST_NA_FUNC ".f32 [%0], %1;" ::"l"(ptr), "f"(value)); 274 | } 275 | 276 | template <> 277 | __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value) { 278 | asm volatile(ST_NA_FUNC ".v4.s32 [%0], {%1, %2, %3, %4};" 279 | ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w)); 280 | } 281 | 282 | __device__ __forceinline__ float log2f_approx(const float &x) { 283 | float ret; 284 | asm volatile("lg2.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); 285 | return ret; 286 | } 287 | 288 | __device__ __forceinline__ float exp2f_approx(const float &x) { 289 | float ret; 290 | asm volatile("ex2.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); 291 | return ret; 292 | } 293 | 294 | // TMA PTX instructions 295 | #ifndef DISABLE_SM90_FEATURES 296 | 297 | __device__ __forceinline__ uint32_t elect_one_sync(int lane_id) { 298 | uint32_t pred = 0; 299 | asm volatile( 300 | "{\n" 301 | ".reg .b32 %%rx;\n" 302 | ".reg .pred %%px;\n" 303 | " elect.sync %%rx|%%px, %2;\n" 304 | "@%%px mov.s32 %1, 1;\n" 305 | " mov.s32 %0, %%rx;\n" 306 | "}\n" 307 | : "+r"(lane_id), "+r"(pred) 308 | : "r"(0xffffffff)); 309 | return pred; 310 | } 311 | 312 | __device__ __forceinline__ void fence_view_async_shared() { 313 | asm volatile("fence.proxy.async.shared::cta; \n" :: ); 314 | } 315 | 316 | __device__ __forceinline__ void fence_barrier_init() { 317 | asm volatile("fence.mbarrier_init.release.cluster; \n" :: ); 318 | } 319 | 320 | __device__ __forceinline__ void mbarrier_init(uint64_t* mbar_ptr, uint32_t arrive_count) { 321 | auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr)); 322 | asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" :: "r"(arrive_count), "r"(mbar_int_ptr)); 323 | } 324 | 325 | __device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phase) { 326 | auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr)); 327 | asm volatile("{\n\t" 328 | ".reg .pred P1; \n\t" 329 | "LAB_WAIT: \n\t" 330 | "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" 331 | "@P1 bra DONE; \n\t" 332 | "bra LAB_WAIT; \n\t" 333 | "DONE: \n\t" 334 | "}" :: "r"(mbar_int_ptr), "r"(phase), "r"(0x989680)); 335 | phase ^= 1; 336 | } 337 | 338 | __device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr, int num_bytes) { 339 | auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr)); 340 | asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: "r"(num_bytes), "r"(mbar_int_ptr)); 341 | } 342 | 343 | __device__ __forceinline__ void tma_store_fence() { 344 | asm volatile ("fence.proxy.async.shared::cta;"); 345 | } 346 | 347 | constexpr uint64_t kEvictFirst = 0x12f0000000000000; 348 | constexpr uint64_t kEvictNormal = 0x1000000000000000; 349 | 350 | __device__ __forceinline__ void tma_load_1d(const void* smem_ptr, const void* gmem_ptr, uint64_t* mbar_ptr, int num_bytes, 351 | bool evict_first = true) { 352 | auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr)); 353 | auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); 354 | const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal; 355 | asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n" 356 | :: "r"(smem_int_ptr), "l"(gmem_ptr), "r"(num_bytes), "r"(mbar_int_ptr), "l"(cache_hint) : "memory"); 357 | } 358 | 359 | __device__ __forceinline__ void tma_store_1d(const void* smem_ptr, const void* gmem_ptr, int num_bytes, 360 | bool evict_first = true) { 361 | auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); 362 | const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal; 363 | asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n" 364 | :: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(num_bytes), "l"(cache_hint) : "memory"); 365 | asm volatile("cp.async.bulk.commit_group;"); 366 | } 367 | 368 | template <int N = 0> 369 | __device__ __forceinline__ void tma_store_wait() { 370 | asm volatile("cp.async.bulk.wait_group.read %0;" :: "n"(N) : "memory"); 371 | } 372 | 373 | #endif 374 | 375 | template <typename dtype_t> 376 | __host__ __device__ constexpr dtype_t ceil_div(dtype_t a, dtype_t b) { 377 | return (a + b - 1) / b; 378 | } 379 | 380 | template <typename dtype_t> 381 | __host__ __device__ constexpr dtype_t align(dtype_t a, dtype_t b) { 382 | return ceil_div<dtype_t>(a, b) * b; 383 | } 384 | 385 | __forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id, 386 | int& token_start_idx, int& token_end_idx) { 387 | int num_tokens_per_sm = ceil_div(num_tokens, num_sms); 388 | token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens); 389 | token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens); 390 | } 391 | 392 | template <typename dtype_a_t, typename dtype_b_t> 393 | __device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) { 394 | EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes"); 395 | dtype_b_t packed; 396 | auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed); 397 | unpacked_ptr[0] = x, unpacked_ptr[1] = y; 398 | return packed; 399 | } 400 | 401 | template <typename dtype_a_t, typename dtype_b_t> 402 | __device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) { 403 | EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes"); 404 | auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed); 405 | x = unpacked_ptr[0], y = unpacked_ptr[1]; 406 | } 407 | 408 | template <typename dtype_t> 409 | __device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) { 410 | EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, ""); 411 | auto send_int_values = reinterpret_cast<int*>(&ptr); 412 | int recv_int_values[sizeof(dtype_t) / sizeof(int)]; 413 | #pragma unroll 414 | for (int i = 0; i < sizeof(dtype_t) / sizeof(int); ++ i) 415 | recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], src_lane_idx); 416 | return *reinterpret_cast<dtype_t*>(recv_int_values); 417 | } 418 | 419 | __forceinline__ __device__ int get_lane_id() { 420 | int lane_id; 421 | asm("mov.s32 %0, %laneid;" : "=r"(lane_id)); 422 | return lane_id; 423 | } 424 | 425 | constexpr float kFP8Margin = 1e-4; 426 | constexpr float kFinfoAmaxE4M3 = 448.0f; 427 | constexpr float kFinfoAmaxInvE4M3 = 1 / 448.0f; 428 | 429 | __forceinline__ __device__ float fast_pow2(int x) { 430 | // We can ensure `-126 <= x and x <= 127` 431 | uint32_t bits_x = (x + 127) << 23; 432 | return *reinterpret_cast<float*>(&bits_x); 433 | } 434 | 435 | __forceinline__ __device__ int fast_log2_ceil(float x) { 436 | auto bits_x = *reinterpret_cast<uint32_t*>(&x); 437 | auto exp_x = (bits_x >> 23) & 0xff; 438 | auto man_bits = bits_x & ((1 << 23) - 1); 439 | return exp_x - 127 + (man_bits != 0); 440 | } 441 | 442 | __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) { 443 | if (round_scale) { 444 | auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3); 445 | scale = fast_pow2(-exp_scale_inv); 446 | scale_inv = fast_pow2(exp_scale_inv); 447 | } else { 448 | scale_inv = amax * kFinfoAmaxInvE4M3; 449 | scale = kFinfoAmaxE4M3 / amax; 450 | } 451 | } 452 | 453 | template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>> 454 | __forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) { 455 | if constexpr (kIsUE8M0) { 456 | return static_cast<uint8_t>((*reinterpret_cast<uint32_t*>(&value)) >> 23); 457 | } else { 458 | return value; 459 | } 460 | } 461 | 462 | template <int kNumRanks, bool kSyncOnly = false> 463 | __forceinline__ __device__ void 464 | barrier_block(int** barrier_signal_ptrs, int rank) { 465 | auto thread_id = static_cast<int>(threadIdx.x); 466 | 467 | // For non-sync-only cases, the memory operations by other threads in the block must be visible to the `sys` scope 468 | if constexpr (not kSyncOnly) { 469 | memory_fence(); 470 | __syncthreads(); 471 | } 472 | 473 | // Add self-ranks, sub other ranks 474 | if (thread_id < kNumRanks) { 475 | atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG); 476 | atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG); 477 | } 478 | EP_DEVICE_ASSERT(kNumRanks <= blockDim.x); 479 | 480 | // Check timeout 481 | auto start_time = clock64(); 482 | while (true) { 483 | auto value = thread_id < kNumRanks ? ld_volatile_global(barrier_signal_ptrs[rank] + thread_id) : 0; 484 | if (__all_sync(0xffffffff, value <= 0)) 485 | break; 486 | 487 | if (clock64() - start_time > NUM_TIMEOUT_CYCLES and thread_id < kNumRanks) { 488 | printf("DeepEP timeout check failed: rank = %d, thread = %d, value = %d)\n", rank, thread_id, value); 489 | trap(); 490 | } 491 | } 492 | __syncthreads(); 493 | } 494 | 495 | __forceinline__ __device__ int atomic_cas_cta_acquire(int* addr, int x, int y) { 496 | int ret; 497 | asm volatile("atom.acquire.cta.shared::cta.cas.b32 %0, [%1], %2, %3;" : "=r"(ret) : "l"(addr), "r"(x), "r"(y) : "memory"); 498 | return ret; 499 | } 500 | 501 | __forceinline__ __device__ int atomic_exch_cta_release(int* addr, int x) { 502 | int ret; 503 | asm volatile("atom.release.cta.shared::cta.exch.b32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(x) : "memory"); 504 | return ret; 505 | } 506 | 507 | __forceinline__ __device__ void acquire_lock(int* mutex) { 508 | // To make later memory operations valid, we must use `acquire` for memory semantics 509 | while (atomic_cas_cta_acquire(mutex, 0, 1) != 0); 510 | } 511 | 512 | __forceinline__ __device__ void release_lock(int* mutex) { 513 | // To make previous memory operations visible to other threads, we must use `release` for memory semantics 514 | atomic_exch_cta_release(mutex, 0); 515 | } 516 | 517 | // Operation functors 518 | template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } }; 519 | template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } }; 520 | template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } }; 521 | 522 | // Unified reduction function 523 | template <uint32_t kNumLanes, typename T, typename Op> 524 | __forceinline__ __device__ T warp_reduce(T value, Op op) { 525 | EP_STATIC_ASSERT(kNumLanes == 32 or kNumLanes == 16 or kNumLanes == 8 or 526 | kNumLanes == 4 or kNumLanes == 2 or kNumLanes == 1, 527 | "Invalid number of lanes"); 528 | 529 | if constexpr (kNumLanes >= 32) value = op(value, __shfl_xor_sync(0xffffffff, value, 16)); 530 | if constexpr (kNumLanes >= 16) value = op(value, __shfl_xor_sync(0xffffffff, value, 8)); 531 | if constexpr (kNumLanes >= 8) value = op(value, __shfl_xor_sync(0xffffffff, value, 4)); 532 | if constexpr (kNumLanes >= 4) value = op(value, __shfl_xor_sync(0xffffffff, value, 2)); 533 | if constexpr (kNumLanes >= 2) value = op(value, __shfl_xor_sync(0xffffffff, value, 1)); 534 | return value; 535 | } 536 | 537 | // Convenience aliases 538 | template < uint32_t kNumLanes = 32, typename T> 539 | __forceinline__ __device__ T warp_reduce_sum(T value) { 540 | return warp_reduce<kNumLanes, T>(value, ReduceSum<T>{}); 541 | } 542 | 543 | template <uint32_t kNumLanes = 32, typename T> 544 | __forceinline__ __device__ T warp_reduce_max(T value) { 545 | return warp_reduce<kNumLanes, T>(value, ReduceMax<T>{}); 546 | } 547 | 548 | template <uint32_t kNumLanes = 32, typename T> 549 | __forceinline__ __device__ T warp_reduce_min(T value) { 550 | return warp_reduce<kNumLanes, T>(value, ReduceMin<T>{}); 551 | } 552 | 553 | } // namespace deep_ep 554 | -------------------------------------------------------------------------------- /deep_ep/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .utils import EventOverlap 4 | from .buffer import Buffer 5 | 6 | # noinspection PyUnresolvedReferences 7 | from deep_ep_cpp import Config 8 | -------------------------------------------------------------------------------- /deep_ep/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import torch 4 | import torch.distributed as dist 5 | from typing import Any, Optional, Tuple 6 | 7 | # noinspection PyUnresolvedReferences 8 | from deep_ep_cpp import Config, EventHandle 9 | 10 | 11 | class EventOverlap: 12 | """ 13 | A wrapper class to manage CUDA events, also for better overlapping convenience. 14 | 15 | Attributes: 16 | event: the CUDA event captured. 17 | extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph. 18 | """ 19 | 20 | def __init__(self, event: Optional[EventHandle] = None, 21 | extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None: 22 | """ 23 | Initialize the class. 24 | 25 | Arguments: 26 | event: the CUDA event captured. 27 | extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph. 28 | """ 29 | self.event = event 30 | 31 | # NOTES: we use extra tensors to achieve stream recording, otherwise, 32 | # stream recording will be incompatible with CUDA graph. 33 | self.extra_tensors = extra_tensors 34 | 35 | def current_stream_wait(self) -> None: 36 | """ 37 | The current stream `torch.cuda.current_stream()` waits for the event to be finished. 38 | """ 39 | assert self.event is not None 40 | self.event.current_stream_wait() 41 | 42 | def __enter__(self) -> Any: 43 | """ 44 | Utility for overlapping and Python `with` syntax. 45 | 46 | You can overlap the kernels on the current stream with the following example: 47 | ```python 48 | event_overlap = event_after_all_to_all_kernels() 49 | with event_overlap(): 50 | do_something_on_current_stream() 51 | # After exiting the `with` scope, the current stream with wait the event to be finished. 52 | ``` 53 | """ 54 | return self 55 | 56 | def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 57 | """ 58 | Utility for overlapping and Python `with` syntax. 59 | 60 | Please follow the example in the `__enter__` function. 61 | """ 62 | if self.event is not None: 63 | self.event.current_stream_wait() 64 | 65 | 66 | def check_nvlink_connections(group: dist.ProcessGroup): 67 | """ 68 | Check NVLink connection between every pair of GPUs. 69 | 70 | Arguments: 71 | group: the communication group. 72 | """ 73 | # Check NVLink connection 74 | # NOTES: some A100 PCIE GPUs only have pairwise NVLink connection, so that we can only use EP2 75 | # TODO: check all cases, all local-node GPUs in the group should be connected via NVLink 76 | if 'PCIE' in torch.cuda.get_device_name(): 77 | assert group.size() <= 2, 'PCIe GPUs only have pairwise NVLink connections' 78 | 79 | # noinspection PyUnresolvedReferences 80 | import pynvml 81 | pynvml.nvmlInit() 82 | 83 | # noinspection PyTypeChecker 84 | devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0,1,2,3,4,5,6,7').strip(',').split(',') 85 | physical_device_idx = int(devices[torch.cuda.current_device()]) 86 | physical_device_indices = [0, ] * group.size() 87 | dist.all_gather_object(physical_device_indices, physical_device_idx, group) 88 | 89 | # Check whether they are all connected via NVLink 90 | # Reference: https://github.com/vllm-project/vllm/blob/b8e809a057765c574726a6077fd124db5077ce1f/vllm/platforms/cuda.py#L438 91 | handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_indices] 92 | for i, handle in enumerate(handles): 93 | for j, peer_handle in enumerate(handles): 94 | if i >= j: 95 | continue 96 | status = pynvml.nvmlDeviceGetP2PStatus(handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) 97 | assert status == pynvml.NVML_P2P_STATUS_OK,\ 98 | f'GPU {physical_device_indices[i]} and GPU {physical_device_indices[j]} are not connected via NVLink' 99 | 100 | # Close NVML 101 | pynvml.nvmlShutdown() 102 | -------------------------------------------------------------------------------- /figures/low-latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DeepEP/0eee87b8ca19e255b70b5631b8cc42ca1b88c8d7/figures/low-latency.png -------------------------------------------------------------------------------- /figures/normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DeepEP/0eee87b8ca19e255b70b5631b8cc42ca1b88c8d7/figures/normal.png -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | # Change current directory into project root 2 | original_dir=$(pwd) 3 | script_dir=$(dirname "$0") 4 | cd "$script_dir" 5 | 6 | # Remove old dist file, build, and install 7 | rm -rf dist 8 | python setup.py bdist_wheel 9 | pip install dist/*.whl 10 | 11 | # Open users' original directory 12 | cd "$original_dir" 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import setuptools 4 | import importlib 5 | import importlib.resources 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | # Wheel specific: The wheels only include the soname of the host library (libnvshmem_host.so.X) 9 | def get_nvshmem_host_lib_name(): 10 | for path in importlib.resources.files('nvidia.nvshmem').iterdir(): 11 | for file in path.rglob('libnvshmem_host.so.*'): 12 | return file.name 13 | raise ModuleNotFoundError('libnvshmem_host.so not found') 14 | 15 | if __name__ == '__main__': 16 | disable_nvshmem = False 17 | nvshmem_dir = os.getenv('NVSHMEM_DIR', None) 18 | nvshmem_host_lib = 'libnvshmem_host.so' 19 | if nvshmem_dir is None: 20 | try: 21 | nvshmem_dir = importlib.util.find_spec("nvidia.nvshmem").submodule_search_locations[0] 22 | nvshmem_host_lib = get_nvshmem_host_lib_name() 23 | import nvidia.nvshmem as nvshmem 24 | except (ModuleNotFoundError, AttributeError, IndexError): 25 | print('Warning: `NVSHMEM_DIR` is not specified, and the NVSHMEM module is not installed. All internode and low-latency features are disabled\n') 26 | disable_nvshmem = True 27 | else: 28 | disable_nvshmem = False 29 | 30 | if not disable_nvshmem: 31 | assert os.path.exists(nvshmem_dir), f'The specified NVSHMEM directory does not exist: {nvshmem_dir}' 32 | 33 | cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable', 34 | '-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes'] 35 | nvcc_flags = ['-O3', '-Xcompiler', '-O3'] 36 | sources = ['csrc/deep_ep.cpp', 'csrc/kernels/runtime.cu', 'csrc/kernels/layout.cu', 'csrc/kernels/intranode.cu'] 37 | include_dirs = ['csrc/'] 38 | library_dirs = [] 39 | nvcc_dlink = [] 40 | extra_link_args = [] 41 | 42 | # NVSHMEM flags 43 | if disable_nvshmem: 44 | cxx_flags.append('-DDISABLE_NVSHMEM') 45 | nvcc_flags.append('-DDISABLE_NVSHMEM') 46 | else: 47 | sources.extend(['csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu']) 48 | include_dirs.extend([f'{nvshmem_dir}/include']) 49 | library_dirs.extend([f'{nvshmem_dir}/lib']) 50 | nvcc_dlink.extend(['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem_device']) 51 | extra_link_args.extend([f'-l:{nvshmem_host_lib}', '-l:libnvshmem_device.a', f'-Wl,-rpath,{nvshmem_dir}/lib']) 52 | 53 | if int(os.getenv('DISABLE_SM90_FEATURES', 0)): 54 | # Prefer A100 55 | os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '8.0') 56 | 57 | # Disable some SM90 features: FP8, launch methods, and TMA 58 | cxx_flags.append('-DDISABLE_SM90_FEATURES') 59 | nvcc_flags.append('-DDISABLE_SM90_FEATURES') 60 | 61 | # Disable internode and low-latency kernels 62 | assert disable_nvshmem 63 | else: 64 | # Prefer H800 series 65 | os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '9.0') 66 | 67 | # CUDA 12 flags 68 | nvcc_flags.extend(['-rdc=true', '--ptxas-options=--register-usage-level=10']) 69 | 70 | # Disable LD/ST tricks, as some CUDA version does not support `.L1::no_allocate` 71 | if os.environ['TORCH_CUDA_ARCH_LIST'].strip() != '9.0': 72 | assert int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', 1)) == 1 73 | os.environ['DISABLE_AGGRESSIVE_PTX_INSTRS'] = '1' 74 | 75 | # Disable aggressive PTX instructions 76 | if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '1')): 77 | cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS') 78 | nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS') 79 | 80 | # Put them together 81 | extra_compile_args = { 82 | 'cxx': cxx_flags, 83 | 'nvcc': nvcc_flags, 84 | } 85 | if len(nvcc_dlink) > 0: 86 | extra_compile_args['nvcc_dlink'] = nvcc_dlink 87 | 88 | # Summary 89 | print(f'Build summary:') 90 | print(f' > Sources: {sources}') 91 | print(f' > Includes: {include_dirs}') 92 | print(f' > Libraries: {library_dirs}') 93 | print(f' > Compilation flags: {extra_compile_args}') 94 | print(f' > Link flags: {extra_link_args}') 95 | print(f' > Arch list: {os.environ["TORCH_CUDA_ARCH_LIST"]}') 96 | print(f' > NVSHMEM path: {nvshmem_dir}') 97 | print() 98 | 99 | # noinspection PyBroadException 100 | try: 101 | cmd = ['git', 'rev-parse', '--short', 'HEAD'] 102 | revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() 103 | except Exception as _: 104 | revision = '' 105 | 106 | setuptools.setup( 107 | name='deep_ep', 108 | version='1.1.0' + revision, 109 | packages=setuptools.find_packages( 110 | include=['deep_ep'] 111 | ), 112 | ext_modules=[ 113 | CUDAExtension( 114 | name='deep_ep_cpp', 115 | include_dirs=include_dirs, 116 | library_dirs=library_dirs, 117 | sources=sources, 118 | extra_compile_args=extra_compile_args, 119 | extra_link_args=extra_link_args 120 | ) 121 | ], 122 | cmdclass={ 123 | 'build_ext': BuildExtension 124 | } 125 | ) 126 | -------------------------------------------------------------------------------- /tests/test_internode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import torch 5 | import torch.distributed as dist 6 | 7 | # noinspection PyUnresolvedReferences 8 | import deep_ep 9 | from utils import init_dist, bench, bench_kineto, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back 10 | 11 | # Test compatibility with low latency functions 12 | import test_low_latency 13 | 14 | 15 | # noinspection PyShadowingNames 16 | def test_main(args: argparse.Namespace, num_sms: int, 17 | local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, 18 | buffer: deep_ep.Buffer, group: dist.ProcessGroup): 19 | # Settings 20 | num_tokens, hidden = args.num_tokens, args.hidden 21 | num_topk_groups, num_topk, num_experts = args.num_topk_groups, args.num_topk, args.num_experts 22 | 23 | assert num_experts % num_ranks == 0 and num_local_ranks == 8 24 | if local_rank == 0: 25 | print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True) 26 | 27 | # Random data 28 | x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank 29 | x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') 30 | x_e4m3 = per_token_cast_to_fp8(x) 31 | x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) 32 | scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 33 | group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1) 34 | group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices 35 | masked_scores = create_grouped_scores(scores, group_idx, num_nodes) 36 | topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1] 37 | topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank 38 | topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') 39 | rank_idx = topk_idx // (num_experts // num_ranks) 40 | rank_idx.masked_fill_(topk_idx == -1, -1) 41 | inplace_unique(rank_idx, num_ranks) 42 | rdma_rank_idx = rank_idx // num_local_ranks 43 | rdma_rank_idx.masked_fill_(rank_idx == -1, -1) 44 | inplace_unique(rdma_rank_idx, num_nodes) 45 | 46 | # RDMA dispatch counts 47 | rdma_idx = topk_idx // (num_experts // num_nodes) 48 | rdma_idx.masked_fill_(topk_idx == -1, -1) 49 | inplace_unique(rdma_idx, num_nodes) 50 | num_rdma_token_sent = rdma_idx.ne(-1).sum().item() 51 | 52 | # Expert meta 53 | num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda') 54 | for i in range(num_experts): 55 | num_tokens_per_expert[i] = (topk_idx == i).sum() 56 | gbl_num_tokens_per_expert = num_tokens_per_expert.clone() 57 | dist.all_reduce(gbl_num_tokens_per_expert, group=group) 58 | 59 | # Rank layout meta 60 | num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda') 61 | num_tokens_per_rdma_rank = torch.empty((num_nodes, ), dtype=torch.int, device='cuda') 62 | token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda') 63 | for i in range(num_ranks): 64 | num_tokens_per_rank[i] = (rank_idx == i).sum() 65 | token_sel = (rank_idx == i).max(dim=-1)[0] 66 | count = token_sel.sum().item() 67 | tokens = torch.sort(token_sel.to(torch.int), descending=True)[1] 68 | tokens[:count] = torch.sort(tokens[:count])[0] 69 | token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda') 70 | for i in range(num_nodes): 71 | num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum() 72 | token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int) 73 | is_token_in_rank = token_idx_in_rank >= 0 74 | gbl_num_tokens_per_rank = num_tokens_per_rank.clone() 75 | dist.all_reduce(gbl_num_tokens_per_rank, group=group) 76 | 77 | ref_num_tokens_per_rank, ref_num_tokens_per_rdma_rank, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \ 78 | buffer.get_dispatch_layout(topk_idx, num_experts) 79 | assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank) 80 | assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank) 81 | assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert) 82 | assert torch.allclose(ref_is_token_in_rank, is_token_in_rank) 83 | t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] 84 | if local_rank == 0: 85 | print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True) 86 | print('', flush=True) 87 | group.barrier() 88 | time.sleep(1) 89 | 90 | # Config 91 | rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512) 92 | config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size) 93 | 94 | # Test dispatch 95 | # noinspection PyShadowingNames 96 | def check_data(check_x, recv_gbl_rank_prefix_sum): 97 | assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1)) 98 | check_start = 0 99 | for i in range(num_ranks): 100 | check_end = recv_gbl_rank_prefix_sum[i].item() 101 | assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0 102 | check_start = check_end 103 | 104 | for previous_mode in (False, True): 105 | for async_mode in (False, True): 106 | for current_x in (x_pure_rand, x, x_e4m3): 107 | for with_topk in (False, True): 108 | if local_rank == 0: 109 | print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='') 110 | dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank, 111 | 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode} 112 | if with_topk: 113 | dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights}) 114 | if previous_mode: 115 | dispatch_args.update({'previous_event': buffer.capture()}) 116 | recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args) 117 | event.current_stream_wait() if async_mode else () 118 | recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x 119 | 120 | # Checks 121 | recv_gbl_rank_prefix_sum = handle[-4] 122 | assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' 123 | assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list 124 | if current_x is not x_pure_rand: 125 | check_data(recv_x, recv_gbl_rank_prefix_sum) 126 | if with_topk: 127 | # Check `topk_idx` 128 | assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() 129 | for i, count in enumerate(recv_num_tokens_per_expert_list): 130 | assert recv_topk_idx.eq(i).sum().item() == count 131 | 132 | # Check `topk_weights` 133 | if current_x is not x_pure_rand: 134 | recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] 135 | check_data(recv_topk_weights, recv_gbl_rank_prefix_sum) 136 | 137 | # Test cached dispatch (must without top-k staffs) 138 | if not with_topk: 139 | dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode} 140 | if previous_mode: 141 | dispatch_args.update({'previous_event': buffer.capture()}) 142 | recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) 143 | event.current_stream_wait() if async_mode else () 144 | recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x 145 | if current_x is not x_pure_rand: 146 | check_data(recv_x, recv_gbl_rank_prefix_sum) 147 | 148 | # Test combine 149 | bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') 150 | bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') 151 | combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode} 152 | if with_topk: 153 | combine_args.update({'topk_weights': recv_topk_weights}) 154 | if previous_mode: 155 | combine_args.update({'previous_event': buffer.capture()}) 156 | combined_x, combined_topk_weights, event = buffer.combine(**combine_args) 157 | event.current_stream_wait() if async_mode else () 158 | check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1) 159 | ref_x = x_pure_rand if current_x is x_pure_rand else x 160 | assert calc_diff(check_x, ref_x) < 5e-6 161 | if with_topk: 162 | check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1)) 163 | ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights 164 | assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 165 | 166 | # For later tuning 167 | dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2 168 | dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2 169 | combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes 170 | combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes 171 | 172 | if local_rank == 0: 173 | print(' passed', flush=True) 174 | if local_rank == 0: 175 | print('', flush=True) 176 | 177 | # Tune dispatch performance 178 | best_dispatch_results = None 179 | fp8_factor = (1 + 4 / 128) / 2 180 | for current_x in (x_e4m3, x): 181 | best_time, best_results = 1e10, None 182 | rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes 183 | nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes 184 | for nvl_chunk_size in range(4, 45, 4): 185 | for rdma_chunk_size in range(4, 33, 4): 186 | config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) 187 | tune_args = {'x': current_x, 'handle': handle, 'config': config} 188 | t, notify_t = bench_kineto(lambda: buffer.dispatch(**tune_args), ('dispatch', 'notify')) 189 | if t < best_time: 190 | best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t) 191 | if local_rank == 0: 192 | print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) 193 | if local_rank == 0: 194 | print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) 195 | print('', flush=True) 196 | 197 | if isinstance(current_x, tuple): 198 | # Gather FP8 the best config from rank 0 199 | best_dispatch_results = torch.tensor([best_results[0], best_results[1], best_results[2]], dtype=torch.int32, device='cuda') 200 | all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())] 201 | dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group) 202 | best_dispatch_results = all_best_fp8_results_list[0].tolist() 203 | dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], rdma_buffer_size) 204 | 205 | dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 206 | 'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert, 207 | 'config': dispatch_config if dispatch_config is not None else config} 208 | recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) 209 | 210 | # Tune combine performance 211 | best_time, best_results = 1e10, None 212 | for nvl_chunk_size in range(1, 13, 1): 213 | for rdma_chunk_size in range(8, 33, 4): 214 | config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) 215 | tune_args = {'x': recv_x, 'handle': handle, 'config': config} 216 | t, notify_t = bench_kineto(lambda: buffer.combine(**tune_args), ('combine', 'notify')) 217 | if local_rank == 0: 218 | print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) 219 | if t < best_time: 220 | best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t) 221 | 222 | if local_rank == 0: 223 | print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) 224 | print('', flush=True) 225 | 226 | 227 | # noinspection PyUnboundLocalVariable,PyShadowingNames 228 | def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): 229 | num_nodes = int(os.getenv('WORLD_SIZE', 1)) 230 | rank, num_ranks, group = init_dist(local_rank, num_local_ranks) 231 | if args.test_ll_compatibility: 232 | ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 233 | 234 | num_sms = 24 235 | num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0) 236 | 237 | buffer = deep_ep.Buffer(group, int(2e9), int(1e9), low_latency_mode=args.test_ll_compatibility, 238 | num_qps_per_rank=num_qps_per_rank, explicitly_destroy=True) 239 | assert num_local_ranks == 8 and num_ranks > 8 240 | torch.manual_seed(rank) 241 | 242 | for i in (num_sms, ): 243 | test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group) 244 | if local_rank == 0: 245 | print('', flush=True) 246 | 247 | # Test compatibility with low latency functions 248 | if args.test_ll_compatibility: 249 | buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) 250 | test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) 251 | 252 | # Destroy the buffer runtime and communication group 253 | buffer.destroy() 254 | dist.barrier() 255 | dist.destroy_process_group() 256 | 257 | 258 | if __name__ == '__main__': 259 | parser = argparse.ArgumentParser(description='Test internode EP kernels') 260 | parser.add_argument('--num-processes', type=int, default=8, 261 | help='Number of processes to spawn (default: 8)') 262 | parser.add_argument('--num-tokens', type=int, default=4096, 263 | help='Number of tokens (default: 4096)') 264 | parser.add_argument('--hidden', type=int, default=7168, 265 | help='Hidden dimension size (default: 7168)') 266 | parser.add_argument('--num-topk-groups', type=int, default=None, 267 | help='Number of top-k groups (default: `min(num_nodes, 4)`)') 268 | parser.add_argument('--num-topk', type=int, default=8, 269 | help='Number of top-k experts (default: 8)') 270 | parser.add_argument('--num-experts', type=int, default=256, 271 | help='Number of experts (default: 256') 272 | parser.add_argument('--test-ll-compatibility', action='store_true', 273 | help='whether to test compatibility with low-latency kernels') 274 | args = parser.parse_args() 275 | 276 | # Set default `num_topk_groups` if not provided 277 | if args.num_topk_groups is None: 278 | num_nodes = int(os.getenv('WORLD_SIZE', 1)) 279 | args.num_topk_groups = min(num_nodes, 4) 280 | 281 | num_processes = args.num_processes 282 | torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) 283 | -------------------------------------------------------------------------------- /tests/test_intranode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | import torch.distributed as dist 5 | 6 | # noinspection PyUnresolvedReferences 7 | import deep_ep 8 | from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_back 9 | 10 | # Test compatibility with low latency functions 11 | import test_low_latency 12 | 13 | 14 | # noinspection PyShadowingNames 15 | def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks: int, rank: int, 16 | buffer: deep_ep.Buffer, group: dist.ProcessGroup): 17 | # Settings 18 | num_tokens, hidden = args.num_tokens, args.hidden 19 | num_topk, num_experts = args.num_topk, args.num_experts 20 | 21 | assert num_experts % num_ranks == 0 22 | if local_rank == 0: 23 | print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True) 24 | 25 | # Random data 26 | x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank 27 | x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') 28 | x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None 29 | x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None 30 | scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 31 | topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1] 32 | topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank 33 | topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') 34 | rank_idx = topk_idx // (num_experts // num_ranks) 35 | rank_idx.masked_fill_(topk_idx == -1, -1) 36 | inplace_unique(rank_idx, num_ranks) 37 | 38 | # Expert meta 39 | num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda') 40 | for i in range(num_experts): 41 | num_tokens_per_expert[i] = (topk_idx == i).sum() 42 | gbl_num_tokens_per_expert = num_tokens_per_expert.clone() 43 | dist.all_reduce(gbl_num_tokens_per_expert, group=group) 44 | 45 | # Rank layout meta 46 | num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda') 47 | token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda') 48 | for i in range(num_ranks): 49 | num_tokens_per_rank[i] = (rank_idx == i).sum() 50 | token_sel = (rank_idx == i).max(dim=-1)[0] 51 | count = token_sel.sum().item() 52 | tokens = torch.sort(token_sel.to(torch.int), descending=True)[1] 53 | tokens[:count] = torch.sort(tokens[:count])[0] 54 | token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda') 55 | token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int) 56 | is_token_in_rank = token_idx_in_rank >= 0 57 | gbl_num_tokens_per_rank = num_tokens_per_rank.clone() 58 | dist.all_reduce(gbl_num_tokens_per_rank, group=group) 59 | 60 | ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \ 61 | buffer.get_dispatch_layout(topk_idx, num_experts) 62 | assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank) 63 | assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert) 64 | assert torch.allclose(ref_is_token_in_rank, is_token_in_rank) 65 | t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] 66 | if local_rank == 0: 67 | print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True) 68 | print('', flush=True) 69 | group.barrier() 70 | time.sleep(1) 71 | 72 | # Config 73 | nvl_buffer_size = 256 74 | config = deep_ep.Config(num_sms, 8, nvl_buffer_size) 75 | 76 | # Test dispatch 77 | # noinspection PyShadowingNames 78 | def check_data(check_x, rank_prefix_matrix): 79 | assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1)) 80 | check_start = 0 81 | for i in range(num_ranks): 82 | check_end = rank_prefix_matrix[i][rank].item() 83 | assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0 84 | check_start = check_end 85 | 86 | for previous_mode in (False, True): 87 | for async_mode in (False, True): 88 | for current_x in filter(lambda elem: elem is not None, (x_pure_rand, x, x_e4m3)): 89 | for with_topk in (False, True): 90 | if local_rank == 0: 91 | print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='') 92 | dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'is_token_in_rank': is_token_in_rank, 93 | 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode} 94 | if with_topk: 95 | dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights}) 96 | if previous_mode: 97 | dispatch_args.update({'previous_event': buffer.capture()}) 98 | recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args) 99 | event.current_stream_wait() if async_mode else () 100 | recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x 101 | 102 | # Checks 103 | rank_prefix_matrix = handle[0] 104 | assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' 105 | assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list 106 | if current_x is not x_pure_rand: 107 | check_data(recv_x, rank_prefix_matrix) 108 | recv_topk_weights_clone = None 109 | if with_topk: 110 | # Check `topk_idx` 111 | assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() 112 | for i, count in enumerate(recv_num_tokens_per_expert_list): 113 | assert recv_topk_idx.eq(i).sum().item() == count 114 | 115 | # Check `topk_weights` 116 | recv_topk_weights_clone = recv_topk_weights.clone() 117 | if current_x is not x_pure_rand: 118 | recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] 119 | check_data(recv_topk_weights, rank_prefix_matrix) 120 | 121 | # Test `num_worst_tokens != 0` 122 | if with_topk: 123 | num_worst_tokens = num_tokens * num_ranks 124 | dispatch_args.update({'num_worst_tokens': num_worst_tokens}) 125 | recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args) 126 | event.current_stream_wait() if async_mode else () 127 | recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x 128 | assert len(empty_list) == 0 129 | assert num_worst_tokens == recv_worst_x.size(0) 130 | assert num_worst_tokens == recv_worst_topk_idx.size(0) 131 | assert num_worst_tokens == recv_worst_topk_weights.size(0) 132 | assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)]) 133 | assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)]) 134 | assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)]) 135 | assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item() 136 | 137 | # Test cached dispatch (must without top-k staffs) 138 | if not with_topk: 139 | dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode} 140 | if previous_mode: 141 | dispatch_args.update({'previous_event': buffer.capture()}) 142 | recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) 143 | event.current_stream_wait() if async_mode else () 144 | recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x 145 | if current_x is not x_pure_rand: 146 | check_data(recv_x, rank_prefix_matrix) 147 | 148 | # Test combine 149 | combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode} 150 | if with_topk: 151 | combine_args.update({'topk_weights': recv_topk_weights}) 152 | if previous_mode: 153 | combine_args.update({'previous_event': buffer.capture()}) 154 | combined_x, combined_topk_weights, event = buffer.combine(**combine_args) 155 | event.current_stream_wait() if async_mode else () 156 | check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1) 157 | ref_x = x_pure_rand if current_x is x_pure_rand else x 158 | assert calc_diff(check_x, ref_x) < 5e-6 159 | if with_topk: 160 | check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1)) 161 | ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights 162 | assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 163 | 164 | # For later tuning 165 | dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2 166 | combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes 167 | 168 | if local_rank == 0: 169 | print(' passed', flush=True) 170 | if local_rank == 0: 171 | print('', flush=True) 172 | 173 | # Tune dispatch performance 174 | best_dispatch_results = None 175 | fp8_factor = (1 + 4 / 128) / 2 176 | for current_x in filter(lambda elem: elem is not None, (x_e4m3, x)): 177 | best_time, best_results = 1e10, None 178 | nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes 179 | for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ): 180 | if nvl_chunk_size > 0: 181 | config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) 182 | else: 183 | # Test default config as well 184 | deep_ep.Buffer.set_num_sms(num_sms) 185 | config = deep_ep.Buffer.get_dispatch_config(num_ranks) 186 | tune_args = {'x': current_x, 'handle': handle, 'config': config} 187 | t = bench(lambda: buffer.dispatch(**tune_args))[0] 188 | if t < best_time and nvl_chunk_size > 0: 189 | best_time, best_results = t, (num_sms, nvl_chunk_size) 190 | if local_rank == 0: 191 | print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' 192 | f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True) 193 | if local_rank == 0: 194 | print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True) 195 | print('', flush=True) 196 | 197 | # Gather the best config from rank 0 and the first test setting 198 | if best_dispatch_results is None: 199 | best_dispatch_results = torch.tensor([best_results[0], best_results[1]], dtype=torch.int32, device='cuda') 200 | all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())] 201 | dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group) 202 | best_dispatch_results = all_best_fp8_results_list[0].tolist() 203 | dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size) 204 | 205 | dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, 206 | 'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert, 207 | 'config': dispatch_config if dispatch_config is not None else config} 208 | recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) 209 | 210 | # Tune combine performance 211 | best_time, best_results = 1e10, None 212 | for nvl_chunk_size in tuple(range(1, 17, 1)) + (0, ): 213 | if nvl_chunk_size > 0: 214 | config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) 215 | else: 216 | # Test default config as well 217 | deep_ep.Buffer.set_num_sms(num_sms) 218 | config = deep_ep.Buffer.get_combine_config(num_ranks) 219 | tune_args = {'x': recv_x, 'handle': handle, 'config': config} 220 | t = bench(lambda: buffer.combine(**tune_args))[0] 221 | if local_rank == 0: 222 | print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' 223 | f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True) 224 | if t < best_time and nvl_chunk_size > 0: 225 | best_time, best_results = t, (num_sms, nvl_chunk_size) 226 | 227 | if local_rank == 0: 228 | print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True) 229 | print('', flush=True) 230 | 231 | 232 | # noinspection PyUnboundLocalVariable,PyShadowingNames 233 | def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): 234 | rank, num_ranks, group = init_dist(local_rank, num_local_ranks) 235 | test_ll_compatibility, num_rdma_bytes = False, 0 236 | if test_ll_compatibility: 237 | ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 238 | num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts) 239 | 240 | buffer = deep_ep.Buffer(group, int(2e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility, 241 | num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), explicitly_destroy=True) 242 | torch.manual_seed(rank) 243 | 244 | for i in (24, ): 245 | test_main(args, i, local_rank, num_ranks, rank, buffer, group) 246 | if local_rank == 0: 247 | print('', flush=True) 248 | 249 | # Test compatibility with low latency functions 250 | if test_ll_compatibility: 251 | buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) 252 | test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) 253 | 254 | # Destroy the buffer runtime and communication group 255 | buffer.destroy() 256 | dist.barrier() 257 | dist.destroy_process_group() 258 | 259 | 260 | if __name__ == '__main__': 261 | parser = argparse.ArgumentParser(description='Test intranode EP kernels') 262 | parser.add_argument('--num-processes', type=int, default=8, 263 | help='Number of processes to spawn (default: 8)') 264 | parser.add_argument('--num-tokens', type=int, default=4096, 265 | help='Number of tokens (default: 4096)') 266 | parser.add_argument('--hidden', type=int, default=7168, 267 | help='Hidden dimension size (default: 7168)') 268 | parser.add_argument('--num-topk', type=int, default=8, 269 | help='Number of top-k experts (default: 8)') 270 | parser.add_argument('--num-experts', type=int, default=256, 271 | help='Number of experts (default: 256)') 272 | args = parser.parse_args() 273 | 274 | num_processes = args.num_processes 275 | torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) 276 | -------------------------------------------------------------------------------- /tests/test_low_latency.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import torch 4 | import torch.distributed as dist 5 | from functools import partial 6 | 7 | import deep_ep 8 | from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back 9 | 10 | 11 | def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, 12 | rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, 13 | use_logfmt: bool = False, seed: int = 0): 14 | torch.manual_seed(seed + rank) 15 | random.seed(seed + rank) 16 | 17 | assert num_experts % num_ranks == 0 18 | num_local_experts = num_experts // num_ranks 19 | 20 | # NOTES: the integers greater than 256 exceed the BF16 precision limit 21 | rank_offset = 128 22 | assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)' 23 | 24 | x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset) 25 | x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1) 26 | x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1 27 | scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 28 | topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] 29 | topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs() 30 | 31 | # Randomly mask some positions 32 | for i in range(10): 33 | topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1 34 | 35 | # Check dispatch correctness 36 | do_check = True 37 | hash_value, num_times = 0, 0 38 | for current_x in (x, x_pure_rand): 39 | for return_recv_hook in (False, True): 40 | for dispatch_use_fp8 in (False, True): 41 | for round_scale in (False, True) if dispatch_use_fp8 else (False, ): 42 | for use_ue8m0 in (False, True) if round_scale else (False, ): 43 | num_times += 1 44 | for i in range((num_times % 2) + 1): 45 | cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda') 46 | packed_recv_x, packed_recv_count, handle, event, hook = \ 47 | buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, 48 | use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, 49 | cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, 50 | async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) 51 | hook() if return_recv_hook else event.current_stream_wait() 52 | packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x 53 | simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \ 54 | if dispatch_use_fp8 else packed_recv_x.clone() 55 | all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda') 56 | dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) 57 | for i in range(num_local_experts if do_check else 0): 58 | expert_id = rank * num_local_experts + i 59 | recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i] 60 | recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] 61 | 62 | # Check expert indices 63 | int_mask = (2 ** 32) - 1 64 | num_valid_tokens = recv_count.item() 65 | assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}' 66 | assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()' 67 | assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}' 68 | 69 | # Check received data 70 | if current_x is not x_pure_rand: 71 | recv_x = recv_x[:num_valid_tokens] 72 | recv_x_amin = recv_x[:, :-128].amin(dim=-1) 73 | recv_src_info = recv_src_info[:num_valid_tokens] 74 | assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) 75 | if round_scale: 76 | assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 77 | else: 78 | assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 79 | for j in range(num_ranks): 80 | begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() 81 | if not round_scale: 82 | assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() 83 | assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0 84 | if dispatch_use_fp8: 85 | hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) 86 | hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) 87 | else: 88 | hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) 89 | 90 | # Check combine correctness 91 | for zero_copy in (False, ) if use_logfmt else (False, True): 92 | if zero_copy: 93 | buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x 94 | out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') 95 | combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, 96 | use_logfmt=use_logfmt, 97 | async_finish=not return_recv_hook, zero_copy=zero_copy, 98 | return_recv_hook=return_recv_hook, out=out) 99 | hook() if return_recv_hook else event.current_stream_wait() 100 | if do_check: 101 | diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) 102 | assert torch.isnan(combined_x).sum().item() == 0 103 | assert diff < (7e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {zero_copy=}' 104 | hash_value ^= hash_tensor(combined_x) 105 | 106 | # noinspection PyShadowingNames 107 | def large_gemm_with_hook(hook): 108 | mat_0 = torch.randn((8192, 8192), dtype=torch.float) 109 | mat_1 = torch.randn((8192, 8192), dtype=torch.float) 110 | mat_0 @ mat_1 111 | hook() 112 | 113 | # noinspection PyShadowingNames 114 | def test_func(return_recv_hook: bool): 115 | recv_x, recv_count, handle, event, hook = \ 116 | buffer.low_latency_dispatch(x_pure_rand, topk_idx, num_tokens, num_experts, 117 | cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, 118 | use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook) 119 | large_gemm_with_hook(hook) if return_recv_hook else None 120 | combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, 121 | use_logfmt=use_logfmt, return_recv_hook=return_recv_hook) 122 | large_gemm_with_hook(hook) if return_recv_hook else None 123 | 124 | # Calculate bandwidth 125 | num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 126 | num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 127 | for i in range(num_tokens): 128 | num_selections = (topk_idx[i] != -1).sum().item() 129 | num_dispatch_comm_bytes += num_fp8_bytes * num_selections 130 | num_combine_comm_bytes += num_bf16_bytes * num_selections 131 | 132 | # Dispatch + combine testing 133 | avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False)) 134 | print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' 135 | f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) 136 | 137 | # Separate profiling 138 | for return_recv_hook in (False, True): 139 | group.barrier() 140 | dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook), 141 | kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, 142 | suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1) 143 | if not return_recv_hook: 144 | print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' 145 | f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True) 146 | else: 147 | print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | ' 148 | f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True) 149 | return hash_value 150 | 151 | 152 | # noinspection PyUnboundLocalVariable,PyShadowingNames 153 | def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): 154 | rank, num_ranks, group = init_dist(local_rank, num_local_ranks) 155 | num_tokens, hidden = args.num_tokens, args.hidden 156 | num_topk, num_experts = args.num_topk, args.num_experts 157 | 158 | num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) 159 | if local_rank == 0: 160 | print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) 161 | buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, 162 | num_qps_per_rank=num_experts // num_ranks, 163 | allow_nvlink_for_low_latency_mode=not args.disable_nvlink, explicitly_destroy=True) 164 | test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, 165 | use_logfmt=args.use_logfmt, seed=1) 166 | 167 | do_pressure_test = False 168 | for seed in range(int(1e9) if do_pressure_test else 0): 169 | if local_rank == 0: 170 | print(f'Testing with seed {seed} ...', flush=True) 171 | ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, 172 | use_logfmt=args.use_logfmt, seed=seed) 173 | for i in range(20): 174 | assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, 175 | use_logfmt=args.use_logfmt, seed=seed) == ref_hash, f'Error: seed={seed}' 176 | 177 | # Destroy the buffer runtime and communication group 178 | buffer.destroy() 179 | dist.barrier() 180 | dist.destroy_process_group() 181 | 182 | 183 | if __name__ == '__main__': 184 | # TODO: you may modify NUMA binding for less CPU overhead 185 | # TODO: buggy with `num_tokens=512` 186 | parser = argparse.ArgumentParser(description='Test low-latency EP kernels') 187 | parser.add_argument('--num-processes', type=int, default=8, 188 | help='Number of processes to spawn (default: 8)') 189 | parser.add_argument('--num-tokens', type=int, default=128, 190 | help='Number of tokens (default: 128)') 191 | parser.add_argument('--hidden', type=int, default=7168, 192 | help='Hidden dimension size (default: 7168)') 193 | parser.add_argument('--num-topk', type=int, default=8, 194 | help='Number of top-k experts (default: 8)') 195 | parser.add_argument('--num-experts', type=int, default=288, 196 | help='Number of experts (default: 288)') 197 | parser.add_argument('--disable-nvlink', action='store_true', 198 | help='Whether to disable NVLink for testing') 199 | parser.add_argument('--use-logfmt', action='store_true', 200 | help='Whether to test LogFMT combine') 201 | args = parser.parse_args() 202 | 203 | num_processes = args.num_processes 204 | torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) 205 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import json 3 | import tempfile 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import os 8 | import sys 9 | import torch 10 | import torch.distributed as dist 11 | from typing import Optional, Union 12 | 13 | 14 | def init_dist(local_rank: int, num_local_ranks: int): 15 | # NOTES: you may rewrite this function with your own cluster settings 16 | ip = os.getenv('MASTER_ADDR', '127.0.0.1') 17 | port = int(os.getenv('MASTER_PORT', '8361')) 18 | num_nodes = int(os.getenv('WORLD_SIZE', 1)) 19 | node_rank = int(os.getenv('RANK', 0)) 20 | 21 | sig = inspect.signature(dist.init_process_group) 22 | params = { 23 | 'backend': 'nccl', 24 | 'init_method': f'tcp://{ip}:{port}', 25 | 'world_size': num_nodes * num_local_ranks, 26 | 'rank': node_rank * num_local_ranks + local_rank, 27 | } 28 | if 'device_id' in sig.parameters: 29 | # noinspection PyTypeChecker 30 | params['device_id'] = torch.device(f'cuda:{local_rank}') 31 | dist.init_process_group(**params) 32 | torch.set_default_dtype(torch.bfloat16) 33 | torch.set_default_device('cuda') 34 | torch.cuda.set_device(local_rank) 35 | 36 | return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes))) 37 | 38 | 39 | def calc_diff(x: torch.Tensor, y: torch.Tensor): 40 | x, y = x.double() + 1, y.double() + 1 41 | denominator = (x * x + y * y).sum() 42 | sim = 2 * (x * y).sum() / denominator 43 | return (1 - sim).item() 44 | 45 | 46 | def per_token_cast_to_fp8(x: torch.Tensor): 47 | assert x.dim() == 2 and x.size(1) % 128 == 0 48 | m, n = x.shape 49 | x_view = x.view(m, -1, 128) 50 | x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) 51 | return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) 52 | 53 | 54 | def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor): 55 | if x_scales.dtype == torch.int: 56 | x_scales = x_scales.view(dtype=torch.int8).to(torch.int) << 23 57 | x_scales = x_scales.view(dtype=torch.float) 58 | x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128) 59 | x_scales = x_scales.view(x_fp8.size(0), -1, 1) 60 | return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16) 61 | 62 | 63 | def inplace_unique(x: torch.Tensor, num_slots: int): 64 | assert x.dim() == 2 65 | mask = x < 0 66 | x_padded = x.masked_fill(mask, num_slots) 67 | bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device) 68 | bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded)) 69 | bin_count = bin_count[:, :num_slots] 70 | sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True) 71 | sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1) 72 | sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values 73 | x[:, :].fill_(-1) 74 | valid_len = min(num_slots, x.size(1)) 75 | x[:, :valid_len] = sorted_bin_idx[:, :valid_len] 76 | 77 | 78 | def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int): 79 | num_tokens, num_experts = scores.shape 80 | scores = scores.view(num_tokens, num_groups, -1) 81 | mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device) 82 | mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores) 83 | return (scores * mask).view(num_tokens, num_experts) 84 | 85 | 86 | def bench(fn, num_warmups: int = 50, num_tests: int = 50, post_fn=None): 87 | # Flush L2 cache with 256 MB data 88 | torch.cuda.synchronize() 89 | cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') 90 | 91 | # Warmup 92 | for _ in range(num_warmups): 93 | fn() 94 | 95 | # Flush L2 96 | cache.zero_() 97 | 98 | # Testing 99 | start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] 100 | end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] 101 | for i in range(num_tests): 102 | # Record 103 | start_events[i].record() 104 | fn() 105 | end_events[i].record() 106 | if post_fn is not None: 107 | post_fn() 108 | torch.cuda.synchronize() 109 | 110 | times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:] 111 | return np.average(times), np.min(times), np.max(times) 112 | 113 | 114 | class empty_suppress: 115 | def __enter__(self): 116 | return self 117 | 118 | def __exit__(self, *_): 119 | pass 120 | 121 | 122 | class suppress_stdout_stderr: 123 | def __enter__(self): 124 | self.outnull_file = open(os.devnull, 'w') 125 | self.errnull_file = open(os.devnull, 'w') 126 | 127 | self.old_stdout_fileno_undup = sys.stdout.fileno() 128 | self.old_stderr_fileno_undup = sys.stderr.fileno() 129 | 130 | self.old_stdout_fileno = os.dup(sys.stdout.fileno()) 131 | self.old_stderr_fileno = os.dup(sys.stderr.fileno()) 132 | 133 | self.old_stdout = sys.stdout 134 | self.old_stderr = sys.stderr 135 | 136 | os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) 137 | os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) 138 | 139 | sys.stdout = self.outnull_file 140 | sys.stderr = self.errnull_file 141 | return self 142 | 143 | def __exit__(self, *_): 144 | sys.stdout = self.old_stdout 145 | sys.stderr = self.old_stderr 146 | 147 | os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) 148 | os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) 149 | 150 | os.close(self.old_stdout_fileno) 151 | os.close(self.old_stderr_fileno) 152 | 153 | self.outnull_file.close() 154 | self.errnull_file.close() 155 | 156 | 157 | def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppress_kineto_output: bool = False, 158 | trace_path: Optional[str] = None, barrier_comm_profiling: bool = False, 159 | num_kernels_per_period: int = 1): 160 | # Profile 161 | suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress 162 | with suppress(): 163 | schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) 164 | with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof: 165 | for i in range(2): 166 | # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead 167 | if barrier_comm_profiling: 168 | lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 169 | rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') 170 | lhs @ rhs 171 | dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) 172 | for _ in range(num_tests): 173 | fn() 174 | prof.step() 175 | 176 | # Parse the profiling table 177 | assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) 178 | is_tuple = isinstance(kernel_names, tuple) 179 | prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') 180 | kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names 181 | assert all([isinstance(name, str) for name in kernel_names]) 182 | for name in kernel_names: 183 | assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' 184 | 185 | # Save chrome traces 186 | if trace_path is not None: 187 | prof.export_chrome_trace(trace_path) 188 | 189 | # Return average kernel durations 190 | units = {'ms': 1e3, 'us': 1e6} 191 | kernel_durations = [] 192 | for name in kernel_names: 193 | for line in prof_lines: 194 | if name in line: 195 | time_str = line.split()[-2] 196 | for unit, scale in units.items(): 197 | if unit in time_str: 198 | kernel_durations.append(float(time_str.replace(unit, '')) / scale) 199 | break 200 | break 201 | 202 | # Expand the kernels by periods 203 | if num_kernels_per_period > 1: 204 | with tempfile.NamedTemporaryFile(suffix='.json') as tmp: 205 | prof.export_chrome_trace(tmp.name) 206 | profile_data = json.loads(Path(tmp.name).read_text()) 207 | 208 | for i, kernel_name in enumerate(kernel_names): 209 | events = [event for event in profile_data['traceEvents'] if f'::{kernel_name}' in event['name']] 210 | events = sorted(events, key=lambda event: event['ts']) 211 | durations = [event['dur'] / 1e6 for event in events] 212 | assert len(durations) % num_kernels_per_period == 0 213 | num_kernel_patterns = len(durations) // num_kernels_per_period 214 | kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns 215 | for j in range(num_kernels_per_period)] 216 | 217 | # Return execution durations 218 | return kernel_durations if is_tuple else kernel_durations[0] 219 | 220 | 221 | def hash_tensor(t: torch.Tensor): 222 | return t.view(torch.int64).sum().item() 223 | -------------------------------------------------------------------------------- /third-party/README.md: -------------------------------------------------------------------------------- 1 | # Install NVSHMEM 2 | 3 | ## Important notices 4 | 5 | **This project is neither sponsored nor supported by NVIDIA.** 6 | 7 | **Use of NVIDIA NVSHMEM is governed by the terms at [NVSHMEM Software License Agreement](https://docs.nvidia.com/nvshmem/api/sla.html).** 8 | 9 | ## Prerequisites 10 | 11 | Hardware requirements: 12 | - GPUs inside one node needs to be connected by NVLink 13 | - GPUs across different nodes needs to be connected by RDMA devices, see [GPUDirect RDMA Documentation](https://docs.nvidia.com/cuda/gpudirect-rdma/) 14 | - InfiniBand GPUDirect Async (IBGDA) support, see [IBGDA Overview](https://developer.nvidia.com/blog/improving-network-performance-of-hpc-systems-using-nvidia-magnum-io-nvshmem-and-gpudirect-async/) 15 | - For more detailed requirements, see [NVSHMEM Hardware Specifications](https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/abstract.html#hardware-requirements) 16 | 17 | Software requirements: 18 | - NVSHMEM v3.3.9 or later 19 | 20 | ## Installation procedure 21 | 22 | ### 1. Install NVSHMEM binaries 23 | 24 | NVSHMEM 3.3.9 binaries are available in several formats: 25 | - Tarballs for [x86_64](https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/linux-x86_64/libnvshmem-linux-x86_64-3.3.9_cuda12-archive.tar.xz) and [aarch64](https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/linux-sbsa/libnvshmem-linux-sbsa-3.3.9_cuda12-archive.tar.xz) 26 | - RPM and deb packages: instructions can be found on the [NVHSMEM installer page](https://developer.nvidia.com/nvshmem-downloads?target_os=Linux) 27 | - Conda packages through conda-forge 28 | - pip wheels through PyPI: `pip install nvidia-nvshmem-cu12` 29 | DeepEP is compatible with upstream NVSHMEM 3.3.9 and later. 30 | 31 | 32 | ### 2. Enable NVSHMEM IBGDA support 33 | 34 | NVSHMEM Supports two modes with different requirements. Either of the following methods can be used to enable IBGDA support. 35 | 36 | #### 2.1 Configure NVIDIA driver 37 | 38 | This configuration enables traditional IBGDA support. 39 | 40 | Modify `/etc/modprobe.d/nvidia.conf`: 41 | 42 | ```bash 43 | options nvidia NVreg_EnableStreamMemOPs=1 NVreg_RegistryDwords="PeerMappingOverride=1;" 44 | ``` 45 | 46 | Update kernel configuration: 47 | 48 | ```bash 49 | sudo update-initramfs -u 50 | sudo reboot 51 | ``` 52 | 53 | #### 2.2 Install GDRCopy and load the gdrdrv kernel module 54 | 55 | This configuration enables IBGDA through asynchronous post-send operations assisted by the CPU. More information about CPU-assisted IBGDA can be found in [this blog](https://developer.nvidia.com/blog/enhancing-application-portability-and-compatibility-across-new-platforms-using-nvidia-magnum-io-nvshmem-3-0/#cpu-assisted_infiniband_gpu_direct_async%C2%A0). 56 | It comes with a small performance penalty, but can be used when modifying the driver regkeys is not an option. 57 | 58 | Download GDRCopy 59 | GDRCopy is available as prebuilt deb and rpm packages [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/). or as source code on the [GDRCopy github repository](https://github.com/NVIDIA/gdrcopy). 60 | 61 | Install GDRCopy following the instructions on the [GDRCopy github repository](https://github.com/NVIDIA/gdrcopy?tab=readme-ov-file#build-and-installation). 62 | 63 | ## Post-installation configuration 64 | 65 | When not installing NVSHMEM from RPM or deb packages, set the following environment variables in your shell configuration: 66 | 67 | ```bash 68 | export NVSHMEM_DIR=/path/to/your/dir/to/install # Use for DeepEP installation 69 | export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH" 70 | export PATH="${NVSHMEM_DIR}/bin:$PATH" 71 | ``` 72 | 73 | ## Verification 74 | 75 | ```bash 76 | nvshmem-info -a # Should display details of nvshmem 77 | ``` 78 | -------------------------------------------------------------------------------- /third-party/nvshmem.patch: -------------------------------------------------------------------------------- 1 | From 9e6cc27cceb3130784e4ea7b61ea3171156365fd Mon Sep 17 00:00:00 2001 2 | From: Shangyan Zhou <sy.zhou@deepseek.com> 3 | Date: Fri, 20 Dec 2024 10:57:12 +0800 4 | Subject: [PATCH 1/4] Change QP creating order. 5 | 6 | --- 7 | src/modules/transport/ibgda/ibgda.cpp | 13 ++++++++----- 8 | 1 file changed, 8 insertions(+), 5 deletions(-) 9 | 10 | diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp 11 | index ef325cd..286132e 100644 12 | --- a/src/modules/transport/ibgda/ibgda.cpp 13 | +++ b/src/modules/transport/ibgda/ibgda.cpp 14 | @@ -2936,17 +2936,20 @@ int nvshmemt_ibgda_connect_endpoints(nvshmem_transport_t t, int *selected_dev_id 15 | INFO(ibgda_state->log_level, "Creating %d RC QPs", device->rc.num_eps_per_pe); 16 | for (int i = 0; i < num_rc_eps; ++i) { 17 | // Do not create loopback to self 18 | - if (i / device->rc.num_eps_per_pe == mype) { 19 | + int dst_pe = (i + 1 + mype) % n_pes; 20 | + int offset = i / n_pes; 21 | + int mapped_i = dst_pe * device->rc.num_eps_per_pe + offset; 22 | + if (dst_pe == mype) { 23 | continue; 24 | } 25 | - status = ibgda_create_qp(&device->rc.eps[i], device, portid, i, 26 | + status = ibgda_create_qp(&device->rc.eps[mapped_i], device, portid, mapped_i, 27 | NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC); 28 | NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, 29 | - "ibgda_create_dci failed on RC #%d.", i); 30 | + "ibgda_create_dci failed on RC #%d.", mapped_i); 31 | 32 | - status = ibgda_get_rc_handle(&local_rc_handles[i], device->rc.eps[i], device); 33 | + status = ibgda_get_rc_handle(&local_rc_handles[mapped_i], device->rc.eps[mapped_i], device); 34 | NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, 35 | - "ibgda_get_rc_handle failed on RC #%d.", i); 36 | + "ibgda_get_rc_handle failed on RC #%d.", mapped_i); 37 | } 38 | 39 | if (num_rc_eps) { 40 | -- 41 | 2.25.1 42 | 43 | 44 | From b11d41e4f3727f2f6ccc00a8c852e59e2ee33c8a Mon Sep 17 00:00:00 2001 45 | From: Shangyan Zhou <sy.zhou@deepseek.com> 46 | Date: Fri, 10 Jan 2025 11:53:38 +0800 47 | Subject: [PATCH 2/4] Add recv queue and recv cq for rc qps. 48 | 49 | Let the ibgda rc qps use regular recv queue. 50 | 51 | Add recv queue to ibgda dev qp. 52 | 53 | IBGDA create recv cq 54 | 55 | Setup recv cq. 56 | 57 | fix recv queue. 58 | 59 | Remove some useless idx. 60 | 61 | Longer recv queue. 62 | --- 63 | .../nvshmem_common_ibgda.h | 19 +++++- 64 | src/modules/transport/ibgda/ibgda.cpp | 65 ++++++++++++++++--- 65 | 2 files changed, 71 insertions(+), 13 deletions(-) 66 | 67 | diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h 68 | index 8b8a263..1be3dec 100644 69 | --- a/src/include/device_host_transport/nvshmem_common_ibgda.h 70 | +++ b/src/include/device_host_transport/nvshmem_common_ibgda.h 71 | @@ -168,14 +168,17 @@ typedef struct { 72 | uint64_t get_head; // last wqe idx + 1 with a "fetch" operation (g, get, amo_fetch) 73 | uint64_t get_tail; // last wqe idx + 1 polled with cst; get_tail > get_head is possible 74 | } tx_wq; 75 | + struct { 76 | + uint64_t resv_head; // last reserved wqe idx + 1 77 | + } rx_wq; 78 | struct { 79 | uint64_t head; 80 | uint64_t tail; 81 | } ibuf; 82 | char padding[NVSHMEMI_IBGDA_QP_MANAGEMENT_PADDING]; 83 | } __attribute__((__aligned__(8))) nvshmemi_ibgda_device_qp_management_v1; 84 | -static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 96, 85 | - "ibgda_device_qp_management_v1 must be 96 bytes."); 86 | +static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 104, 87 | + "ibgda_device_qp_management_v1 must be 104 bytes."); 88 | 89 | typedef nvshmemi_ibgda_device_qp_management_v1 nvshmemi_ibgda_device_qp_management_t; 90 | 91 | @@ -199,9 +202,19 @@ typedef struct nvshmemi_ibgda_device_qp { 92 | // May point to mvars.prod_idx or internal prod_idx 93 | uint64_t *prod_idx; 94 | } tx_wq; 95 | + struct { 96 | + uint16_t nwqes; 97 | + uint64_t tail; 98 | + void *wqe; 99 | + __be32 *dbrec; 100 | + void *bf; 101 | + nvshmemi_ibgda_device_cq_t *cq; 102 | + // May point to mvars.prod_idx or internal prod_idx 103 | + uint64_t *prod_idx; 104 | + } rx_wq; 105 | nvshmemi_ibgda_device_qp_management_v1 mvars; // management variables 106 | } nvshmemi_ibgda_device_qp_v1; 107 | -static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 184, "ibgda_device_qp_v1 must be 184 bytes."); 108 | +static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 248, "ibgda_device_qp_v1 must be 248 bytes."); 109 | 110 | typedef nvshmemi_ibgda_device_qp_v1 nvshmemi_ibgda_device_qp_t; 111 | 112 | diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp 113 | index 286132e..e0b2d5c 100644 114 | --- a/src/modules/transport/ibgda/ibgda.cpp 115 | +++ b/src/modules/transport/ibgda/ibgda.cpp 116 | @@ -198,6 +198,7 @@ struct ibgda_ep { 117 | off_t dbr_offset; 118 | 119 | struct ibgda_cq *send_cq; 120 | + struct ibgda_cq *recv_cq; 121 | struct ibv_ah *ah; 122 | 123 | uint32_t user_index; 124 | @@ -1538,7 +1539,8 @@ static int ibgda_create_cq_shared_objects(nvshmemt_ibgda_state_t *ibgda_state, 125 | 126 | struct ibv_context *context = device->context; 127 | 128 | - unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes; 129 | + // Each RC qp has one send CQ and one recv CQ. 130 | + unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes * 2; 131 | 132 | assert(ibgda_qp_depth > 0); 133 | size_t num_cqe = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth); 134 | @@ -1701,7 +1703,8 @@ static int ibgda_create_qp_shared_objects(nvshmemt_ibgda_state_t *ibgda_state, 135 | } 136 | 137 | // Allocate and map WQ buffer for all QPs. 138 | - wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB; // num_wqebb is always a power of 2 139 | + // Todo: reduce the size of wq buffer. 140 | + wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB * 2; // num_wqebb is always a power of 2 141 | wq_buf_size = wq_buf_size_per_qp * num_eps; 142 | status = ibgda_nic_control_alloc(&wq_mobject, wq_buf_size, IBGDA_GPAGE_SIZE); 143 | NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "cannot allocate wq buf.\n"); 144 | @@ -1882,8 +1885,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device 145 | int cqe_version = 0; 146 | 147 | struct ibgda_cq *send_cq = NULL; 148 | + struct ibgda_cq *recv_cq = NULL; 149 | 150 | size_t num_wqebb = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth); 151 | + size_t num_recv_wqe = ibgda_qp_depth; 152 | + size_t recv_wqe_size = 16; 153 | 154 | int status = 0; 155 | 156 | @@ -1911,6 +1917,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device 157 | status = ibgda_create_cq(&send_cq, device); 158 | NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_create_cq failed.\n"); 159 | 160 | + if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) { 161 | + status = ibgda_create_cq(&recv_cq, device); 162 | + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_create_cq failed.\n"); 163 | + } 164 | + 165 | ep = (struct ibgda_ep *)calloc(1, sizeof(struct ibgda_ep)); 166 | NVSHMEMI_NULL_ERROR_JMP(ep, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, 167 | "Unable to allocate mem for ep.\n"); 168 | @@ -1939,12 +1950,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device 169 | DEVX_SET(qpc, qp_context, pm_state, MLX5_QPC_PM_STATE_MIGRATED); 170 | DEVX_SET(qpc, qp_context, pd, device->qp_shared_object.pdn); 171 | DEVX_SET(qpc, qp_context, uar_page, uar_mobject->uar->page_id); // BF register 172 | - DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE); // Shared Receive Queue 173 | - DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn); 174 | DEVX_SET(qpc, qp_context, cqn_snd, send_cq->cqn); 175 | - DEVX_SET(qpc, qp_context, cqn_rcv, device->qp_shared_object.rcqn); 176 | + DEVX_SET(qpc, qp_context, cqn_rcv, qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC ? recv_cq->cqn : device->qp_shared_object.rcqn); 177 | DEVX_SET(qpc, qp_context, log_sq_size, IBGDA_ILOG2_OR0(num_wqebb)); 178 | - DEVX_SET(qpc, qp_context, log_rq_size, 0); 179 | DEVX_SET(qpc, qp_context, cs_req, 0); // Disable CS Request 180 | DEVX_SET(qpc, qp_context, cs_res, 0); // Disable CS Response 181 | DEVX_SET(qpc, qp_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE); // Enable dbr_umem_id 182 | @@ -1953,6 +1961,15 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device 183 | DEVX_SET(qpc, qp_context, dbr_umem_id, dbr_umem->umem_id); // DBR buffer 184 | DEVX_SET(qpc, qp_context, user_index, qp_idx); 185 | DEVX_SET(qpc, qp_context, page_offset, 0); 186 | + if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC){ 187 | + DEVX_SET(qpc, qp_context, rq_type, 0); // Regular recv queue 188 | + DEVX_SET(qpc, qp_context, log_rq_size, IBGDA_ILOG2(num_recv_wqe)); // 4 wqe 189 | + DEVX_SET(qpc, qp_context, log_rq_stride, IBGDA_ILOG2(recv_wqe_size) - 4); // max recv wqe size = 16B 190 | + } else { 191 | + DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE); // Shared Receive Queue, DC must use this. 192 | + DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn); 193 | + DEVX_SET(qpc, qp_context, log_rq_size, 0); 194 | + } 195 | 196 | ep->devx_qp = mlx5dv_devx_obj_create(context, cmd_in, sizeof(cmd_in), cmd_out, sizeof(cmd_out)); 197 | NVSHMEMI_NULL_ERROR_JMP(ep->devx_qp, status, NVSHMEMX_ERROR_INTERNAL, out, 198 | @@ -1962,9 +1979,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device 199 | ep->portid = portid; 200 | 201 | ep->sq_cnt = num_wqebb; 202 | - ep->sq_buf_offset = 0; 203 | + ep->sq_buf_offset = num_recv_wqe * recv_wqe_size; 204 | 205 | - ep->rq_cnt = 0; 206 | + ep->rq_cnt = num_recv_wqe; 207 | ep->rq_buf_offset = 0; 208 | 209 | ep->wq_mobject = device->qp_shared_object.wq_mobject; 210 | @@ -1978,6 +1995,7 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device 211 | ep->uar_mobject = uar_mobject; 212 | 213 | ep->send_cq = send_cq; 214 | + ep->recv_cq = recv_cq; 215 | 216 | ep->qp_type = qp_type; 217 | 218 | @@ -1989,6 +2007,7 @@ out: 219 | if (status) { 220 | if (uar_mobject) ibgda_unmap_and_free_qp_uar(uar_mobject); 221 | if (send_cq) ibgda_destroy_cq(send_cq); 222 | + if (recv_cq) ibgda_destroy_cq(recv_cq); 223 | if (ep) free(ep); 224 | } 225 | 226 | @@ -2287,6 +2306,10 @@ static int ibgda_destroy_ep(struct ibgda_ep *ep) { 227 | ibgda_destroy_cq(ep->send_cq); 228 | } 229 | 230 | + if (ep->recv_cq) { 231 | + ibgda_destroy_cq(ep->recv_cq); 232 | + } 233 | + 234 | if (ep->ah) { 235 | ftable.destroy_ah(ep->ah); 236 | } 237 | @@ -2318,7 +2341,7 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda 238 | dev_qp->qpn = ep->qpn; 239 | 240 | assert(ep->wq_mobject->has_gpu_mapping); 241 | - dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset); 242 | + dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->sq_buf_offset); 243 | 244 | if (ibgda_nic_handler == IBGDA_NIC_HANDLER_GPU) { 245 | assert(ep->dbr_mobject->has_gpu_mapping); 246 | @@ -2330,6 +2353,12 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda 247 | } 248 | 249 | dev_qp->tx_wq.nwqes = ep->sq_cnt; 250 | + if (ep->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) { 251 | + dev_qp->rx_wq.nwqes = ep->rq_cnt; 252 | + dev_qp->rx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->rq_buf_offset); 253 | + dev_qp->rx_wq.dbrec = (__be32 *)((uintptr_t)ep->dbr_mobject->aligned.gpu_ptr + ep->dbr_offset); 254 | + dev_qp->rx_wq.bf = (void *)ep->uar_mobject->aligned.gpu_ptr; 255 | + } 256 | 257 | ibuf_dci_start = (uintptr_t)device->qp_shared_object.internal_buf.mem_object->aligned.gpu_ptr; 258 | ibuf_rc_start = ibuf_dci_start + (size_per_dci * device->dci.num_eps); 259 | @@ -2379,6 +2408,9 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) { 260 | nvshmemi_ibgda_device_cq_t *cq_d = NULL; 261 | nvshmemi_ibgda_device_cq_t *cq_h = NULL; 262 | 263 | + nvshmemi_ibgda_device_cq_t *recv_cq_d = NULL; 264 | + nvshmemi_ibgda_device_cq_t *recv_cq_h = NULL; 265 | + 266 | uint8_t *qp_group_switches_d = NULL; 267 | 268 | const size_t mvars_offset = offsetof(nvshmemi_ibgda_device_qp_t, mvars); 269 | @@ -2386,6 +2418,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) { 270 | const size_t cons_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.cons_idx); 271 | const size_t wqe_h_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.resv_head); 272 | const size_t wqe_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.ready_head); 273 | + const size_t rx_resv_head_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.resv_head); 274 | 275 | nvshmemi_ibgda_device_qp_map_type_t rc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID; 276 | nvshmemi_ibgda_device_qp_map_type_t dc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID; 277 | @@ -2421,7 +2454,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) { 278 | num_dct_handles += device->dct.num_eps * n_pes; 279 | num_dci_handles += device->dci.num_eps; 280 | num_rc_handles += device->rc.num_eps_per_pe * n_pes; 281 | - num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1)); 282 | + num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1) * 2); 283 | num_shared_dci_handles += device->dci.num_shared_eps; 284 | } 285 | assert(num_dci_handles - num_shared_dci_handles >= 0); 286 | @@ -2456,6 +2489,10 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) { 287 | for (int i = 0; i < num_cq_handles; i++) { 288 | nvshmemi_init_ibgda_device_cq(cq_h[i]); 289 | } 290 | + 291 | + recv_cq_h = (nvshmemi_ibgda_device_cq_t *)calloc(1, sizeof(*recv_cq_h)); 292 | + NVSHMEMI_NULL_ERROR_JMP(recv_cq_h, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "recv_cq calloc err."); 293 | + nvshmemi_init_ibgda_device_cq(recv_cq_h[0]); 294 | /* allocate host memory for dct, rc, cq, dci end */ 295 | 296 | /* allocate device memory for dct, rc, cq, dci start */ 297 | @@ -2559,6 +2596,14 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) { 298 | } 299 | 300 | ++cq_idx; 301 | + 302 | + rc_h[arr_idx].rx_wq.cq = &cq_d[cq_idx]; 303 | + 304 | + ibgda_get_device_cq(&cq_h[cq_idx], device->rc.eps[i]->recv_cq); 305 | + cq_h[cq_idx].resv_head = (uint64_t *)(base_mvars_d_addr + rx_resv_head_offset); 306 | + cq_h[cq_idx].qpn = rc_h[arr_idx].qpn; 307 | + cq_h[cq_idx].qp_type = rc_h[arr_idx].qp_type; 308 | + ++cq_idx; 309 | } 310 | } 311 | } 312 | -- 313 | 2.25.1 314 | 315 | 316 | From af479f9f23103d4a1579fae38676d6b3022df887 Mon Sep 17 00:00:00 2001 317 | From: Shangyan Zhou <sy.zhou@deepseek.com> 318 | Date: Sat, 8 Feb 2025 18:02:39 +0800 319 | Subject: [PATCH 3/4] Maintain recv queue's cons_idx. 320 | 321 | --- 322 | src/include/device_host_transport/nvshmem_common_ibgda.h | 5 +++-- 323 | src/modules/transport/ibgda/ibgda.cpp | 6 ++++-- 324 | 2 files changed, 7 insertions(+), 4 deletions(-) 325 | 326 | diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h 327 | index 1be3dec..ea1e284 100644 328 | --- a/src/include/device_host_transport/nvshmem_common_ibgda.h 329 | +++ b/src/include/device_host_transport/nvshmem_common_ibgda.h 330 | @@ -170,6 +170,7 @@ typedef struct { 331 | } tx_wq; 332 | struct { 333 | uint64_t resv_head; // last reserved wqe idx + 1 334 | + uint64_t cons_idx; // polled wqe idx + 1 (consumer index + 1) 335 | } rx_wq; 336 | struct { 337 | uint64_t head; 338 | @@ -177,7 +178,7 @@ typedef struct { 339 | } ibuf; 340 | char padding[NVSHMEMI_IBGDA_QP_MANAGEMENT_PADDING]; 341 | } __attribute__((__aligned__(8))) nvshmemi_ibgda_device_qp_management_v1; 342 | -static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 104, 343 | - "ibgda_device_qp_management_v1 must be 104 bytes."); 344 | +static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 112, 345 | + "ibgda_device_qp_management_v1 must be 112 bytes."); 346 | 347 | typedef nvshmemi_ibgda_device_qp_management_v1 nvshmemi_ibgda_device_qp_management_t; 348 | @@ -214,7 +215,7 @@ typedef struct nvshmemi_ibgda_device_qp { 349 | } rx_wq; 350 | nvshmemi_ibgda_device_qp_management_v1 mvars; // management variables 351 | } nvshmemi_ibgda_device_qp_v1; 352 | -static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 248, "ibgda_device_qp_v1 must be 248 bytes."); 353 | +static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 256, "ibgda_device_qp_v1 must be 256 bytes."); 354 | 355 | typedef nvshmemi_ibgda_device_qp_v1 nvshmemi_ibgda_device_qp_t; 356 | 357 | diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp 358 | index e0b2d5c..bc339c5 100644 359 | --- a/src/modules/transport/ibgda/ibgda.cpp 360 | +++ b/src/modules/transport/ibgda/ibgda.cpp 361 | @@ -1067,7 +1067,7 @@ static inline void ibgda_nic_control_free(struct ibgda_mem_object *mobject) { 362 | ibgda_host_mem_free(mobject); 363 | } 364 | 365 | -static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device) { 366 | +static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device, int cc = 1) { 367 | int status = 0; 368 | 369 | struct ibgda_cq *gcq = NULL; 370 | @@ -1118,7 +1118,7 @@ static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device) 371 | cq_context = DEVX_ADDR_OF(create_cq_in, cmd_in, cq_context); 372 | DEVX_SET(cqc, cq_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE); 373 | DEVX_SET(cqc, cq_context, cqe_sz, MLX5_CQE_SIZE_64B); 374 | - DEVX_SET(cqc, cq_context, cc, 0x1); // Use collapsed CQ 375 | + DEVX_SET(cqc, cq_context, cc, cc); // Use collapsed CQ 376 | DEVX_SET(cqc, cq_context, oi, 0x1); // Allow overrun 377 | DEVX_SET(cqc, cq_context, dbr_umem_id, dbr_umem->umem_id); 378 | DEVX_SET(cqc, cq_context, log_cq_size, IBGDA_ILOG2_OR0(num_cqe)); 379 | @@ -2419,6 +2419,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) { 380 | const size_t wqe_h_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.resv_head); 381 | const size_t wqe_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.ready_head); 382 | const size_t rx_resv_head_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.resv_head); 383 | + const size_t rx_cons_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.cons_idx); 384 | 385 | nvshmemi_ibgda_device_qp_map_type_t rc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID; 386 | nvshmemi_ibgda_device_qp_map_type_t dc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID; 387 | @@ -2601,6 +2602,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) { 388 | 389 | ibgda_get_device_cq(&cq_h[cq_idx], device->rc.eps[i]->recv_cq); 390 | cq_h[cq_idx].resv_head = (uint64_t *)(base_mvars_d_addr + rx_resv_head_offset); 391 | + cq_h[cq_idx].cons_idx = (uint64_t *)(base_mvars_d_addr + rx_cons_offset); 392 | cq_h[cq_idx].qpn = rc_h[arr_idx].qpn; 393 | cq_h[cq_idx].qp_type = rc_h[arr_idx].qp_type; 394 | ++cq_idx; 395 | -- 396 | 2.25.1 397 | 398 | 399 | From e0ba3fa21b4b633b481c6684c3ad04f2670c8df4 Mon Sep 17 00:00:00 2001 400 | From: Shangyan Zhou <sy.zhou@deepseek.com> 401 | Date: Tue, 11 Feb 2025 11:00:57 +0800 402 | Subject: [PATCH 4/4] Init rx_wq counters. 403 | 404 | --- 405 | src/include/device_host_transport/nvshmem_common_ibgda.h | 2 ++ 406 | 1 file changed, 2 insertions(+) 407 | 408 | diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h 409 | index ea1e284..e6640d6 100644 410 | --- a/src/include/device_host_transport/nvshmem_common_ibgda.h 411 | +++ b/src/include/device_host_transport/nvshmem_common_ibgda.h 412 | @@ -46,6 +46,8 @@ 413 | qp_man.tx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \ 414 | qp_man.tx_wq.get_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \ 415 | qp_man.tx_wq.get_tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \ 416 | + qp_man.rx_wq.resv_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \ 417 | + qp_man.rx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \ 418 | qp_man.ibuf.head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \ 419 | qp_man.ibuf.tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \ 420 | } while (0); 421 | -- 422 | 2.25.1 423 | 424 | diff --git a/src/modules/transport/common/transport_ib_common.cpp b/src/modules/transport/common/transport_ib_common.cpp 425 | index c89f408..f99018a 100644 426 | --- a/src/modules/transport/common/transport_ib_common.cpp 427 | +++ b/src/modules/transport/common/transport_ib_common.cpp 428 | @@ -26,6 +26,9 @@ int nvshmemt_ib_common_nv_peer_mem_available() { 429 | if (access("/sys/kernel/mm/memory_peers/nvidia-peermem/version", F_OK) == 0) { 430 | return NVSHMEMX_SUCCESS; 431 | } 432 | + if (access("/sys/module/nvidia_peermem/version", F_OK) == 0) { 433 | + return NVSHMEMX_SUCCESS; 434 | + } 435 | 436 | return NVSHMEMX_ERROR_INTERNAL; 437 | } 438 | 439 | 440 | From 099f608fcd9a1d34c866ad75d0af5d02d2020374 Mon Sep 17 00:00:00 2001 441 | From: Kaichao You <youkaichao@gmail.com> 442 | Date: Tue, 10 Jun 2025 00:35:03 -0700 443 | Subject: [PATCH] remove gdrcopy dependency 444 | 445 | --- 446 | src/modules/transport/ibgda/ibgda.cpp | 6 ++++++ 447 | 1 file changed, 6 insertions(+) 448 | 449 | diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp 450 | index ef325cd..16ee09c 100644 451 | --- a/src/modules/transport/ibgda/ibgda.cpp 452 | +++ b/src/modules/transport/ibgda/ibgda.cpp 453 | @@ -406,6 +406,7 @@ static size_t ibgda_get_host_page_size() { 454 | return host_page_size; 455 | } 456 | 457 | +#ifdef NVSHMEM_USE_GDRCOPY 458 | int nvshmemt_ibgda_progress(nvshmem_transport_t t) { 459 | nvshmemt_ibgda_state_t *ibgda_state = (nvshmemt_ibgda_state_t *)t->state; 460 | int n_devs_selected = ibgda_state->n_devs_selected; 461 | @@ -459,6 +460,11 @@ int nvshmemt_ibgda_progress(nvshmem_transport_t t) { 462 | } 463 | return 0; 464 | } 465 | +#else 466 | +int nvshmemt_ibgda_progress(nvshmem_transport_t t) { 467 | + return NVSHMEMX_ERROR_NOT_SUPPORTED; 468 | +} 469 | +#endif 470 | 471 | int nvshmemt_ibgda_show_info(struct nvshmem_transport *transport, int style) { 472 | NVSHMEMI_ERROR_PRINT("ibgda show info not implemented"); 473 | -- 474 | 2.34.1 475 | --------------------------------------------------------------------------------