The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .gitignore
├── LICENSE
├── README.md
├── csrc
    ├── CMakeLists.txt
    ├── config.hpp
    ├── deep_ep.cpp
    ├── deep_ep.hpp
    ├── event.hpp
    └── kernels
    │   ├── CMakeLists.txt
    │   ├── api.cuh
    │   ├── buffer.cuh
    │   ├── configs.cuh
    │   ├── exception.cuh
    │   ├── ibgda_device.cuh
    │   ├── internode.cu
    │   ├── internode_ll.cu
    │   ├── intranode.cu
    │   ├── launch.cuh
    │   ├── layout.cu
    │   ├── runtime.cu
    │   └── utils.cuh
├── deep_ep
    ├── __init__.py
    ├── buffer.py
    └── utils.py
├── figures
    ├── low-latency.png
    └── normal.png
├── install.sh
├── setup.py
├── tests
    ├── test_internode.py
    ├── test_intranode.py
    ├── test_low_latency.py
    └── utils.py
└── third-party
    ├── README.md
    └── nvshmem.patch


/.gitignore:
--------------------------------------------------------------------------------
1 | compile_commands.json
2 | .idea
3 | .DS_Store
4 | *.pyc
5 | build/
6 | .cache/
7 | .vscode/
8 | */cmake-build-*/
9 | 


--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 | MIT License
 2 | 
 3 | Copyright (c) 2025 DeepSeek
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 | 


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | # DeepEP
  2 | 
  3 | DeepEP is a communication library tailored for Mixture-of-Experts (MoE) and expert parallelism (EP). It provides high-throughput and low-latency all-to-all GPU kernels, which are also known as MoE dispatch and combine. The library also supports low-precision operations, including FP8.
  4 | 
  5 | To align with the group-limited gating algorithm proposed in the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper, DeepEP offers a set of kernels optimized for asymmetric-domain bandwidth forwarding, such as forwarding data from NVLink domain to RDMA domain. These kernels deliver high throughput, making them suitable for both training and inference prefilling tasks. Additionally, they support SM (Streaming Multiprocessors) number control.
  6 | 
  7 | For latency-sensitive inference decoding, DeepEP includes a set of low-latency kernels with pure RDMA to minimize delays. The library also introduces a hook-based communication-computation overlapping method that does not occupy any SM resource.
  8 | 
  9 | Notice: the implementation in this library may have some slight differences from the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper.
 10 | 
 11 | ## Performance
 12 | 
 13 | ### Normal kernels with NVLink and RDMA forwarding
 14 | 
 15 | We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow the DeepSeek-V3/R1 pretraining setting (4096 tokens per batch, 7168 hidden, top-4 groups, top-8 experts, FP8 dispatching and BF16 combining).
 16 | 
 17 | |   Type    | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth |
 18 | |:---------:|:------------:|:--------------------:|:-----------:|:--------------------:|
 19 | | Intranode |      8       |  153 GB/s (NVLink)   |      8      |  158 GB/s (NVLink)   |
 20 | | Internode |      16      |    43 GB/s (RDMA)    |     16      |    43 GB/s (RDMA)    |
 21 | | Internode |      32      |    58 GB/s (RDMA)    |     32      |    57 GB/s (RDMA)    |
 22 | | Internode |      64      |    51 GB/s (RDMA)    |     64      |    50 GB/s (RDMA)    |
 23 | 
 24 | **News (2025.04.22)**: with optimizations from Tencent Network Platform Department, performance was enhanced by up to 30%, see [#130](https://github.com/deepseek-ai/DeepEP/pull/130) for more details. Thanks for the contribution!
 25 | 
 26 | ### Low-latency kernels with pure RDMA
 27 | 
 28 | We test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow a typical DeepSeek-V3/R1 production setting (128 tokens per batch, 7168 hidden, top-8 experts, FP8 dispatching and BF16 combining).
 29 | 
 30 | | Dispatch #EP | Latency | RDMA bandwidth | Combine #EP | Latency | RDMA bandwidth |
 31 | |:------------:|:-------:|:--------------:|:-----------:|:-------:|:--------------:|
 32 | |      8       |  77 us  |    98 GB/s     |      8      | 114 us  |    127 GB/s    |
 33 | |      16      | 118 us  |    63 GB/s     |     16      | 195 us  |    74 GB/s     |
 34 | |      32      | 155 us  |    48 GB/s     |     32      | 273 us  |    53 GB/s     |
 35 | |      64      | 173 us  |    43 GB/s     |     64      | 314 us  |    46 GB/s     |
 36 | |     128      | 192 us  |    39 GB/s     |     128     | 369 us  |    39 GB/s     |
 37 | |     256      | 194 us  |    39 GB/s     |     256     | 360 us  |    40 GB/s     |
 38 | 
 39 | **News (2025.06.05)**: low-latency kernels now leverage NVLink as much as possible, see [#173](https://github.com/deepseek-ai/DeepEP/pull/173) for more details. Thanks for the contribution!
 40 | 
 41 | ## Quick start
 42 | 
 43 | ### Requirements
 44 | 
 45 | - Ampere (SM80), Hopper (SM90) GPUs, or other architectures with SM90 PTX ISA support
 46 | - Python 3.8 and above
 47 | - CUDA version
 48 |   - CUDA 11.0 and above for SM80 GPUs
 49 |   - CUDA 12.3 and above for SM90 GPUs
 50 | - PyTorch 2.1 and above
 51 | - NVLink for intranode communication
 52 | - RDMA network for internode communication
 53 | 
 54 | ### Download and install NVSHMEM dependency
 55 | 
 56 | DeepEP also depends on NVSHMEM. Please refer to our [NVSHMEM Installation Guide](third-party/README.md) for instructions.
 57 | 
 58 | ### Development
 59 | 
 60 | ```bash
 61 | # Build and make symbolic links for SO files
 62 | NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py build
 63 | # You may modify the specific SO names according to your own platform
 64 | ln -s build/lib.linux-x86_64-cpython-38/deep_ep_cpp.cpython-38-x86_64-linux-gnu.so
 65 | 
 66 | # Run test cases
 67 | # NOTES: you may modify the `init_dist` function in `tests/utils.py`
 68 | # according to your own cluster settings, and launch into multiple nodes 
 69 | python tests/test_intranode.py
 70 | python tests/test_internode.py
 71 | python tests/test_low_latency.py
 72 | ```
 73 | 
 74 | ### Installation
 75 | 
 76 | ```bash
 77 | NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py install
 78 | ```
 79 | 
 80 | #### Installation environment variables
 81 | 
 82 | - `NVSHMEM_DIR`: the path to the NVSHMEM directory, disable all internode and low-latency features if not specified 
 83 | - `DISABLE_SM90_FEATURES`: 0 or 1, whether to disable SM90 features, it is required for SM90 devices or CUDA 11
 84 | - `TORCH_CUDA_ARCH_LIST`: the list of target architectures, e.g. `TORCH_CUDA_ARCH_LIST="9.0"`
 85 | - `DISABLE_AGGRESSIVE_PTX_INSTRS`: 0 or 1, whether to disable aggressive load/store instructions, see [Undefined-behavior PTX usage](#undefined-behavior-ptx-usage) for more details
 86 | 
 87 | Then, import `deep_ep` in your Python project, and enjoy!
 88 | 
 89 | ## Network configurations
 90 | 
 91 | DeepEP is fully tested with InfiniBand networks. However, it is theoretically compatible with RDMA over Converged Ethernet (RoCE) as well.
 92 | 
 93 | ### Traffic isolation
 94 | 
 95 | Traffic isolation is supported by InfiniBand through Virtual Lanes (VL).
 96 | 
 97 | To prevent interference between different types of traffic, we recommend segregating workloads across different virtual lanes as follows:
 98 | 
 99 | - workloads using normal kernels
100 | - workloads using low-latency kernels
101 | - other workloads
102 | 
103 | For DeepEP, you can control the virtual lane assignment by setting the `NVSHMEM_IB_SL` environment variable.
104 | 
105 | ### Adaptive routing
106 | 
107 | Adaptive routing is an advanced routing feature provided by InfiniBand switches that can evenly distribute traffic across multiple paths. Enabling adaptive routing can completely eliminate network congestion caused by routing conflicts, but it also introduces additional latency. We recommend the following configuration for optimal performance:
108 | 
109 | - enable adaptive routing in environments with heavy network loads
110 | - use static routing in environments with light network loads
111 | 
112 | ### Congestion control
113 | 
114 | Congestion control is disabled as we have not observed significant congestion in our production environment.
115 | 
116 | ## Interfaces and examples
117 | 
118 | ### Example use in model training or inference prefilling
119 | 
120 | The normal kernels can be used in model training or the inference prefilling phase (without the backward part) as the below example code shows.
121 | 
122 | ```python
123 | import torch
124 | import torch.distributed as dist
125 | from typing import List, Tuple, Optional, Union
126 | 
127 | from deep_ep import Buffer, EventOverlap
128 | 
129 | # Communication buffer (will allocate at runtime)
130 | _buffer: Optional[Buffer] = None
131 | 
132 | # Set the number of SMs to use
133 | # NOTES: this is a static variable
134 | Buffer.set_num_sms(24)
135 | 
136 | 
137 | # You may call this function at the framework initialization
138 | def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer:
139 |     global _buffer
140 |     
141 |     # NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests
142 |     num_nvl_bytes, num_rdma_bytes = 0, 0
143 |     for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())):
144 |         num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
145 |         num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)
146 | 
147 |     # Allocate a buffer if not existed or not enough buffer size
148 |     if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes:
149 |         _buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
150 |     return _buffer
151 | 
152 | 
153 | def get_hidden_bytes(x: torch.Tensor) -> int:
154 |     t = x[0] if isinstance(x, tuple) else x
155 |     return t.size(1) * max(t.element_size(), 2)
156 | 
157 | 
158 | def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
159 |                      topk_idx: torch.Tensor, topk_weights: torch.Tensor,
160 |                      num_experts: int, previous_event: Optional[EventOverlap] = None) -> \
161 |         Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]:
162 |     # NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency 
163 |     # of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
164 |     # refer to the docs of `Buffer.dispatch`
165 |     global _buffer
166 | 
167 |     # Calculate layout before actual dispatch
168 |     num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = \
169 |         _buffer.get_dispatch_layout(topk_idx, num_experts,
170 |                                     previous_event=previous_event, async_finish=True,
171 |                                     allocate_on_comm_stream=previous_event is not None)
172 |     # Do MoE dispatch
173 |     # NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
174 |     # Unless you specify `num_worst_tokens`, but this flag is for intranode only
175 |     # For more advanced usages, please refer to the docs of the `dispatch` function
176 |     recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
177 |         _buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
178 |                          num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
179 |                          is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert,
180 |                          previous_event=previous_event, async_finish=True,
181 |                          allocate_on_comm_stream=True)
182 |     # For event management, please refer to the docs of the `EventOverlap` class
183 |     return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event
184 | 
185 | 
186 | def dispatch_backward(grad_recv_x: torch.Tensor, grad_recv_topk_weights: torch.Tensor, handle: Tuple) -> \
187 |         Tuple[torch.Tensor, torch.Tensor, EventOverlap]:
188 |     global _buffer
189 | 
190 |     # The backward process of MoE dispatch is actually a combine
191 |     # For more advanced usages, please refer to the docs of the `combine` function
192 |     combined_grad_x, combined_grad_recv_topk_weights, event = \
193 |         _buffer.combine(grad_recv_x, handle, topk_weights=grad_recv_topk_weights, async_finish=True)
194 | 
195 |     # For event management, please refer to the docs of the `EventOverlap` class
196 |     return combined_grad_x, combined_grad_recv_topk_weights, event
197 | 
198 | 
199 | def combine_forward(x: torch.Tensor, handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
200 |         Tuple[torch.Tensor, EventOverlap]:
201 |     global _buffer
202 | 
203 |     # Do MoE combine
204 |     # For more advanced usages, please refer to the docs of the `combine` function
205 |     combined_x, _, event = _buffer.combine(x, handle, async_finish=True, previous_event=previous_event,
206 |                                            allocate_on_comm_stream=previous_event is not None)
207 | 
208 |     # For event management, please refer to the docs of the `EventOverlap` class
209 |     return combined_x, event
210 | 
211 | 
212 | def combine_backward(grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
213 |                      handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
214 |         Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]:
215 |     global _buffer
216 | 
217 |     # The backward process of MoE combine is actually a dispatch
218 |     # For more advanced usages, please refer to the docs of the `dispatch` function
219 |     grad_x, _, _, _, _, event = _buffer.dispatch(grad_combined_x, handle=handle, async_finish=True,
220 |                                                  previous_event=previous_event,
221 |                                                  allocate_on_comm_stream=previous_event is not None)
222 | 
223 |     # For event management, please refer to the docs of the `EventOverlap` class
224 |     return grad_x, event
225 | ```
226 | 
227 | Moreover, inside the dispatch function, we may not know how many tokens to receive for the current rank. So an implicit CPU wait for GPU received count signal will be involved, as the following figure shows.
228 | 
229 | ![normal](figures/normal.png)
230 | 
231 | ### Example use in inference decoding
232 | 
233 | The low latency kernels can be used in the inference decoding phase as the below example code shows.
234 | 
235 | ```python
236 | import torch
237 | import torch.distributed as dist
238 | from typing import Tuple, Optional
239 | 
240 | from deep_ep import Buffer
241 | 
242 | # Communication buffer (will allocate at runtime)
243 | # NOTES: there is no SM control API for the low-latency kernels
244 | _buffer: Optional[Buffer] = None
245 | 
246 | 
247 | # You may call this function at the framework initialization
248 | def get_buffer(group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> Buffer:
249 |     # NOTES: the low-latency mode will consume much more space than the normal mode
250 |     # So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
251 |     global _buffer
252 |     num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts)
253 | 
254 |     # Allocate a buffer if not existed or not enough buffer size
255 |     if _buffer is None or _buffer.group != group or not _buffer.low_latency_mode or _buffer.num_rdma_bytes < num_rdma_bytes:
256 |         # NOTES: for the best performance, the QP number **must** be equal to the number of the local experts
257 |         assert num_experts % group.size() == 0
258 |         _buffer = Buffer(group, 0, num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // group.size())
259 |     return _buffer
260 | 
261 | 
262 | def low_latency_dispatch(hidden_states: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int):
263 |     global _buffer
264 | 
265 |     # Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
266 |     recv_hidden_states, recv_expert_count, handle, event, hook = \
267 |         _buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts,
268 |                                      async_finish=False, return_recv_hook=True)
269 | 
270 |     # NOTES: the actual tensor will not be received only if you call `hook()`,
271 |     # it is useful for double-batch overlapping, but **without any SM occupation**
272 |     # If you don't want to overlap, please set `return_recv_hook=False`
273 |     # Later, you can use our GEMM library to do the computation with this specific format
274 |     return recv_hidden_states, recv_expert_count, handle, event, hook
275 | 
276 | 
277 | def low_latency_combine(hidden_states: torch.Tensor,
278 |                         topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: Tuple):
279 |     global _buffer
280 | 
281 |     # Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
282 |     combined_hidden_states, event_overlap, hook = \
283 |         _buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle,
284 |                                     async_finish=False, return_recv_hook=True)
285 | 
286 |     # NOTES: the same behavior as described in the dispatch kernel
287 |     return combined_hidden_states, event_overlap, hook
288 | ```
289 | 
290 | For two-micro-batch overlapping, you can refer to the following figure. With our receiving hook interface, the RDMA network traffic is happening in the background, without costing any GPU SMs from the computation part. But notice, the overlapped parts can be adjusted, i.e., the 4 parts of attention/dispatch/MoE/combine may not have the exact same execution time. You may adjust the stage settings according to your workload.
291 | 
292 | ![low-latency](figures/low-latency.png)
293 | 
294 | ## Roadmap
295 | 
296 | - [x] AR support
297 | - [x] Refactor low-latency mode AR code
298 | - [x] A100 support (intranode only)
299 | - [x] Support BF16 for the low-latency dispatch kernel
300 | - [x] Support NVLink protocol for intranode low-latency kernels
301 | - [ ] TMA copy instead of LD/ST
302 |   - [x] Intranode kernels
303 |   - [ ] Internode kernels
304 |   - [ ] Low-latency kernels
305 | - [ ] SM-free kernels and refactors
306 | - [ ] Fully remove undefined-behavior PTX instructions
307 | 
308 | ## Notices
309 | 
310 | #### Easier potential overall design
311 | 
312 | The current DeepEP implementation uses queues for communication buffers which save memory but introduce complexity and potential deadlocks. If you're implementing your own version based on DeepEP, consider using fixed-size buffers allocated to maximum capacity for simplicity and better performance. For a detailed discussion of this alternative approach, see https://github.com/deepseek-ai/DeepEP/issues/39.
313 | 
314 | #### Undefined-behavior PTX usage
315 | 
316 | - For extreme performance, we discover and use an undefined-behavior PTX usage: using read-only PTX `ld.global.nc.L1::no_allocate.L2::256B` to **read volatile data**. The PTX modifier `.nc` indicates that a non-coherent cache is used. But the correctness is tested to be guaranteed with `.L1::no_allocate` on Hopper architectures, and performance will be much better. The reason we guess may be: the non-coherent cache is unified with L1, and the L1 modifier is not just a hint but a strong option, so that the correctness can be guaranteed by no dirty data in L1.
317 | - Initially, because NVCC could not automatically unroll volatile read PTX, we tried using `__ldg` (i.e., `ld.nc`). Even compared to manually unrolled volatile reads, it was significantly faster (likely due to additional compiler optimizations). However, the results could be incorrect or dirty. After consulting the PTX documentation, we discovered that L1 and non-coherent cache are unified on Hopper architectures. We speculated that `.L1::no_allocate` might resolve the issue, leading to this discovery.
318 | - If you find kernels not working on some other platforms, you may add `DISABLE_AGGRESSIVE_PTX_INSTRS=1` to `setup.py` and disable this, or file an issue.
319 | 
320 | #### Auto-tuning on your cluster
321 | 
322 | For better performance on your cluster, we recommend to run all the tests and use the best auto-tuned configuration. The default configurations are optimized on the DeepSeek's internal cluster.
323 | 
324 | ## License
325 | 
326 | This code repository is released under [the MIT License](LICENSE), except for codes that reference NVSHMEM (including `csrc/kernels/ibgda_device.cuh` and `third-party/nvshmem.patch`), which are subject to [NVSHMEM SLA](https://docs.nvidia.com/nvshmem/api/sla.html).
327 | 
328 | ## Community Forks
329 | 
330 | - [Infrawaves/DeepEP_ibrc_dual-ports_multiQP](https://github.com/Infrawaves/DeepEP_ibrc_dual-ports_multiQP) - Adds multi-QP solution and dual-port NIC support in IBRC transport
331 | 
332 | ## Citation
333 | 
334 | If you use this codebase or otherwise find our work valuable, please cite:
335 | 
336 | ```bibtex
337 | @misc{deepep2025,
338 |       title={DeepEP: an efficient expert-parallel communication library},
339 |       author={Chenggang Zhao and Shangyan Zhou and Liyue Zhang and Chengqi Deng and Zhean Xu and Yuxuan Liu and Kuai Yu and Jiashi Li and Liang Zhao},
340 |       year={2025},
341 |       publisher = {GitHub},
342 |       howpublished = {\url{https://github.com/deepseek-ai/DeepEP}},
343 | }
344 | ```
345 | 


--------------------------------------------------------------------------------
/csrc/CMakeLists.txt:
--------------------------------------------------------------------------------
 1 | # NOTES: this CMake is only for debugging; for setup, please use Torch extension
 2 | cmake_minimum_required(VERSION 3.10)
 3 | project(deep_ep LANGUAGES CUDA CXX)
 4 | set(CMAKE_VERBOSE_MAKEFILE ON)
 5 | 
 6 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC")
 7 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC")
 8 | set(CUDA_SEPARABLE_COMPILATION ON)
 9 | list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
10 | list(APPEND CUDA_NVCC_FLAGS "-O3")
11 | list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
12 | 
13 | set(USE_SYSTEM_NVTX on)
14 | set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile")
15 | set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
16 | 
17 | find_package(CUDAToolkit REQUIRED)
18 | find_package(pybind11 REQUIRED)
19 | find_package(Torch REQUIRED)
20 | find_package(NVSHMEM REQUIRED HINTS ${NVSHMEM_ROOT_DIR}/lib/cmake/nvshmem)
21 | 
22 | add_library(nvshmem ALIAS nvshmem::nvshmem)
23 | add_library(nvshmem_host ALIAS nvshmem::nvshmem_host)
24 | add_library(nvshmem_device ALIAS nvshmem::nvshmem_device)
25 | 
26 | set(CMAKE_CXX_STANDARD 17)
27 | set(CMAKE_CUDA_STANDARD 17)
28 | 
29 | include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR})
30 | link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR})
31 | 
32 | add_subdirectory(kernels)
33 | 
34 | # Link CPP and CUDA together
35 | pybind11_add_module(deep_ep_cpp deep_ep.cpp)
36 | target_link_libraries(deep_ep_cpp PRIVATE ${EP_CUDA_LIBRARIES} ${TORCH_LIBRARIES} torch_python)
37 | 


--------------------------------------------------------------------------------
/csrc/config.hpp:
--------------------------------------------------------------------------------
  1 | #pragma once
  2 | 
  3 | #include "kernels/api.cuh"
  4 | #include "kernels/exception.cuh"
  5 | 
  6 | namespace deep_ep {
  7 | 
  8 | template <typename dtype_t>
  9 | dtype_t ceil_div(dtype_t a, dtype_t b) {
 10 |     return (a + b - 1) / b;
 11 | }
 12 | 
 13 | template <typename dtype_t>
 14 | dtype_t align(dtype_t a, dtype_t b) {
 15 |     return ceil_div<dtype_t>(a, b) * b;
 16 | }
 17 | 
 18 | struct Config {
 19 |     int num_sms;
 20 |     int num_max_nvl_chunked_send_tokens;
 21 |     int num_max_nvl_chunked_recv_tokens;
 22 |     int num_max_rdma_chunked_send_tokens;
 23 |     int num_max_rdma_chunked_recv_tokens;
 24 | 
 25 |     Config(int num_sms,
 26 |            int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
 27 |            int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) :
 28 |             num_sms(num_sms),
 29 |             num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),
 30 |             num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),
 31 |             num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
 32 |             num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
 33 |         EP_HOST_ASSERT(num_sms >= 0);
 34 |         EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0);
 35 |         EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);
 36 |         EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0);
 37 | 
 38 |         // Ceil up RDMA buffer size
 39 |         this->num_max_rdma_chunked_recv_tokens = align<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
 40 |         EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
 41 |         // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push
 42 |         EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2);
 43 |     }
 44 | 
 45 |     size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
 46 |         // Below are some assumptions
 47 |         // TODO: add assertions
 48 |         constexpr int kNumMaxTopK = 128;
 49 |         constexpr int kNumMaxScales = 128;
 50 |         EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
 51 |         EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
 52 |         const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
 53 |         const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
 54 |         const int num_channels = num_sms / 2;
 55 | 
 56 |         size_t num_bytes = 0;
 57 |         num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
 58 |         num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
 59 | #ifndef DISABLE_NVSHMEM
 60 |         num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes();
 61 | #endif
 62 |         num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t);
 63 |         num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float);
 64 |         num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float);
 65 |         num_bytes = ((num_bytes + 127) / 128) * 128;
 66 |         return num_bytes;
 67 |     }
 68 | 
 69 |     size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
 70 | #ifndef DISABLE_NVSHMEM
 71 |         // Legacy mode
 72 |         if (num_ranks <= NUM_MAX_NVL_PEERS)
 73 |             return 0;
 74 | 
 75 |         // Below are some assumptions
 76 |         // TODO: add assertions
 77 |         constexpr int kNumMaxTopK = 128;
 78 |         constexpr int kNumMaxScales = 128;
 79 |         EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
 80 |         EP_HOST_ASSERT(num_sms % 2 == 0);
 81 |         const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
 82 |         const int num_channels = num_sms / 2;
 83 | 
 84 |         size_t num_bytes = 0;
 85 |         num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
 86 |         num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
 87 |         num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2;
 88 |         num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t) * 2;
 89 |         num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2;
 90 |         num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2;
 91 |         num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
 92 |         num_bytes = ((num_bytes + 127) / 128) * 128;
 93 |         return num_bytes;
 94 | #else
 95 |         EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
 96 | #endif
 97 |     }
 98 | };
 99 | 
100 | struct LowLatencyBuffer {
101 |     int num_clean_int = 0;
102 | 
103 |     void* dispatch_rdma_send_buffer = nullptr;
104 |     void* dispatch_rdma_recv_data_buffer = nullptr;
105 |     int* dispatch_rdma_recv_count_buffer = nullptr;
106 | 
107 |     void* combine_rdma_send_buffer = nullptr;
108 |     void* combine_rdma_recv_data_buffer = nullptr;
109 |     int* combine_rdma_recv_flag_buffer = nullptr;
110 | 
111 |     void* combine_rdma_send_buffer_data_start = nullptr;
112 |     size_t num_bytes_per_combine_msg = 0;
113 | 
114 |     std::pair<int*, int> clean_meta() {
115 |         EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
116 |         return {dispatch_rdma_recv_count_buffer, num_clean_int};
117 |     }
118 | };
119 | 
120 | struct LowLatencyLayout {
121 |     size_t total_bytes = 0;
122 |     LowLatencyBuffer buffers[2];
123 | 
124 |     template <typename out_ptr_t = void*, typename count_ptr_t = uint8_t*, typename in_ptr_t = void*>
125 |     out_ptr_t advance(const in_ptr_t& ptr, size_t count) {
126 |         return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
127 |     }
128 | 
129 |     LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
130 |         const int num_scales = hidden / 128;
131 | 
132 |         // Dispatch and combine layout:
133 |         //  - 2 symmetric odd/even send buffer
134 |         //  - 2 symmetric odd/even receive buffers
135 |         //  - 2 symmetric odd/even signaling buffers
136 | 
137 |         // Message sizes
138 |         // NOTES: you should add a control `int4` for combine messages if you want to do data transformation
139 |         EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);
140 |         size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float));
141 |         size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16);
142 | 
143 |         // Send buffer
144 |         size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
145 |         size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
146 |         size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
147 |         EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);
148 |         total_bytes += send_buffer_bytes * 2;
149 | 
150 |         // Symmetric receive buffers
151 |         // TODO: optimize memory usages
152 |         size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
153 |         size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
154 |         size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
155 |         EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
156 |         total_bytes += recv_buffer_bytes * 2;
157 | 
158 |         // Symmetric signaling buffers
159 |         size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
160 |         size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
161 |         size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
162 |         total_bytes += signaling_buffer_bytes * 2;
163 | 
164 |         // Assign pointers
165 |         // NOTES: we still leave some space for distinguishing dispatch/combine buffer,
166 |         // so you may see some parameters are duplicated
167 |         for (int i = 0; i < 2; ++ i) {
168 |             buffers[i] = {
169 |                 static_cast<int>(signaling_buffer_bytes / sizeof(int)),
170 |                 advance(rdma_buffer, send_buffer_bytes * i),
171 |                 advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
172 |                 advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
173 |                 advance(rdma_buffer, send_buffer_bytes * i),
174 |                 advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
175 |                 advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
176 |                 advance(rdma_buffer, send_buffer_bytes * i),
177 |                 num_bytes_per_combine_msg
178 |             };
179 |         }
180 |     }
181 | };
182 | 
183 | size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
184 |     auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes;
185 |     return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES;
186 | }
187 | 
188 | } // namespace deep_ep
189 | 


--------------------------------------------------------------------------------
/csrc/deep_ep.hpp:
--------------------------------------------------------------------------------
  1 | #pragma once
  2 | 
  3 | // Forcibly disable NDEBUG
  4 | #ifdef NDEBUG
  5 | #undef NDEBUG
  6 | #endif
  7 | 
  8 | #include <pybind11/pybind11.h>
  9 | #include <pybind11/pytypes.h>
 10 | #include <torch/types.h>
 11 | #include <tuple>
 12 | #include <vector>
 13 | 
 14 | #include "config.hpp"
 15 | #include "event.hpp"
 16 | #include "kernels/configs.cuh"
 17 | #include "kernels/exception.cuh"
 18 | 
 19 | #ifndef TORCH_EXTENSION_NAME
 20 | #define TORCH_EXTENSION_NAME deep_ep_cpp
 21 | #endif
 22 | 
 23 | namespace deep_ep {
 24 | 
 25 | struct Buffer {
 26 |     EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8");
 27 | 
 28 | private:
 29 |     // Low-latency mode buffer
 30 |     int low_latency_buffer_idx = 0;
 31 |     bool low_latency_mode = false;
 32 | 
 33 |     // NVLink Buffer
 34 |     int64_t num_nvl_bytes;
 35 |     void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
 36 |     void** buffer_ptrs_gpu = nullptr;
 37 | 
 38 |     // NVSHMEM Buffer
 39 |     int64_t num_rdma_bytes;
 40 |     void* rdma_buffer_ptr = nullptr;
 41 | 
 42 |     // Device info and communication
 43 |     int device_id;
 44 |     int num_device_sms;
 45 |     int rank, rdma_rank, nvl_rank;
 46 |     int num_ranks, num_rdma_ranks, num_nvl_ranks;
 47 |     cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
 48 | 
 49 |     // Stream for communication
 50 |     at::cuda::CUDAStream comm_stream;
 51 | 
 52 |     // After IPC/NVSHMEM synchronization, this flag will be true
 53 |     bool available = false;
 54 | 
 55 |     // Whether explicit `destroy()` is required.
 56 |     bool explicitly_destroy;
 57 |     // After `destroy()` be called, this flag will be true
 58 |     bool destroyed = false;
 59 | 
 60 |     // Barrier signals
 61 |     int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
 62 |     int** barrier_signal_ptrs_gpu = nullptr;
 63 | 
 64 |     // Workspace
 65 |     void* workspace = nullptr;
 66 | 
 67 |     // Host-side MoE info
 68 |     volatile int* moe_recv_counter = nullptr;
 69 |     int* moe_recv_counter_mapped = nullptr;
 70 | 
 71 |     // Host-side expert-level MoE info
 72 |     volatile int* moe_recv_expert_counter = nullptr;
 73 |     int* moe_recv_expert_counter_mapped = nullptr;
 74 | 
 75 |     // Host-side RDMA-level MoE info
 76 |     volatile int* moe_recv_rdma_counter = nullptr;
 77 |     int* moe_recv_rdma_counter_mapped = nullptr;
 78 | 
 79 | public:
 80 |     Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy);
 81 | 
 82 |     ~Buffer() noexcept(false);
 83 | 
 84 |     bool is_available() const;
 85 | 
 86 |     bool is_internode_available() const;
 87 | 
 88 |     int get_num_rdma_ranks() const;
 89 | 
 90 |     int get_rdma_rank() const;
 91 | 
 92 |     int get_root_rdma_rank(bool global) const;
 93 | 
 94 |     int get_local_device_id() const;
 95 | 
 96 |     pybind11::bytearray get_local_ipc_handle() const;
 97 | 
 98 |     pybind11::bytearray get_local_nvshmem_unique_id() const;
 99 | 
100 |     torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const;
101 | 
102 |     torch::Stream get_comm_stream() const;
103 | 
104 |     void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt);
105 | 
106 |     void destroy();
107 | 
108 |     std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
109 |     get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event,
110 |                         bool async, bool allocate_on_comm_stream);
111 | 
112 |     std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
113 |     intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
114 |                        const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
115 |                        const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
116 |                        int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
117 |                        int expert_alignment, int num_worst_tokens, const Config& config,
118 |                        std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
119 | 
120 |     std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
121 |     intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
122 |                       const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
123 |                       const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
124 |                       const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
125 | 
126 |     std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<EventHandle>>
127 |     internode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
128 |                        const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
129 |                        const std::optional<torch::Tensor>& num_tokens_per_rank, const std::optional<torch::Tensor>& num_tokens_per_rdma_rank,
130 |                        const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
131 |                        int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
132 |                        const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,
133 |                        const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
134 |                        int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
135 | 
136 |     std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
137 |     internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
138 |                       const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
139 |                       const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
140 |                       const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
141 |                       const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
142 |                       const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
143 | 
144 |     void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
145 | 
146 |     std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
147 |     low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
148 |                          const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
149 |                          int num_max_dispatch_tokens_per_rank, int num_experts,
150 |                          bool use_fp8, bool round_scale, bool use_ue8m0,
151 |                          bool async, bool return_recv_hook);
152 | 
153 |     std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
154 |     low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
155 |                         const torch::Tensor& src_info, const torch::Tensor& layout_range,
156 |                         int num_max_dispatch_tokens_per_rank, int num_experts,
157 |                         bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
158 |                         const std::optional<torch::Tensor>& out = std::nullopt);
159 | 
160 |     torch::Tensor
161 |     get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const;
162 | };
163 | 
164 | } // namespace deep_ep
165 | 


--------------------------------------------------------------------------------
/csrc/event.hpp:
--------------------------------------------------------------------------------
 1 | #include <ATen/cuda/CUDAContext.h>
 2 | #include <memory>
 3 | 
 4 | #include "kernels/exception.cuh"
 5 | 
 6 | namespace deep_ep {
 7 | 
 8 | struct EventHandle {
 9 |     std::shared_ptr<torch::Event> event;
10 | 
11 |     EventHandle() {
12 |         event = std::make_shared<torch::Event>(torch::kCUDA);
13 |         event->record(at::cuda::getCurrentCUDAStream());
14 |     }
15 | 
16 |     explicit EventHandle(const at::cuda::CUDAStream& stream) {
17 |         event = std::make_shared<torch::Event>(torch::kCUDA);
18 |         event->record(stream);
19 |     }
20 | 
21 |     EventHandle(const EventHandle& other) = default;
22 | 
23 |     void current_stream_wait() const {
24 |         at::cuda::getCurrentCUDAStream().unwrap().wait(*event);
25 |     }
26 | };
27 | 
28 | torch::Event create_event(const at::cuda::CUDAStream &s) {
29 |     auto event = torch::Event(torch::kCUDA);
30 |     event.record(s);
31 |     return event;
32 | }
33 | 
34 | void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) {
35 |     EP_HOST_ASSERT(s_0.id() != s_1.id());
36 |     s_0.unwrap().wait(create_event(s_1));
37 | }
38 | 
39 | void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) {
40 |     s.unwrap().wait(*event.event);
41 | }
42 | 
43 | } // namespace deep_ep
44 | 


--------------------------------------------------------------------------------
/csrc/kernels/CMakeLists.txt:
--------------------------------------------------------------------------------
 1 | function(add_deep_ep_library target_name source_file)
 2 |     add_library(${target_name} STATIC ${source_file})
 3 |     set_target_properties(${target_name} PROPERTIES
 4 |             POSITION_INDEPENDENT_CODE ON
 5 |             CXX_STANDARD_REQUIRED ON
 6 |             CUDA_STANDARD_REQUIRED ON
 7 |             CXX_STANDARD 17
 8 |             CUDA_STANDARD 17
 9 |             CUDA_SEPARABLE_COMPILATION ON
10 |     )
11 |     target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5)
12 | endfunction()
13 | 
14 | add_deep_ep_library(runtime_cuda runtime.cu)
15 | add_deep_ep_library(layout_cuda layout.cu)
16 | add_deep_ep_library(intranode_cuda intranode.cu)
17 | add_deep_ep_library(internode_cuda internode.cu)
18 | add_deep_ep_library(internode_ll_cuda internode_ll.cu)
19 | 
20 | # Later, we should link all libraries in `EP_CUDA_LIBRARIES`
21 | set(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda internode_ll_cuda PARENT_SCOPE)
22 | 


--------------------------------------------------------------------------------
/csrc/kernels/api.cuh:
--------------------------------------------------------------------------------
  1 | #pragma once
  2 | 
  3 | #include <vector>
  4 | 
  5 | namespace deep_ep {
  6 | 
  7 | // Intranode runtime
  8 | namespace intranode {
  9 | 
 10 | void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream);
 11 | 
 12 | } // namespace intranode
 13 | 
 14 | // Internode runtime
 15 | namespace internode {
 16 | 
 17 | std::vector<uint8_t> get_unique_id();
 18 | 
 19 | int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode);
 20 | 
 21 | void *alloc(size_t size, size_t alignment);
 22 | 
 23 | void free(void *ptr);
 24 | 
 25 | void barrier();
 26 | 
 27 | void finalize();
 28 | 
 29 | } // namespace internode
 30 | 
 31 | // Layout kernels
 32 | namespace layout {
 33 | 
 34 | void get_dispatch_layout(const int64_t* topk_idx,
 35 |                          int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
 36 |                          int* num_tokens_per_expert, bool* is_token_in_rank,
 37 |                          int num_tokens, int num_topk, int num_ranks, int num_experts,
 38 |                          cudaStream_t stream);
 39 | 
 40 | } // namespace layout
 41 | 
 42 | // Intranode kernels
 43 | namespace intranode {
 44 | 
 45 | void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
 46 |                      const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
 47 |                      int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
 48 |                      int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
 49 |                      void** buffer_ptrs, int** barrier_signal_ptrs, int rank,
 50 |                      cudaStream_t stream, int num_sms);
 51 | 
 52 | void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
 53 |                             void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks,
 54 |                             cudaStream_t stream);
 55 | 
 56 | void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
 57 |               int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
 58 |               const bool* is_token_in_rank, const int* channel_prefix_matrix,
 59 |               int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
 60 |               int scale_token_stride, int scale_hidden_stride,
 61 |               void** buffer_ptrs, int rank, int num_ranks,
 62 |               cudaStream_t stream, int num_sms,
 63 |               int num_max_send_tokens, int num_recv_buffer_tokens);
 64 | 
 65 | void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
 66 |                            int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream);
 67 | 
 68 | void combine(cudaDataType_t type,
 69 |              void* recv_x, float* recv_topk_weights,
 70 |              const void* x, const float* topk_weights,
 71 |              const void* bias_0, const void* bias_1,
 72 |              const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
 73 |              int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
 74 |              void** buffer_ptrs, int rank, int num_ranks,
 75 |              cudaStream_t stream, int num_sms,
 76 |              int num_max_send_tokens, int num_recv_buffer_tokens);
 77 | 
 78 | } // namespace intranode
 79 | 
 80 | // Internode kernels
 81 | namespace internode {
 82 | 
 83 | int get_source_meta_bytes();
 84 | 
 85 | void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
 86 |                      const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
 87 |                      const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
 88 |                      const bool* is_token_in_rank, int num_tokens, int num_channels,
 89 |                      int hidden_int4, int num_scales, int num_topk, int expert_alignment,
 90 |                      int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
 91 |                      int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
 92 |                      void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
 93 |                      void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
 94 |                      int** barrier_signal_ptrs, int rank,
 95 |                      cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
 96 |                      bool low_latency_mode);
 97 | 
 98 | void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
 99 |               const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
100 |               int* send_rdma_head, int* send_nvl_head,
101 |               int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
102 |               const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
103 |               const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
104 |               const bool* is_token_in_rank,
105 |               int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
106 |               int scale_token_stride, int scale_hidden_stride,
107 |               void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
108 |               void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
109 |               int rank, int num_ranks, bool is_cached_dispatch,
110 |               cudaStream_t stream, int num_channels, bool low_latency_mode);
111 | 
112 | void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
113 |                    int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head,
114 |                    const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
115 |                    void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
116 |                    void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
117 |                    int** barrier_signal_ptrs, int rank, cudaStream_t stream,
118 |                    int64_t num_rdma_bytes, int64_t num_nvl_bytes,
119 |                    bool is_cached_dispatch, bool low_latency_mode);
120 | 
121 | void combine(cudaDataType_t type,
122 |              void* combined_x, float* combined_topk_weights,
123 |              const bool* is_combined_token_in_rank,
124 |              const void* x, const float* topk_weights,
125 |              const void* bias_0, const void* bias_1,
126 |              const int* combined_rdma_head, const int* combined_nvl_head,
127 |              const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
128 |              int num_tokens, int num_combined_tokens, int hidden, int num_topk,
129 |              void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
130 |              void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
131 |              int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode);
132 | 
133 | } // namespace internode
134 | 
135 | // Internode low-latency kernels
136 | namespace internode_ll {
137 | 
138 | void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
139 |                               int* clean_1, int num_clean_int_1,
140 |                               cudaStream_t stream);
141 | 
142 | void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
143 |               int* packed_recv_src_info, int64_t* packed_recv_layout_range,
144 |               int* packed_recv_count,
145 |               int* cumulative_local_expert_recv_stats,
146 |               void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
147 |               const void* x, const int64_t* topk_idx,
148 |               int* next_clean, int num_next_clean_int,
149 |               int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
150 |               int num_topk, int num_experts, int rank, int num_ranks,
151 |               bool use_fp8, bool round_scale, bool use_ue8m0,
152 |               void* workspace, int num_device_sms,
153 |               cudaStream_t stream, int phases);
154 | 
155 | void combine(void* combined_x,
156 |              void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
157 |              const void* x, const int64_t* topk_idx, const float* topk_weights,
158 |              const int* src_info, const int64_t* layout_range,
159 |              int* next_clean, int num_next_clean_int,
160 |              int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
161 |              int num_topk, int num_experts, int rank, int num_ranks,
162 |              bool use_logfmt,
163 |              void* workspace, int num_device_sms,
164 |              cudaStream_t stream, int phases, bool zero_copy);
165 | 
166 | } // namespace internode_ll
167 | 
168 | } // namespace deep_ep
169 | 


--------------------------------------------------------------------------------
/csrc/kernels/buffer.cuh:
--------------------------------------------------------------------------------
  1 | #pragma once
  2 | 
  3 | #include "configs.cuh"
  4 | #include "exception.cuh"
  5 | 
  6 | namespace deep_ep {
  7 | 
  8 | template <typename dtype_t>
  9 | struct Buffer {
 10 | private:
 11 |     uint8_t* ptr;
 12 | 
 13 | public:
 14 |     int total_bytes;
 15 | 
 16 |     __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {}
 17 | 
 18 |     __device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) {
 19 |         total_bytes = num_elems * sizeof(dtype_t);
 20 |         ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);
 21 |         gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
 22 |     }
 23 | 
 24 |     __device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) {
 25 |         gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
 26 |         return *this;
 27 |     }
 28 | 
 29 |     __device__ __forceinline__ dtype_t* buffer() {
 30 |         return reinterpret_cast<dtype_t*>(ptr);
 31 |     }
 32 | 
 33 |     __device__ __forceinline__ dtype_t& operator[](int idx) {
 34 |         return buffer()[idx];
 35 |     }
 36 | };
 37 | 
 38 | template <typename dtype_t, int kNumRanks = 1>
 39 | struct AsymBuffer {
 40 | private:
 41 |     uint8_t* ptrs[kNumRanks];
 42 |     int num_bytes;
 43 | 
 44 | public:
 45 |     int total_bytes;
 46 | 
 47 |     __device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
 48 |                                           int sm_id = 0, int num_sms = 1, int offset = 0) {
 49 |         EP_STATIC_ASSERT(kNumRanks == 1, "");
 50 |         num_bytes = num_elems * sizeof(dtype_t);
 51 | 
 52 |         int per_channel_bytes = num_bytes * num_ranks;
 53 |         total_bytes = per_channel_bytes * num_sms;
 54 |         ptrs[0] = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
 55 |         gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
 56 |     }
 57 | 
 58 |     __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks,
 59 |                                           int sm_id = 0, int num_sms = 1, int offset = 0) {
 60 |         EP_STATIC_ASSERT(kNumRanks > 1, "");
 61 |         num_bytes = num_elems * sizeof(dtype_t);
 62 | 
 63 |         int per_channel_bytes = num_bytes * num_ranks;
 64 |         total_bytes = per_channel_bytes * num_sms;
 65 |         for (int i = 0; i < kNumRanks; ++ i) {
 66 |             ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
 67 |             gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
 68 |         }
 69 |     }
 70 | 
 71 |     __device__ __forceinline__ void advance(int shift) {
 72 |         #pragma unroll
 73 |         for (int i = 0; i < kNumRanks; ++ i)
 74 |             ptrs[i] = ptrs[i] + shift * sizeof(dtype_t);
 75 |     }
 76 | 
 77 |     __device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) {
 78 |         gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
 79 |         return *this;
 80 |     }
 81 | 
 82 |     template<int kNumAlsoRanks>
 83 |     __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) {
 84 |         for (int i = 0; i < kNumAlsoRanks; ++ i)
 85 |             gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
 86 |         return *this;
 87 |     }
 88 | 
 89 |     __device__ __forceinline__ dtype_t* buffer(int idx = 0) {
 90 |         EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case");
 91 |         return reinterpret_cast<dtype_t*>(ptrs[0] + num_bytes * idx);
 92 |     }
 93 | 
 94 |     __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) {
 95 |         EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case");
 96 |         return reinterpret_cast<dtype_t*>(ptrs[rank_idx] + num_bytes * idx);
 97 |     }
 98 | };
 99 | 
100 | template <typename dtype_t, bool kDecoupled = true>
101 | struct SymBuffer {
102 | private:
103 |     // NOTES: for non-decoupled case, `recv_ptr` is not used
104 |     uint8_t* send_ptr;
105 |     uint8_t* recv_ptr;
106 |     int num_bytes;
107 | 
108 | public:
109 |     int total_bytes;
110 | 
111 |     __device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
112 |                                          int sm_id = 0, int num_sms = 1) {
113 |         num_bytes = num_elems * sizeof(dtype_t);
114 | 
115 |         int per_channel_bytes = num_bytes * num_ranks;
116 |         total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);
117 |         send_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;
118 |         recv_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
119 |         gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
120 |     }
121 | 
122 |     __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) {
123 |         EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case");
124 |         return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
125 |     }
126 | 
127 |     __device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) {
128 |         EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case");
129 |         return reinterpret_cast<dtype_t*>(recv_ptr + num_bytes * idx);
130 |     }
131 | 
132 |     __device__ __forceinline__ dtype_t* buffer(int idx = 0) {
133 |         EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case");
134 |         return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
135 |     }
136 | };
137 | 
138 | } // namespace deep_ep
139 | 


--------------------------------------------------------------------------------
/csrc/kernels/configs.cuh:
--------------------------------------------------------------------------------
 1 | #pragma once
 2 | 
 3 | #define NUM_MAX_NVL_PEERS 8
 4 | #define NUM_MAX_RDMA_PEERS 20
 5 | #define NUM_WORKSPACE_BYTES (32 * 1024 * 1024)
 6 | #define NUM_MAX_LOCAL_EXPERTS 1024
 7 | #define NUM_BUFFER_ALIGNMENT_BYTES 128
 8 | 
 9 | #define FINISHED_SUM_TAG 1024
10 | #define NUM_WAIT_NANOSECONDS 500
11 | 
12 | #ifndef ENABLE_FAST_DEBUG
13 | #define NUM_CPU_TIMEOUT_SECS 100
14 | #define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s
15 | #else
16 | #define NUM_CPU_TIMEOUT_SECS 10
17 | #define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s
18 | #endif
19 | 
20 | #define LOW_LATENCY_SEND_PHASE 1
21 | #define LOW_LATENCY_RECV_PHASE 2
22 | 
23 | // Make CLion CUDA indexing work
24 | #ifdef __CLION_IDE__
25 | #define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier)
26 | #define __CUDACC_RDC__ // NOLINT(*-reserved-identifier)
27 | #endif
28 | 
29 | // Remove Torch restrictions
30 | #ifdef __CUDA_NO_HALF_CONVERSIONS__
31 | #undef __CUDA_NO_HALF_CONVERSIONS__
32 | #endif
33 | #ifdef __CUDA_NO_HALF_OPERATORS__
34 | #undef __CUDA_NO_HALF_OPERATORS__
35 | #endif
36 | #ifdef __CUDA_NO_HALF2_OPERATORS__
37 | #undef __CUDA_NO_HALF2_OPERATORS__
38 | #endif
39 | #ifdef __CUDA_NO_BFLOAT16_CONVERSIONS__
40 | #undef __CUDA_NO_BFLOAT16_CONVERSIONS__
41 | #endif
42 | #ifdef __CUDA_NO_BFLOAT162_OPERATORS__
43 | #undef __CUDA_NO_BFLOAT162_OPERATORS__
44 | #endif
45 | 
46 | #include <cstdint>
47 | #include <cuda_bf16.h>
48 | #include <cuda_runtime.h>
49 | 
50 | #ifndef DISABLE_SM90_FEATURES
51 | #include <cuda_fp8.h>
52 | #else
53 | // Ampere does not support FP8 features
54 | #define __NV_E4M3 0
55 | #define __NV_E5M2 1
56 | typedef int __nv_fp8_interpretation_t;
57 | typedef int __nv_fp8x4_e4m3;
58 | typedef uint8_t __nv_fp8_storage_t;
59 | #endif
60 | 
61 | #ifndef DISABLE_NVSHMEM
62 | #include <nvshmem.h>
63 | #include <nvshmemx.h>
64 | #include <infiniband/mlx5dv.h>
65 | #include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
66 | #include <device_host_transport/nvshmem_common_ibgda.h>
67 | #endif
68 | 


--------------------------------------------------------------------------------
/csrc/kernels/exception.cuh:
--------------------------------------------------------------------------------
 1 | #pragma once
 2 | 
 3 | #include <string>
 4 | #include <exception>
 5 | 
 6 | #include "configs.cuh"
 7 | 
 8 | #ifndef EP_STATIC_ASSERT
 9 | #define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
10 | #endif
11 | 
12 | class EPException: public std::exception {
13 | private:
14 |     std::string message = {};
15 | 
16 | public:
17 |     explicit EPException(const char *name, const char* file, const int line, const std::string& error) {
18 |         message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'";
19 |     }
20 | 
21 |     const char *what() const noexcept override { return message.c_str(); }
22 | };
23 | 
24 | #ifndef CUDA_CHECK
25 | #define CUDA_CHECK(cmd) \
26 | do { \
27 |     cudaError_t e = (cmd); \
28 |     if (e != cudaSuccess) { \
29 |         throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \
30 |     } \
31 | } while (0)
32 | #endif
33 | 
34 | #ifndef EP_HOST_ASSERT
35 | #define EP_HOST_ASSERT(cond) \
36 | do { \
37 |     if (not (cond)) { \
38 |         throw EPException("Assertion", __FILE__, __LINE__, #cond); \
39 |     } \
40 | } while (0)
41 | #endif
42 | 
43 | #ifndef EP_DEVICE_ASSERT
44 | #define EP_DEVICE_ASSERT(cond) \
45 | do { \
46 |     if (not (cond)) { \
47 |         printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
48 |         asm("trap;"); \
49 |     } \
50 | } while (0)
51 | #endif
52 | 


--------------------------------------------------------------------------------
/csrc/kernels/ibgda_device.cuh:
--------------------------------------------------------------------------------
  1 | // Portions derived from NVSHMEM (https://developer.nvidia.com/nvshmem)
  2 | // Copyright (c) NVIDIA Corporation.
  3 | // Licensed under the NVSHMEM Software License Agreement (version: September 3, 2019).
  4 | // See full license at: https://docs.nvidia.com/nvshmem/api/sla.html
  5 | //
  6 | // Modified from original source:
  7 | //  - nvshmem/src/include/non_abi/device/pt-to-pt/ibgda_device.cuh
  8 | #pragma once
  9 | 
 10 | #include "configs.cuh"
 11 | #include "exception.cuh"
 12 | #include "utils.cuh"
 13 | 
 14 | namespace deep_ep {
 15 | 
 16 | EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth");
 17 | 
 18 | __device__ static __forceinline__
 19 | uint64_t HtoBE64(uint64_t x) {
 20 |     uint64_t ret;
 21 |     asm("{\n\t"
 22 |         ".reg .b32 ign;\n\t"
 23 |         ".reg .b32 lo;\n\t"
 24 |         ".reg .b32 hi;\n\t"
 25 |         ".reg .b32 new_lo;\n\t"
 26 |         ".reg .b32 new_hi;\n\t"
 27 |         "mov.b64 {lo,hi}, %1;\n\t"
 28 |         "prmt.b32 new_hi, lo, ign, 0x0123;\n\t"
 29 |         "prmt.b32 new_lo, hi, ign, 0x0123;\n\t"
 30 |         "mov.b64 %0, {new_lo,new_hi};\n\t"
 31 |         "}" : "=l"(ret) : "l"(x));
 32 |     return ret;
 33 | }
 34 | 
 35 | __device__ static __forceinline__
 36 | uint32_t HtoBE32(uint32_t x) {
 37 |     uint32_t ret;
 38 |     asm("{\n\t"
 39 |         ".reg .b32 ign;\n\t"
 40 |         "prmt.b32 %0, %1, ign, 0x0123;\n\t"
 41 |         "}" : "=r"(ret) : "r"(x));
 42 |     return ret;
 43 | }
 44 | 
 45 | __device__ static __forceinline__
 46 | uint16_t HtoBE16(uint16_t x) {
 47 |     // TODO: simplify PTX using 16-bit instructions
 48 |     auto a = static_cast<uint32_t>(x);
 49 |     uint32_t d;
 50 |     asm volatile(
 51 |         "{\n\t"
 52 |         ".reg .b32 mask;\n\t"
 53 |         ".reg .b32 ign;\n\t"
 54 |         "mov.b32 mask, 0x4401;\n\t"
 55 |         "mov.b32 ign, 0x0;\n\t"
 56 |         "prmt.b32 %0, %1, ign, mask;\n\t"
 57 |         "}"
 58 |         : "=r"(d)
 59 |         : "r"(a));
 60 |     return static_cast<uint16_t>(d);
 61 | }
 62 | 
 63 | typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t;
 64 | 
 65 | typedef struct {
 66 |     uint32_t add_data;
 67 |     uint32_t field_boundary;
 68 |     uint64_t reserved;
 69 | } __attribute__((__packed__)) ibgda_atomic_32_masked_fa_seg_t;
 70 | 
 71 | __device__ static __forceinline__
 72 | nvshmemi_ibgda_device_state_t* ibgda_get_state() {
 73 |     return &nvshmemi_ibgda_device_state_d;
 74 | }
 75 | 
 76 | __device__ static __forceinline__
 77 | nvshmemi_ibgda_device_qp_t* ibgda_get_rc(int pe, int id) {
 78 |     auto state = ibgda_get_state();
 79 |     const auto num_rc_per_pe = ibgda_get_state()->num_rc_per_pe;
 80 |     return &state->globalmem.rcs[pe * num_rc_per_pe * state->num_devices_initialized + id % (num_rc_per_pe * state->num_devices_initialized)];
 81 | }
 82 | 
 83 | __device__ static __forceinline__
 84 | void ibgda_lock_acquire(int *lock) {
 85 |     while (atomicCAS(lock, 0, 1) == 1);
 86 | 
 87 |     // Prevent reordering before the lock is acquired
 88 |     memory_fence_cta();
 89 | }
 90 | 
 91 | __device__ static __forceinline__
 92 | void ibgda_lock_release(int *lock) {
 93 |     memory_fence_cta();
 94 | 
 95 |     // Prevent reordering before lock is released
 96 |     st_na_relaxed(lock, 0);
 97 | }
 98 | 
 99 | __device__ static __forceinline__
100 | void ibgda_update_dbr(nvshmemi_ibgda_device_qp_t *qp, uint32_t dbrec_head) {
101 |     // `DBREC` contains the index of the next empty `WQEBB`
102 |     __be32 dbrec_val;
103 |     __be32 *dbrec_ptr = qp->tx_wq.dbrec;
104 | 
105 |     // This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(dbrec_head & 0xffff))`
106 |     asm("{\n\t"
107 |         ".reg .b32 dbrec_head_16b;\n\t"
108 |         ".reg .b32 ign;\n\t"
109 |         "and.b32 dbrec_head_16b, %1, 0xffff;\n\t"
110 |         "prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t"
111 |         "}"
112 |         : "=r"(dbrec_val)
113 |         : "r"(dbrec_head));
114 |     st_na_release(dbrec_ptr, dbrec_val);
115 | }
116 | 
117 | __device__ static __forceinline__
118 | void ibgda_ring_db(nvshmemi_ibgda_device_qp_t *qp, uint16_t prod_idx) {
119 |     auto bf_ptr = reinterpret_cast<uint64_t*>(qp->tx_wq.bf);
120 |     ibgda_ctrl_seg_t ctrl_seg = {
121 |         .opmod_idx_opcode = HtoBE32(prod_idx << 8),
122 |         .qpn_ds = HtoBE32(qp->qpn << 8)
123 |     };
124 | 
125 |     EP_STATIC_ASSERT(sizeof(decltype(&ctrl_seg)) == sizeof(uint64_t), "");
126 |     st_na_release(bf_ptr, *(reinterpret_cast<uint64_t*>(&ctrl_seg)));
127 | }
128 | 
129 | __device__ static __forceinline__
130 | void ibgda_post_send(nvshmemi_ibgda_device_qp_t *qp, uint64_t new_prod_idx) {
131 |     nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars;
132 |     uint64_t old_prod_idx;
133 | 
134 |     // Update `prod_idx` before ringing the doorbell, so that we know which index is needed in quiet/fence
135 |     ibgda_lock_acquire(&mvars->post_send_lock);
136 | 
137 |     old_prod_idx = atomicMax(reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.prod_idx), new_prod_idx);
138 |     if (new_prod_idx > old_prod_idx) {
139 |         ibgda_update_dbr(qp, new_prod_idx);
140 |         ibgda_ring_db(qp, new_prod_idx);
141 |     }
142 |     ibgda_lock_release(&mvars->post_send_lock);
143 | }
144 | 
145 | template <bool kAlwaysDoPostSend>
146 | __device__ static __forceinline__
147 | void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t *qp, uint64_t base_wqe_idx,
148 |                            uint32_t num_wqes, int message_idx = 0) {
149 |     auto state = ibgda_get_state();
150 |     nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars;
151 |     uint64_t new_wqe_idx = base_wqe_idx + num_wqes;
152 | 
153 |     // WQE writes must be finished first
154 |     __threadfence();
155 | 
156 |     unsigned long long int *ready_idx =
157 |         (unsigned long long int *)(state->use_async_postsend ? qp->tx_wq.prod_idx
158 |                                                              : &mvars->tx_wq.ready_head);
159 | 
160 |     // Wait for prior WQE slots to be filled first
161 |     while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx);
162 | 
163 |     // Always post, not in batch
164 |     if (!state->use_async_postsend) {
165 |         constexpr int kNumRequestInBatch = 4;
166 |         if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0)
167 |             ibgda_post_send(qp, new_wqe_idx);
168 |     }
169 | }
170 | 
171 | __device__ static __forceinline__ void
172 | ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *val, uint64_t raddr,
173 |                                __be32 rkey, uint16_t wqe_idx, void** out_wqes, uint32_t imm) {
174 |     ibgda_ctrl_seg_t ctrl_seg;
175 |     struct mlx5_wqe_raddr_seg raddr_seg;
176 |     struct mlx5_wqe_inl_data_seg inl_seg;
177 | 
178 |     auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
179 |     auto *raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
180 |     auto *inl_seg_ptr = reinterpret_cast<mlx5_wqe_inl_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
181 |     auto *wqe_data_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(inl_seg_ptr) + sizeof(*inl_seg_ptr));
182 | 
183 |     raddr_seg.raddr = HtoBE64(raddr);
184 |     raddr_seg.rkey = rkey;
185 |     raddr_seg.reserved = 0;
186 | 
187 |     inl_seg.byte_count = HtoBE32(4 | MLX5_INLINE_SEG);
188 | 
189 |     // `imm == std::numeric_limits<uint32_t>::max()` means no imm writes
190 |     ctrl_seg = {0};
191 |     ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);
192 |     ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
193 |     ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | (imm != std::numeric_limits<uint32_t>::max() ? MLX5_OPCODE_RDMA_WRITE_IMM : MLX5_OPCODE_RDMA_WRITE));
194 |     if (imm != std::numeric_limits<uint32_t>::max())
195 |         ctrl_seg.imm = HtoBE32(imm);
196 | 
197 |     EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16");
198 |     EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16");
199 |     EP_STATIC_ASSERT(sizeof(*inl_seg_ptr) == 4, "sizeof(*inl_seg_ptr) == 4");
200 |     st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));
201 |     st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));
202 |     st_na_relaxed(reinterpret_cast<uint32_t*>(inl_seg_ptr), *reinterpret_cast<const uint32_t*>(&inl_seg));
203 |     st_na_relaxed(reinterpret_cast<uint32_t*>(wqe_data_ptr), *reinterpret_cast<const uint32_t*>(val));
204 | }
205 | 
206 | __device__ static __forceinline__
207 | uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
208 |                                  uint64_t raddr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey, uint32_t dev_idx) {
209 |     auto state = ibgda_get_state();
210 |     auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
211 |     auto log2_cumem_granularity = state->log2_cumem_granularity;
212 | 
213 |     // Local key
214 |     uint64_t idx = ((laddr - heap_start) >> log2_cumem_granularity) * state->num_devices_initialized + dev_idx;
215 |     auto device_key = state->constmem.lkeys[idx];
216 |     auto lchunk_size = device_key.next_addr - laddr;
217 |     *lkey = device_key.key;
218 | 
219 |     // Remote key
220 |     uint64_t roffset = raddr - heap_start;
221 | 
222 |     idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) * state->num_devices_initialized
223 |           + dst_pe * state->num_devices_initialized + dev_idx;
224 |     if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) {
225 |         device_key = state->constmem.rkeys[idx];
226 |     } else {
227 |         device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];
228 |     }
229 |     *out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;
230 |     *out_rkey = device_key.key;
231 | 
232 |     // Return the minimum of local and remote chunk sizes
233 |     auto rchunk_size = device_key.next_addr - roffset;
234 |     return min(lchunk_size, rchunk_size);
235 | }
236 | 
237 | __device__ static __forceinline__ void
238 | ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey, uint32_t dev_idx) {
239 |     auto state = ibgda_get_state();
240 |     auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
241 | 
242 |     uint64_t roffset = addr - heap_start;
243 |     uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes * state->num_devices_initialized)
244 |                    + dst_pe * state->num_devices_initialized + dev_idx;
245 |     nvshmemi_ibgda_device_key_t device_key;
246 |     if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS)
247 |         device_key = state->constmem.rkeys[idx];
248 |     else
249 |         device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];
250 |     *out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;
251 |     *out_rkey = device_key.key;
252 | }
253 | 
254 | __device__ static __forceinline__ uint64_t
255 | ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) {
256 |     auto mvars = &qp->mvars;
257 |     return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
258 | }
259 | 
260 | __device__ static __forceinline__ void*
261 | ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
262 |     uint16_t cnt = qp->tx_wq.nwqes;
263 |     uint16_t idx = wqe_idx & (cnt - 1);
264 |     return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT));
265 | }
266 | 
267 | __device__ static __forceinline__ void
268 | nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits<uint32_t>::max()) {
269 |     // Get rkey
270 |     // NOTES: the `p` operation will not cross multiple remote chunks
271 |     __be32 rkey;
272 |     uint64_t raddr;
273 |     auto qp = ibgda_get_rc(dst_pe, qp_id);
274 |     ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), dst_pe, &raddr, &rkey, qp->dev_idx);
275 |     
276 |     // Write WQEs
277 |     uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
278 |     void *wqe_ptrs;
279 |     wqe_ptrs = ibgda_get_wqe_ptr(qp, base_wqe_idx);
280 |     ibgda_write_rdma_write_inl_wqe(qp, reinterpret_cast<const uint32_t*>(&value), raddr, rkey, base_wqe_idx, &wqe_ptrs, imm);
281 | 
282 |     // Submit requests
283 |     ibgda_submit_requests<true>(qp, base_wqe_idx, 1);
284 | }
285 | 
286 | __device__ static __forceinline__ void
287 | ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t *qp, uint64_t laddr, __be32 lkey,
288 |                            uint64_t raddr, __be32 rkey, uint32_t bytes, uint16_t wqe_idx,
289 |                            void** out_wqes) {
290 |     ibgda_ctrl_seg_t ctrl_seg;
291 |     struct mlx5_wqe_raddr_seg raddr_seg;
292 |     struct mlx5_wqe_data_seg data_seg;
293 | 
294 |     auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
295 |     void *av_seg_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
296 |     struct mlx5_wqe_raddr_seg *raddr_seg_ptr;
297 |     struct mlx5_wqe_data_seg *data_seg_ptr;
298 | 
299 |     raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(av_seg_ptr));
300 |     data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
301 | 
302 |     raddr_seg.raddr = HtoBE64(raddr);
303 |     raddr_seg.rkey = rkey;
304 |     raddr_seg.reserved = 0;
305 | 
306 |     data_seg.byte_count = HtoBE32(bytes);
307 |     data_seg.lkey = lkey;
308 |     data_seg.addr = HtoBE64(laddr);
309 | 
310 |     ctrl_seg = {0};
311 |     ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);
312 |     ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
313 |     ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | MLX5_OPCODE_RDMA_WRITE);
314 | 
315 |     EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16");
316 |     EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16");
317 |     EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == 16, "sizeof(*data_seg_ptr) == 16");
318 |     st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));
319 |     st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));
320 |     st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
321 | }
322 | 
323 | __device__ static __forceinline__ void
324 | ibgda_write_empty_recv_wqe(void *out_wqe) {
325 |     auto *data_seg_ptr = reinterpret_cast<struct mlx5_wqe_data_seg*>(out_wqe);
326 |     struct mlx5_wqe_data_seg data_seg;
327 | 
328 |     // Make the first segment in the WQE invalid, then the entire list will be invalid
329 |     data_seg.byte_count = 0;
330 |     data_seg.lkey = HtoBE64(MLX5_INVALID_LKEY);
331 |     data_seg.addr = 0;
332 | 
333 |     EP_STATIC_ASSERT(sizeof(mlx5_wqe_data_seg) == sizeof(int4), "Invalid data type length");
334 |     st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
335 | }
336 | 
337 | template <bool kAlwaysDoPostSend = false>
338 | __device__ static __forceinline__ void
339 | nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
340 |     // Get lkey and rkey, store them into lanes
341 |     uint32_t num_wqes = 0;
342 |     __be32 my_lkey = 0;
343 |     uint64_t my_laddr = 0;
344 |     __be32 my_rkey = 0;
345 |     uint64_t my_raddr = 0;
346 |     uint64_t my_chunk_size = 0;
347 | 
348 |     auto qp = ibgda_get_rc(dst_pe, qp_id);
349 | 
350 |     // Decide how many messages (theoretically 3 for maximum)
351 |     auto remaining_bytes = bytes;
352 |     while (remaining_bytes > 0) {
353 |         if (lane_id == num_wqes) {
354 |             my_chunk_size = min(remaining_bytes, 
355 |                                 ibgda_get_lkey_and_rkey(my_laddr = req_lptr, 
356 |                                                         &my_lkey, 
357 |                                                         req_rptr, 
358 |                                                         dst_pe, 
359 |                                                         &my_raddr, 
360 |                                                         &my_rkey, 
361 |                                                         qp->dev_idx));
362 |         }
363 | 
364 |         // Move one more message
365 |         auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast<int>(num_wqes));
366 |         remaining_bytes -= chunk_size;
367 |         req_lptr += chunk_size;
368 |         req_rptr += chunk_size;
369 |         ++ num_wqes;
370 |     }
371 |     EP_DEVICE_ASSERT(num_wqes <= 32);
372 | 
373 |     // Process WQE
374 |     uint64_t base_wqe_idx = 0;
375 |     if (lane_id == 0)
376 |         base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes);
377 |     base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0);
378 |     if (lane_id < num_wqes) {
379 |         auto wqe_idx = base_wqe_idx + lane_id;
380 |         auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx);
381 |         ibgda_write_rdma_write_wqe(qp, my_laddr, my_lkey, my_raddr, my_rkey, my_chunk_size,
382 |                                    wqe_idx, &wqe_ptr);
383 |     }
384 |     __syncwarp();
385 | 
386 |     // Submit
387 |     if (lane_id == 0)
388 |         ibgda_submit_requests<kAlwaysDoPostSend>(qp, base_wqe_idx, num_wqes, message_idx);
389 |     __syncwarp();
390 | }
391 | 
392 | __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
393 |         nvshmemi_ibgda_device_qp_t *qp, const int &value,
394 |         uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey,
395 |         uint16_t wqe_idx, void** out_wqes) {
396 |     ibgda_ctrl_seg_t ctrl_seg = {0};
397 |     struct mlx5_wqe_raddr_seg raddr_seg;
398 |     struct mlx5_wqe_atomic_seg atomic_seg_1;
399 |     struct mlx5_wqe_data_seg data_seg;
400 | 
401 |     auto ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
402 |     auto raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
403 |     auto atomic_seg_ptr = reinterpret_cast<mlx5_wqe_atomic_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
404 |     auto data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(atomic_seg_ptr) + sizeof(*atomic_seg_ptr));
405 | 
406 |     raddr_seg.raddr = HtoBE64(raddr);
407 |     raddr_seg.rkey = rkey;
408 |     raddr_seg.reserved = 0;
409 | 
410 |     // NOTES: `0x08000000` means `IBGDA_4_BYTE_EXT_AMO_OPMOD`
411 |     ctrl_seg.opmod_idx_opcode = HtoBE32(MLX5_OPCODE_ATOMIC_MASKED_FA | (wqe_idx << 8) | 0x08000000);
412 |     auto atomic_32_masked_fa_seg = reinterpret_cast<ibgda_atomic_32_masked_fa_seg_t*>(&atomic_seg_1);
413 |     atomic_32_masked_fa_seg->add_data = HtoBE32(value);
414 |     atomic_32_masked_fa_seg->field_boundary = 0;
415 | 
416 |     ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 4);
417 |     ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
418 | 
419 |     data_seg.byte_count = HtoBE32(sizeof(int));
420 |     data_seg.lkey = lkey;
421 |     data_seg.addr = HtoBE64(laddr);
422 | 
423 |     EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == sizeof(int4), "Invalid vectorization");
424 |     EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == sizeof(int4), "Invalid vectorization");
425 |     EP_STATIC_ASSERT(sizeof(*atomic_seg_ptr) == sizeof(int4), "Invalid vectorization");
426 |     EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == sizeof(int4), "Invalid vectorization");
427 |     st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<int4*>(&ctrl_seg));
428 |     st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<int4*>(&raddr_seg));
429 |     st_na_relaxed(reinterpret_cast<int4*>(atomic_seg_ptr), *reinterpret_cast<int4*>(&atomic_seg_1));
430 |     st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<int4*>(&data_seg));
431 | }
432 | 
433 | __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) {
434 |     if (is_local_copy) {
435 |         atomicAdd(static_cast<unsigned long long*>(rptr), value);
436 |     } else {
437 |         nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
438 | 
439 |         __be32 rkey;
440 |         uint64_t raddr;
441 |         ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey, qp->dev_idx);
442 | 
443 |         uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
444 |         void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
445 | 
446 |         ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
447 |                                 qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);
448 | 
449 |         ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
450 |     }
451 | }
452 | 
453 | __device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, const int& rank, const int& dst_rank) {
454 |     // Local rank, no need for mapping
455 |     if (rank == dst_rank)
456 |         return ptr;
457 |     auto peer_base = __ldg(reinterpret_cast<uint64_t*>(nvshmemi_device_state_d.peer_heap_base_p2p) + dst_rank);
458 | 
459 |     // RDMA connected
460 |     if (peer_base == 0)
461 |         return 0;
462 | 
463 |     // NVLink P2P is enabled
464 |     return peer_base + (ptr - reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base));
465 | }
466 | 
467 | // This is a simplified version of NVSHMEM's `ibgda_poll_cq`. 
468 | // Note that this implementation does not guarantee thread safety,
469 | // so we must ensure that no other threads are concurrently using the same QP.
470 | __device__ static __forceinline__ void
471 | ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx) {
472 |     const auto cqe64 = static_cast<mlx5_cqe64*>(cq->cqe);
473 |     const uint32_t ncqes = cq->ncqes;
474 |     memory_fence_cta();
475 | 
476 |     // NOTES: this while loop is part of do-while below.
477 |     // `wqe_counter` is the HW consumer index. However, we always maintain `index + 1`.
478 |     // To be able to compare with the index, we need to use `wqe_counter + 1`.
479 |     // Because `wqe_counter` is `uint16_t`, it may be overflow. Still, we know for
480 |     // sure that if `idx - wqe_counter - 1 < ncqes`, `wqe_counter + 1 is less than
481 |     // idx, and thus we need to wait. We don't need to wait when `idx == wqe_counter + 1`
482 |     // That's why we use `- 2` here to make this case overflow.
483 |     uint16_t wqe_counter;
484 |     do {
485 |         wqe_counter = HtoBE16(ld_na_relaxed(&cqe64->wqe_counter));
486 |     } while ((static_cast<uint16_t>(static_cast<uint16_t>(idx) - wqe_counter - static_cast<uint16_t>(2)) < ncqes));
487 |     *cq->cons_idx = idx;
488 | 
489 |     // Prevent reordering of this function and later instructions
490 |     memory_fence_cta();
491 | }
492 | 
493 | // Wait until wqe `idx - 1` is completed.
494 | __device__ static __forceinline__ void
495 | nvshmemi_ibgda_quiet(int dst_pe, int qp_id) {
496 |     auto qp = ibgda_get_rc(dst_pe, qp_id);
497 |     auto state = ibgda_get_state();
498 |     uint64_t prod_idx = state->use_async_postsend ? ld_na_relaxed(qp->tx_wq.prod_idx) : ld_na_relaxed(&qp->mvars.tx_wq.ready_head);
499 |     ibgda_poll_cq(qp->tx_wq.cq, prod_idx);
500 | }
501 | 
502 | } // namespace deep_ep
503 | 


--------------------------------------------------------------------------------
/csrc/kernels/launch.cuh:
--------------------------------------------------------------------------------
 1 | #pragma once
 2 | 
 3 | #include "configs.cuh"
 4 | #include "exception.cuh"
 5 | 
 6 | #ifndef SETUP_LAUNCH_CONFIG
 7 | #ifndef DISABLE_SM90_FEATURES
 8 | #define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
 9 |     cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \
10 |     cudaLaunchAttribute attr[2]; \
11 |     attr[0].id = cudaLaunchAttributeCooperative; \
12 |     attr[0].val.cooperative = 1; \
13 |     attr[1].id = cudaLaunchAttributeClusterDimension; \
14 |     attr[1].val.clusterDim.x = (num_sms % 2 == 0 ? 2 : 1); \
15 |     attr[1].val.clusterDim.y = 1; \
16 |     attr[1].val.clusterDim.z = 1; \
17 |     cfg.attrs = attr; \
18 |     cfg.numAttrs = 2
19 | #else
20 | #define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
21 |     int __num_sms = (sms); \
22 |     int __num_threads = (threads); \
23 |     auto __stream = (stream)
24 | #endif
25 | #endif
26 | 
27 | #ifndef LAUNCH_KERNEL
28 | #ifndef DISABLE_SM90_FEATURES
29 | #define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__))
30 | #else
31 | #define LAUNCH_KERNEL(config, kernel, ...) \
32 | do { \
33 |     kernel<<<__num_sms, __num_threads, 0, __stream>>>(__VA_ARGS__); \
34 |     cudaError_t e = cudaGetLastError(); \
35 |     if (e != cudaSuccess) { \
36 |         EPException cuda_exception("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \
37 |         fprintf(stderr, "%s\n", cuda_exception.what()); \
38 |         throw cuda_exception; \
39 |     } \
40 | } while (0)
41 | #endif
42 | #endif
43 | 
44 | #ifndef SET_SHARED_MEMORY_FOR_TMA
45 | #ifndef DISABLE_SM90_FEATURES
46 | #define SET_SHARED_MEMORY_FOR_TMA(kernel) \
47 | EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
48 | cfg.dynamicSmemBytes = smem_size;
49 | #else
50 | #define SET_SHARED_MEMORY_FOR_TMA(kernel) void()
51 | #endif
52 | #endif
53 | 
54 | #define SWITCH_RANKS(case_macro) \
55 |     switch (num_ranks) { \
56 |         case 2: case_macro(2); \
57 |         case 4: case_macro(4); \
58 |         case 8: case_macro(8); \
59 |         default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
60 |     } while (false)
61 | 
62 | #define SWITCH_RDMA_RANKS(case_macro) \
63 |     switch (num_ranks / NUM_MAX_NVL_PEERS) { \
64 |         case 2: case_macro(2); \
65 |         case 4: case_macro(4); \
66 |         case 8: case_macro(8); \
67 |         case 16: case_macro(16); \
68 |         default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
69 |     } while (false)
70 | 
71 | #define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \
72 |     switch (num_ranks) { \
73 |         case 2: case_macro(dtype, 2); \
74 |         case 4: case_macro(dtype, 4); \
75 |         case 8: case_macro(dtype, 8); \
76 |         default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
77 |     } while (false)
78 | 
79 | #define SWITCH_TYPES(case_macro) \
80 |     switch (type) { \
81 |         case CUDA_R_16BF: case_macro(nv_bfloat16); \
82 |         default: EP_HOST_ASSERT(false and "Unsupported type"); \
83 |     } while (false)
84 | 
85 | #define SWITCH_HIDDEN(case_macro) \
86 |     switch (hidden) { \
87 |         case 2048: case_macro(2048); \
88 |         case 2560: case_macro(2560); \
89 |         case 4096: case_macro(4096); \
90 |         case 5120: case_macro(5120); \
91 |         case 7168: case_macro(7168); \
92 |         case 8192: case_macro(8192); \
93 |         default: EP_HOST_ASSERT(false and "Unsupported hidden"); \
94 |     } while (false)
95 | 


--------------------------------------------------------------------------------
/csrc/kernels/layout.cu:
--------------------------------------------------------------------------------
  1 | #include "configs.cuh"
  2 | #include "exception.cuh"
  3 | #include "launch.cuh"
  4 | 
  5 | namespace deep_ep {
  6 | 
  7 | namespace layout {
  8 | 
  9 | template <int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM>
 10 | __global__ void get_dispatch_layout(const int64_t* topk_idx,
 11 |                                     int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
 12 |                                     int* num_tokens_per_expert, bool* is_token_in_rank,
 13 |                                     int num_tokens, int num_topk, int num_ranks, int num_experts) {
 14 |     auto sm_id = static_cast<int>(blockIdx.x);
 15 |     auto thread_id = static_cast<int>(threadIdx.x);
 16 | 
 17 |     // Count expert statistics
 18 |     __shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM];
 19 |     int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts);
 20 |     if (expert_begin_idx < expert_end_idx) {
 21 |         // Per-thread count
 22 |         #pragma unroll
 23 |         for (int i = 0; i < kNumExpertsPerSM; ++ i)
 24 |             num_tokens_per_expert_per_thread[thread_id][i] = 0;
 25 |         #pragma unroll
 26 |         for (int i = thread_id; i < num_tokens; i += kNumThreads) {
 27 |             auto shifted_topk_idx = topk_idx + i * num_topk;
 28 |             #pragma unroll
 29 |             for (int j = 0, expert_idx; j < num_topk; ++ j) {
 30 |                 expert_idx = static_cast<int>(shifted_topk_idx[j]);
 31 |                 if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx)
 32 |                     ++ num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx];
 33 |             }
 34 |         }
 35 |         __syncthreads();
 36 | 
 37 |         // Sum up
 38 |         EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM");
 39 |         if (expert_begin_idx + thread_id < expert_end_idx) {
 40 |             int sum = 0;
 41 |             #pragma unroll
 42 |             for (int i = 0; i < kNumThreads; ++ i)
 43 |                 sum += num_tokens_per_expert_per_thread[i][thread_id];
 44 |             num_tokens_per_expert[expert_begin_idx + thread_id] = sum;
 45 |         }
 46 |         return;
 47 |     }
 48 | 
 49 |     if (num_tokens_per_rdma_rank != nullptr)
 50 |         EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS);
 51 | 
 52 |     // Count rank statistics
 53 |     constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS;
 54 |     __shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM];
 55 |     __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM];
 56 |     auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM;
 57 |     int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks);
 58 |     int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS;
 59 |     if (rank_begin_idx < rank_end_idx) {
 60 |         const auto num_expert_per_rank = num_experts / num_ranks;
 61 |         auto expert_begin = rank_begin_idx * num_expert_per_rank;
 62 |         auto expert_end = rank_end_idx * num_expert_per_rank;
 63 | 
 64 |         // Per-thread count
 65 |         #pragma unroll
 66 |         for (int i = 0; i < kNumRanksPerSM; ++ i)
 67 |             num_tokens_per_rank_per_thread[thread_id][i] = 0;
 68 |         #pragma unroll
 69 |         for (int i = 0; i < kNumRDMARanksPerSM; ++ i)
 70 |             num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0;
 71 |         #pragma unroll
 72 |         for (int i = thread_id; i < num_tokens; i += kNumThreads) {
 73 |             auto shifted_topk_idx = topk_idx + i * num_topk;
 74 |             int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0};
 75 |             #pragma unroll
 76 |             for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) {
 77 |                 expert_idx = static_cast<int>(shifted_topk_idx[j]);
 78 |                 if (expert_begin <= expert_idx and expert_idx < expert_end) {
 79 |                     // Count single rank
 80 |                     rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx;
 81 |                     is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++;
 82 |                 }
 83 |             }
 84 | 
 85 |             auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks;
 86 |             #pragma unroll
 87 |             for (int j = 0; j + rank_begin_idx < rank_end_idx; ++ j) {
 88 |                 shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0);
 89 |                 num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0);
 90 |             }
 91 | 
 92 |             #pragma unroll
 93 |             for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j)
 94 |                 num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0);
 95 |         }
 96 |         __syncthreads();
 97 | 
 98 |         // Sum up
 99 |         EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM");
100 |         if (rank_begin_idx + thread_id < rank_end_idx) {
101 |             int sum = 0;
102 |             #pragma unroll
103 |             for (int i = 0; i < kNumThreads; ++ i)
104 |                 sum += num_tokens_per_rank_per_thread[i][thread_id];
105 |             num_tokens_per_rank[rank_begin_idx + thread_id] = sum;
106 |         }
107 | 
108 |         if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) {
109 |             int sum = 0;
110 |             #pragma unroll
111 |             for (int i = 0; i < kNumThreads; ++ i)
112 |                 sum += num_tokens_per_rdma_rank_per_thread[i][thread_id];
113 |             num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum;
114 |         }
115 |     }
116 | }
117 | 
118 | void get_dispatch_layout(const int64_t* topk_idx,
119 |                          int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
120 |                          int* num_tokens_per_expert, bool* is_token_in_rank,
121 |                          int num_tokens, int num_topk, int num_ranks, int num_experts,
122 |                          cudaStream_t stream) {
123 |     constexpr int kNumThreads = 256, kNumExpertsPerSM = 4, kNumRanksPerSM = 8;
124 |     int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM;
125 |     EP_STATIC_ASSERT(kNumRanksPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of ranks per SM");
126 | 
127 |     SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
128 |     LAUNCH_KERNEL(&cfg, (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>),
129 |                   topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank,
130 |                   num_tokens, num_topk, num_ranks, num_experts);
131 | }
132 | 
133 | } // namespace layout
134 | 
135 | } // namespace deep_ep
136 | 


--------------------------------------------------------------------------------
/csrc/kernels/runtime.cu:
--------------------------------------------------------------------------------
 1 | #include <vector>
 2 | #include <cstring>
 3 | 
 4 | #include "configs.cuh"
 5 | #include "exception.cuh"
 6 | #include "launch.cuh"
 7 | #include "utils.cuh"
 8 | 
 9 | #ifndef DISABLE_NVSHMEM
10 | #include "nvshmem.h"
11 | #include "ibgda_device.cuh"
12 | #endif
13 | 
14 | namespace deep_ep {
15 | 
16 | namespace intranode {
17 | 
18 | template<int kNumRanks>
19 | __global__ void barrier(int** barrier_signal_ptrs, int rank) {
20 |     barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
21 | }
22 | 
23 | void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream) {
24 | #define BARRIER_LAUNCH_CASE(ranks) \
25 |     LAUNCH_KERNEL(&cfg, barrier<ranks>, barrier_signal_ptrs, rank); \
26 |     break
27 | 
28 |     SETUP_LAUNCH_CONFIG(1, 32, stream);
29 |     SWITCH_RANKS(BARRIER_LAUNCH_CASE);
30 | #undef BARRIER_LAUNCH_CASE
31 | }
32 | 
33 | } // namespace intranode
34 | 
35 | namespace internode {
36 | 
37 | #ifndef DISABLE_NVSHMEM
38 | nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID;
39 | nvshmem_team_config_t cpu_rdma_team_config;
40 | 
41 | std::vector<uint8_t> get_unique_id() {
42 |     nvshmemx_uniqueid_t unique_id;
43 |     nvshmemx_get_uniqueid(&unique_id);
44 |     std::vector<uint8_t> result(sizeof(nvshmemx_uniqueid_t));
45 |     std::memcpy(result.data(), &unique_id, sizeof(nvshmemx_uniqueid_t));
46 |     return result;
47 | }
48 | 
49 | int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) {
50 |     nvshmemx_uniqueid_t root_unique_id;
51 |     nvshmemx_init_attr_t attr;
52 |     std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(nvshmemx_uniqueid_t));
53 |     nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
54 |     nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
55 | 
56 |     // Create sub-RDMA teams
57 |     // NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
58 |     if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {
59 |         EP_HOST_ASSERT(cpu_rdma_team == NVSHMEM_TEAM_INVALID);
60 |         EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
61 |         EP_HOST_ASSERT(nvshmem_team_split_strided(NVSHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS, NUM_MAX_NVL_PEERS,
62 |                                                   num_ranks / NUM_MAX_NVL_PEERS, &cpu_rdma_team_config, 0, &cpu_rdma_team) == 0);
63 |         EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
64 |     }
65 | 
66 |     nvshmem_barrier_all();
67 |     return nvshmem_my_pe();
68 | }
69 | 
70 | void* alloc(size_t size, size_t alignment) {
71 |     return nvshmem_align(alignment, size);
72 | }
73 | 
74 | void free(void* ptr) {
75 |     nvshmem_free(ptr);
76 | }
77 | 
78 | void barrier() {
79 |     nvshmem_barrier_all();
80 | }
81 | 
82 | void finalize() {
83 |     if (cpu_rdma_team != NVSHMEM_TEAM_INVALID) {
84 |         nvshmem_team_destroy(cpu_rdma_team);
85 |         cpu_rdma_team = NVSHMEM_TEAM_INVALID;
86 |     }
87 |     nvshmem_finalize();
88 | }
89 | #endif
90 | 
91 | } // namespace internode
92 | 
93 | } // namespace deep_ep
94 | 


--------------------------------------------------------------------------------
/csrc/kernels/utils.cuh:
--------------------------------------------------------------------------------
  1 | #pragma once
  2 | 
  3 | #include "exception.cuh"
  4 | 
  5 | #define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
  6 | { \
  7 |     constexpr int kLoopStride = 32 * (UNROLL_FACTOR); \
  8 |     typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)]; \
  9 |     auto __src = (SRC); \
 10 |     auto __dst = (DST); \
 11 |     for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \
 12 |         _Pragma("unroll") \
 13 |         for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \
 14 |             unrolled_values[__j] = LD_FUNC(__src + __i + __j * 32); \
 15 |         _Pragma("unroll") \
 16 |         for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \
 17 |             ST_FUNC(__dst + __i + __j * 32, unrolled_values[__j]); \
 18 |     } \
 19 |     for (int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += 32) \
 20 |         ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \
 21 | }
 22 | 
 23 | namespace deep_ep {
 24 | 
 25 | template <int kBytes>
 26 | struct VecInt {};
 27 | template<> struct VecInt<1> { using vec_t = int8_t; };
 28 | template<> struct VecInt<2> { using vec_t = int16_t; };
 29 | template<> struct VecInt<4> { using vec_t = int; };
 30 | template<> struct VecInt<8> { using vec_t = int64_t; };
 31 | template<> struct VecInt<16> { using vec_t = int4; };
 32 | 
 33 | template <typename FuncT>
 34 | struct PatternVisitor {
 35 |     FuncT func;
 36 | 
 37 |     __device__ __host__
 38 |     explicit PatternVisitor(FuncT&& func): func(std::forward<FuncT>(func)) {}
 39 | 
 40 |     __device__ __host__
 41 |     auto operator [](const uint32_t& i) {
 42 |         return func(i);
 43 |     }
 44 | };
 45 | 
 46 | __device__ __forceinline__ void trap() {
 47 |     asm("trap;");
 48 | }
 49 | 
 50 | __device__ __forceinline__ void memory_fence() {
 51 |     asm volatile("fence.acq_rel.sys;":: : "memory");
 52 | }
 53 | 
 54 | __device__ __forceinline__ void memory_fence_gpu() {
 55 |     asm volatile("fence.acq_rel.gpu;":: : "memory");
 56 | }
 57 | 
 58 | __device__ __forceinline__ void memory_fence_cta() {
 59 |     asm volatile("fence.acq_rel.cta;":: : "memory");
 60 | }
 61 | 
 62 | __device__  __forceinline__ void st_relaxed_sys_global(const int *ptr, int val) {
 63 |     asm volatile("st.relaxed.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
 64 | }
 65 | 
 66 | __device__  __forceinline__ void st_release_sys_global(const int *ptr, int val) {
 67 |     asm volatile("st.release.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
 68 | }
 69 | 
 70 | __device__  __forceinline__ void st_release_cta(const int *ptr, int val) {
 71 |     asm volatile("st.release.cta.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
 72 | }
 73 | 
 74 | __device__ __forceinline__ int ld_acquire_sys_global(const int *ptr) {
 75 |     int ret;
 76 |     asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
 77 |     return ret;
 78 | }
 79 | 
 80 | __device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t *ptr) {
 81 |     uint64_t ret;
 82 |     asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
 83 |     return ret;
 84 | }
 85 | 
 86 | __device__ __forceinline__ int ld_acquire_global(const int *ptr) {
 87 |     int ret;
 88 |     asm volatile("ld.acquire.gpu.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
 89 |     return ret;
 90 | }
 91 | 
 92 | __device__ __forceinline__ int atomic_add_release_sys_global(const int* ptr, int value) {
 93 |     int ret;
 94 |     asm volatile("atom.add.release.sys.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value));
 95 |     return ret;
 96 | }
 97 | 
 98 | __device__ __forceinline__ int atomic_add_release_global(const int* ptr, int value) {
 99 |     int ret;
100 |     asm volatile("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value));
101 |     return ret;
102 | }
103 | 
104 | __device__ __forceinline__ int ld_acquire_cta(const int *ptr) {
105 |     int ret;
106 |     asm volatile("ld.acquire.cta.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
107 |     return ret;
108 | }
109 | 
110 | __device__ __forceinline__ uint8_t ld_na_relaxed(const uint8_t *ptr) {
111 |     uint16_t ret;
112 |     asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr));
113 |     return static_cast<uint8_t>(ret);
114 | }
115 | 
116 | __device__ __forceinline__ uint16_t ld_na_relaxed(const uint16_t *ptr) {
117 |     uint16_t ret;
118 |     asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr));
119 |     return ret;
120 | }
121 | 
122 | __device__ __forceinline__ uint32_t ld_na_relaxed(const uint32_t *ptr) {
123 |     uint32_t ret;
124 |     asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr));
125 |     return ret;
126 | }
127 | 
128 | __device__ __forceinline__ uint64_t ld_na_relaxed(const uint64_t *ptr) {
129 |     uint64_t ret;
130 |     asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr));
131 |     return ret;
132 | }
133 | 
134 | __device__  __forceinline__ int ld_volatile_global(const int *ptr) {
135 |     int ret;
136 |     asm volatile("ld.volatile.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
137 |     return ret;
138 | }
139 | 
140 | __device__  __forceinline__ float ld_volatile_global(const float *ptr) {
141 |     float ret;
142 |     asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
143 |     return ret;
144 | }
145 | 
146 | __device__  __forceinline__ int64_t ld_volatile_global(const int64_t *ptr) {
147 |     int64_t ret;
148 |     asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr));
149 |     return ret;
150 | }
151 | 
152 | __device__  __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) {
153 |     int64_t ret;
154 |     asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
155 |     return ret;
156 | }
157 | 
158 | #ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
159 | #define LD_NC_FUNC "ld.global.nc.L1::no_allocate.L2::256B"
160 | #else
161 | #define LD_NC_FUNC "ld.volatile.global.L2::256B"
162 | #endif
163 | 
164 | // `ld.global.nc.L1::no_allocate` will be translated into `LDG.E.NA.[width].CONSTANT` in SASS
165 | template <typename dtype_t>
166 | __device__  __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) {
167 |     auto ret = ld_nc_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr));
168 |     return *reinterpret_cast<dtype_t*>(&ret);
169 | }
170 | 
171 | template <>
172 | __device__  __forceinline__ uint8_t ld_nc_global(const uint8_t *ptr) {
173 |     uint16_t ret;
174 |     // NOTES: we must use `uint16_t` as inline ASM does not support 8-bit constraint letter (`h` below means unsigned 16-bit)
175 |     asm volatile(LD_NC_FUNC ".u8 %0, [%1];" : "=h"(ret) : "l"(ptr));
176 |     return static_cast<uint8_t>(ret);
177 | }
178 | 
179 | template <>
180 | __device__  __forceinline__ int ld_nc_global(const int *ptr) {
181 |     int ret;
182 |     asm volatile(LD_NC_FUNC ".s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
183 |     return ret;
184 | }
185 | 
186 | template <>
187 | __device__  __forceinline__ int64_t ld_nc_global(const int64_t *ptr) {
188 |     int64_t ret;
189 |     asm volatile(LD_NC_FUNC ".s64 %0, [%1];" : "=l"(ret) : "l"(ptr));
190 |     return ret;
191 | }
192 | 
193 | template <>
194 | __device__  __forceinline__ float ld_nc_global(const float *ptr) {
195 |     float ret;
196 |     asm volatile(LD_NC_FUNC ".f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
197 |     return ret;
198 | }
199 | 
200 | template <>
201 | __device__  __forceinline__ int2 ld_nc_global(const int2 *ptr) {
202 |     int2 ret;
203 |     asm volatile(LD_NC_FUNC ".v2.s32 {%0, %1}, [%2];" : "=r"(ret.x), "=r"(ret.y) : "l"(ptr));
204 |     return ret;
205 | }
206 | 
207 | template <>
208 | __device__  __forceinline__ int4 ld_nc_global(const int4 *ptr) {
209 |     int4 ret;
210 |     asm volatile(LD_NC_FUNC ".v4.s32 {%0, %1, %2, %3}, [%4];"
211 |             : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
212 |     return ret;
213 | }
214 | 
215 | __device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) {
216 |     asm volatile("st.relaxed.gpu.global.L1::no_allocate.b8 [%0], %1;" : : "l"(ptr), "h"(static_cast<uint16_t>(val)));
217 | }
218 | 
219 | __device__ __forceinline__ void st_na_relaxed(const uint16_t *ptr, uint16_t val) {
220 |     asm volatile("st.relaxed.gpu.global.L1::no_allocate.b16 [%0], %1;" : : "l"(ptr), "h"(val));
221 | }
222 | 
223 | __device__ __forceinline__ void st_na_relaxed(const uint32_t *ptr, uint32_t val) {
224 |     asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
225 | }
226 | 
227 | __device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) {
228 |     asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
229 | }
230 | 
231 | __device__ __forceinline__ void st_na_relaxed(const int4 *ptr, int4 val) {
232 |     asm volatile("st.relaxed.gpu.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};"
233 |             : : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));
234 | }
235 | 
236 | __device__ __forceinline__ void st_na_release(const int *ptr, int val) {
237 |     asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
238 | }
239 | 
240 | __device__ __forceinline__ void st_na_release(const uint32_t *ptr, uint32_t val) {
241 |     asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
242 | }
243 | 
244 | __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) {
245 |     asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val));
246 | }
247 | 
248 | // `st.global.L1::no_allocate` will be translated into `ST.E.NA.[width]` in SASS
249 | #ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
250 | #define ST_NA_FUNC "st.global.L1::no_allocate"
251 | #else
252 | #define ST_NA_FUNC "st.global"
253 | #endif
254 | 
255 | template <typename dtype_t>
256 | __device__  __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t& value) {
257 |     st_na_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr),
258 |                  *reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(&value));
259 | }
260 | 
261 | template <>
262 | __device__  __forceinline__ void st_na_global(const int *ptr, const int& value) {
263 |     asm volatile(ST_NA_FUNC ".s32 [%0], %1;" ::"l"(ptr), "r"(value));
264 | }
265 | 
266 | template <>
267 | __device__  __forceinline__ void st_na_global(const int64_t *ptr, const int64_t& value) {
268 |     asm volatile(ST_NA_FUNC ".s64 [%0], %1;" ::"l"(ptr), "l"(value));
269 | }
270 | 
271 | template <>
272 | __device__  __forceinline__ void st_na_global(const float *ptr, const float& value) {
273 |     asm volatile(ST_NA_FUNC ".f32 [%0], %1;" ::"l"(ptr), "f"(value));
274 | }
275 | 
276 | template <>
277 | __device__  __forceinline__ void st_na_global(const int4 *ptr, const int4& value) {
278 |     asm volatile(ST_NA_FUNC ".v4.s32 [%0], {%1, %2, %3, %4};"
279 |             ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
280 | }
281 | 
282 | __device__  __forceinline__ float log2f_approx(const float &x) {
283 |     float ret;
284 |     asm volatile("lg2.approx.f32 %0, %1;" : "=f"(ret) : "f"(x));
285 |     return ret;
286 | }
287 | 
288 | __device__  __forceinline__ float exp2f_approx(const float &x) {
289 |     float ret;
290 |     asm volatile("ex2.approx.f32 %0, %1;" : "=f"(ret) : "f"(x));
291 |     return ret;
292 | }
293 | 
294 | // TMA PTX instructions
295 | #ifndef DISABLE_SM90_FEATURES
296 | 
297 | __device__ __forceinline__ uint32_t elect_one_sync(int lane_id) {
298 |     uint32_t pred = 0;
299 |     asm volatile(
300 |       "{\n"
301 |       ".reg .b32 %%rx;\n"
302 |       ".reg .pred %%px;\n"
303 |       "      elect.sync %%rx|%%px, %2;\n"
304 |       "@%%px mov.s32 %1, 1;\n"
305 |       "      mov.s32 %0, %%rx;\n"
306 |       "}\n"
307 |       : "+r"(lane_id), "+r"(pred)
308 |       : "r"(0xffffffff));
309 |     return pred;
310 | }
311 | 
312 | __device__ __forceinline__ void fence_view_async_shared() {
313 |     asm volatile("fence.proxy.async.shared::cta; \n" :: );
314 | }
315 | 
316 | __device__ __forceinline__ void fence_barrier_init() {
317 |     asm volatile("fence.mbarrier_init.release.cluster; \n" :: );
318 | }
319 | 
320 | __device__ __forceinline__ void mbarrier_init(uint64_t* mbar_ptr, uint32_t arrive_count) {
321 |     auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
322 |     asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" :: "r"(arrive_count), "r"(mbar_int_ptr));
323 | }
324 | 
325 | __device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phase) {
326 |     auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
327 |     asm volatile("{\n\t"
328 |                  ".reg .pred       P1; \n\t"
329 |                  "LAB_WAIT: \n\t"
330 |                  "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t"
331 |                  "@P1 bra DONE; \n\t"
332 |                  "bra     LAB_WAIT; \n\t"
333 |                  "DONE: \n\t"
334 |                  "}" :: "r"(mbar_int_ptr), "r"(phase), "r"(0x989680));
335 |     phase ^= 1;
336 | }
337 | 
338 | __device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr, int num_bytes) {
339 |     auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
340 |     asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: "r"(num_bytes), "r"(mbar_int_ptr));
341 | }
342 | 
343 | __device__ __forceinline__ void tma_store_fence() {
344 |     asm volatile ("fence.proxy.async.shared::cta;");
345 | }
346 | 
347 | constexpr uint64_t kEvictFirst = 0x12f0000000000000;
348 | constexpr uint64_t kEvictNormal = 0x1000000000000000;
349 | 
350 | __device__ __forceinline__ void tma_load_1d(const void* smem_ptr, const void* gmem_ptr, uint64_t* mbar_ptr, int num_bytes,
351 |                                             bool evict_first = true) {
352 |     auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
353 |     auto smem_int_ptr  = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
354 |     const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
355 |     asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n"
356 |                  :: "r"(smem_int_ptr), "l"(gmem_ptr), "r"(num_bytes), "r"(mbar_int_ptr), "l"(cache_hint) : "memory");
357 | }
358 | 
359 | __device__ __forceinline__ void tma_store_1d(const void* smem_ptr, const void* gmem_ptr, int num_bytes,
360 |                                              bool evict_first = true) {
361 |     auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
362 |     const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
363 |     asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n"
364 |                  :: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(num_bytes), "l"(cache_hint) : "memory");
365 |     asm volatile("cp.async.bulk.commit_group;");
366 | }
367 | 
368 | template <int N = 0>
369 | __device__ __forceinline__ void tma_store_wait() {
370 |     asm volatile("cp.async.bulk.wait_group.read %0;" :: "n"(N) : "memory");
371 | }
372 | 
373 | #endif
374 | 
375 | template <typename dtype_t>
376 | __host__ __device__ constexpr dtype_t ceil_div(dtype_t a, dtype_t b) {
377 |     return (a + b - 1) / b;
378 | }
379 | 
380 | template <typename dtype_t>
381 | __host__ __device__ constexpr dtype_t align(dtype_t a, dtype_t b) {
382 |     return ceil_div<dtype_t>(a, b) * b;
383 | }
384 | 
385 | __forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id,
386 |                                                        int& token_start_idx, int& token_end_idx) {
387 |     int num_tokens_per_sm = ceil_div(num_tokens, num_sms);
388 |     token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens);
389 |     token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);
390 | }
391 | 
392 | template <typename dtype_a_t, typename dtype_b_t>
393 | __device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {
394 |     EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
395 |     dtype_b_t packed;
396 |     auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
397 |     unpacked_ptr[0] = x, unpacked_ptr[1] = y;
398 |     return packed;
399 | }
400 | 
401 | template <typename dtype_a_t, typename dtype_b_t>
402 | __device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {
403 |     EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
404 |     auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
405 |     x = unpacked_ptr[0], y = unpacked_ptr[1];
406 | }
407 | 
408 | template <typename dtype_t>
409 | __device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) {
410 |     EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, "");
411 |     auto send_int_values = reinterpret_cast<int*>(&ptr);
412 |     int recv_int_values[sizeof(dtype_t) / sizeof(int)];
413 |     #pragma unroll
414 |     for (int i = 0; i < sizeof(dtype_t) / sizeof(int); ++ i)
415 |         recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], src_lane_idx);
416 |     return *reinterpret_cast<dtype_t*>(recv_int_values);
417 | }
418 | 
419 | __forceinline__ __device__ int get_lane_id() {
420 |     int lane_id;
421 |     asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
422 |     return lane_id;
423 | }
424 | 
425 | constexpr float kFP8Margin = 1e-4;
426 | constexpr float kFinfoAmaxE4M3 = 448.0f;
427 | constexpr float kFinfoAmaxInvE4M3 = 1 / 448.0f;
428 | 
429 | __forceinline__ __device__ float fast_pow2(int x) {
430 |     // We can ensure `-126 <= x and x <= 127`
431 |     uint32_t bits_x = (x + 127) << 23;
432 |     return *reinterpret_cast<float*>(&bits_x);
433 | }
434 | 
435 | __forceinline__ __device__ int fast_log2_ceil(float x) {
436 |     auto bits_x = *reinterpret_cast<uint32_t*>(&x);
437 |     auto exp_x = (bits_x >> 23) & 0xff;
438 |     auto man_bits = bits_x & ((1 << 23) - 1);
439 |     return exp_x - 127 + (man_bits != 0);
440 | }
441 | 
442 | __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) {
443 |     if (round_scale) {
444 |         auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
445 |         scale = fast_pow2(-exp_scale_inv);
446 |         scale_inv = fast_pow2(exp_scale_inv);
447 |     } else {
448 |         scale_inv = amax * kFinfoAmaxInvE4M3;
449 |         scale = kFinfoAmaxE4M3 / amax;
450 |     }
451 | }
452 | 
453 | template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
454 | __forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {
455 |     if constexpr (kIsUE8M0) {
456 |         return static_cast<uint8_t>((*reinterpret_cast<uint32_t*>(&value)) >> 23);
457 |     } else {
458 |         return value;
459 |     }
460 | }
461 | 
462 | template <int kNumRanks, bool kSyncOnly = false>
463 | __forceinline__ __device__ void
464 | barrier_block(int** barrier_signal_ptrs, int rank) {
465 |     auto thread_id = static_cast<int>(threadIdx.x);
466 | 
467 |     // For non-sync-only cases, the memory operations by other threads in the block must be visible to the `sys` scope
468 |     if constexpr (not kSyncOnly) {
469 |         memory_fence();
470 |         __syncthreads();
471 |     }
472 | 
473 |     // Add self-ranks, sub other ranks
474 |     if (thread_id < kNumRanks) {
475 |         atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG);
476 |         atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG);
477 |     }
478 |     EP_DEVICE_ASSERT(kNumRanks <= blockDim.x);
479 | 
480 |     // Check timeout
481 |     auto start_time = clock64();
482 |     while (true) {
483 |         auto value = thread_id < kNumRanks ? ld_volatile_global(barrier_signal_ptrs[rank] + thread_id) : 0;
484 |         if (__all_sync(0xffffffff, value <= 0))
485 |             break;
486 | 
487 |         if (clock64() - start_time > NUM_TIMEOUT_CYCLES and thread_id < kNumRanks) {
488 |             printf("DeepEP timeout check failed: rank = %d, thread = %d, value = %d)\n", rank, thread_id, value);
489 |             trap();
490 |         }
491 |     }
492 |     __syncthreads();
493 | }
494 | 
495 | __forceinline__ __device__ int atomic_cas_cta_acquire(int* addr, int x, int y) {
496 |     int ret;
497 |     asm volatile("atom.acquire.cta.shared::cta.cas.b32 %0, [%1], %2, %3;" : "=r"(ret) : "l"(addr), "r"(x), "r"(y) : "memory");
498 |     return ret;
499 | }
500 | 
501 | __forceinline__ __device__ int atomic_exch_cta_release(int* addr, int x) {
502 |     int ret;
503 |     asm volatile("atom.release.cta.shared::cta.exch.b32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(x) : "memory");
504 |     return ret;
505 | }
506 | 
507 | __forceinline__ __device__ void acquire_lock(int* mutex) {
508 |     // To make later memory operations valid, we must use `acquire` for memory semantics
509 |     while (atomic_cas_cta_acquire(mutex, 0, 1) != 0);
510 | }
511 | 
512 | __forceinline__ __device__ void release_lock(int* mutex) {
513 |     // To make previous memory operations visible to other threads, we must use `release` for memory semantics
514 |     atomic_exch_cta_release(mutex, 0);
515 | }
516 | 
517 | // Operation functors
518 | template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } };
519 | template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } };
520 | template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } };
521 | 
522 | // Unified reduction function
523 | template <uint32_t kNumLanes, typename T, typename Op>
524 | __forceinline__ __device__ T warp_reduce(T value, Op op) {
525 |     EP_STATIC_ASSERT(kNumLanes == 32 or kNumLanes == 16 or kNumLanes == 8 or
526 |                      kNumLanes ==  4 or kNumLanes == 2  or kNumLanes == 1,
527 |                      "Invalid number of lanes");
528 | 
529 |     if constexpr (kNumLanes >= 32) value = op(value, __shfl_xor_sync(0xffffffff, value, 16));
530 |     if constexpr (kNumLanes >= 16) value = op(value, __shfl_xor_sync(0xffffffff, value,  8));
531 |     if constexpr (kNumLanes >=  8) value = op(value, __shfl_xor_sync(0xffffffff, value,  4));
532 |     if constexpr (kNumLanes >=  4) value = op(value, __shfl_xor_sync(0xffffffff, value,  2));
533 |     if constexpr (kNumLanes >=  2) value = op(value, __shfl_xor_sync(0xffffffff, value,  1));
534 |     return value;
535 | }
536 | 
537 | // Convenience aliases
538 | template < uint32_t kNumLanes = 32, typename T>
539 | __forceinline__ __device__ T warp_reduce_sum(T value) {
540 |     return warp_reduce<kNumLanes, T>(value, ReduceSum<T>{});
541 | }
542 | 
543 | template <uint32_t kNumLanes = 32, typename T>
544 | __forceinline__ __device__ T warp_reduce_max(T value) {
545 |     return warp_reduce<kNumLanes, T>(value, ReduceMax<T>{});
546 | }
547 | 
548 | template <uint32_t kNumLanes = 32, typename T>
549 | __forceinline__ __device__ T warp_reduce_min(T value) {
550 |     return warp_reduce<kNumLanes, T>(value, ReduceMin<T>{});
551 | }
552 | 
553 | } // namespace deep_ep
554 | 


--------------------------------------------------------------------------------
/deep_ep/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | 
3 | from .utils import EventOverlap
4 | from .buffer import Buffer
5 | 
6 | # noinspection PyUnresolvedReferences
7 | from deep_ep_cpp import Config
8 | 


--------------------------------------------------------------------------------
/deep_ep/utils.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import subprocess
  3 | import torch
  4 | import torch.distributed as dist
  5 | from typing import Any, Optional, Tuple
  6 | 
  7 | # noinspection PyUnresolvedReferences
  8 | from deep_ep_cpp import Config, EventHandle
  9 | 
 10 | 
 11 | class EventOverlap:
 12 |     """
 13 |     A wrapper class to manage CUDA events, also for better overlapping convenience.
 14 | 
 15 |     Attributes:
 16 |         event: the CUDA event captured.
 17 |         extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
 18 |     """
 19 | 
 20 |     def __init__(self, event: Optional[EventHandle] = None,
 21 |                  extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None:
 22 |         """
 23 |         Initialize the class.
 24 | 
 25 |         Arguments:
 26 |             event: the CUDA event captured.
 27 |             extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
 28 |         """
 29 |         self.event = event
 30 | 
 31 |         # NOTES: we use extra tensors to achieve stream recording, otherwise,
 32 |         # stream recording will be incompatible with CUDA graph.
 33 |         self.extra_tensors = extra_tensors
 34 | 
 35 |     def current_stream_wait(self) -> None:
 36 |         """
 37 |         The current stream `torch.cuda.current_stream()` waits for the event to be finished.
 38 |         """
 39 |         assert self.event is not None
 40 |         self.event.current_stream_wait()
 41 | 
 42 |     def __enter__(self) -> Any:
 43 |         """
 44 |         Utility for overlapping and Python `with` syntax.
 45 | 
 46 |         You can overlap the kernels on the current stream with the following example:
 47 |         ```python
 48 |         event_overlap = event_after_all_to_all_kernels()
 49 |         with event_overlap():
 50 |             do_something_on_current_stream()
 51 |         # After exiting the `with` scope, the current stream with wait the event to be finished.
 52 |         ```
 53 |         """
 54 |         return self
 55 | 
 56 |     def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
 57 |         """
 58 |         Utility for overlapping and Python `with` syntax.
 59 | 
 60 |         Please follow the example in the `__enter__` function.
 61 |         """
 62 |         if self.event is not None:
 63 |             self.event.current_stream_wait()
 64 | 
 65 | 
 66 | def check_nvlink_connections(group: dist.ProcessGroup):
 67 |     """
 68 |     Check NVLink connection between every pair of GPUs.
 69 | 
 70 |     Arguments:
 71 |         group: the communication group.
 72 |     """
 73 |     # Check NVLink connection
 74 |     # NOTES: some A100 PCIE GPUs only have pairwise NVLink connection, so that we can only use EP2
 75 |     # TODO: check all cases, all local-node GPUs in the group should be connected via NVLink
 76 |     if 'PCIE' in torch.cuda.get_device_name():
 77 |         assert group.size() <= 2, 'PCIe GPUs only have pairwise NVLink connections'
 78 | 
 79 |         # noinspection PyUnresolvedReferences
 80 |         import pynvml
 81 |         pynvml.nvmlInit()
 82 | 
 83 |         # noinspection PyTypeChecker
 84 |         devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0,1,2,3,4,5,6,7').strip(',').split(',')
 85 |         physical_device_idx = int(devices[torch.cuda.current_device()])
 86 |         physical_device_indices = [0, ] * group.size()
 87 |         dist.all_gather_object(physical_device_indices, physical_device_idx, group)
 88 | 
 89 |         # Check whether they are all connected via NVLink
 90 |         # Reference: https://github.com/vllm-project/vllm/blob/b8e809a057765c574726a6077fd124db5077ce1f/vllm/platforms/cuda.py#L438
 91 |         handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_indices]
 92 |         for i, handle in enumerate(handles):
 93 |             for j, peer_handle in enumerate(handles):
 94 |                 if i >= j:
 95 |                     continue
 96 |                 status = pynvml.nvmlDeviceGetP2PStatus(handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
 97 |                 assert status == pynvml.NVML_P2P_STATUS_OK,\
 98 |                     f'GPU {physical_device_indices[i]} and GPU {physical_device_indices[j]} are not connected via NVLink'
 99 | 
100 |         # Close NVML
101 |         pynvml.nvmlShutdown()
102 | 


--------------------------------------------------------------------------------
/figures/low-latency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepseek-ai/DeepEP/0eee87b8ca19e255b70b5631b8cc42ca1b88c8d7/figures/low-latency.png


--------------------------------------------------------------------------------
/figures/normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepseek-ai/DeepEP/0eee87b8ca19e255b70b5631b8cc42ca1b88c8d7/figures/normal.png


--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
 1 | # Change current directory into project root
 2 | original_dir=$(pwd)
 3 | script_dir=$(dirname "$0")
 4 | cd "$script_dir"
 5 | 
 6 | # Remove old dist file, build, and install
 7 | rm -rf dist
 8 | python setup.py bdist_wheel
 9 | pip install dist/*.whl
10 | 
11 | # Open users' original directory
12 | cd "$original_dir"
13 | 


--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import subprocess
  3 | import setuptools
  4 | import importlib
  5 | import importlib.resources
  6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
  7 | 
  8 | # Wheel specific: The wheels only include the soname of the host library (libnvshmem_host.so.X)
  9 | def get_nvshmem_host_lib_name():
 10 |     for path in importlib.resources.files('nvidia.nvshmem').iterdir():
 11 |         for file in path.rglob('libnvshmem_host.so.*'):
 12 |             return file.name
 13 |     raise ModuleNotFoundError('libnvshmem_host.so not found')
 14 | 
 15 | if __name__ == '__main__':
 16 |     disable_nvshmem = False
 17 |     nvshmem_dir = os.getenv('NVSHMEM_DIR', None)
 18 |     nvshmem_host_lib = 'libnvshmem_host.so'
 19 |     if nvshmem_dir is None:
 20 |         try:
 21 |             nvshmem_dir = importlib.util.find_spec("nvidia.nvshmem").submodule_search_locations[0]
 22 |             nvshmem_host_lib = get_nvshmem_host_lib_name()
 23 |             import nvidia.nvshmem as nvshmem
 24 |         except (ModuleNotFoundError, AttributeError, IndexError):
 25 |             print('Warning: `NVSHMEM_DIR` is not specified, and the NVSHMEM module is not installed. All internode and low-latency features are disabled\n')
 26 |             disable_nvshmem = True
 27 |     else:
 28 |         disable_nvshmem = False
 29 | 
 30 |     if not disable_nvshmem:
 31 |         assert os.path.exists(nvshmem_dir), f'The specified NVSHMEM directory does not exist: {nvshmem_dir}'
 32 | 
 33 |     cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable',
 34 |                  '-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes']
 35 |     nvcc_flags = ['-O3', '-Xcompiler', '-O3']
 36 |     sources = ['csrc/deep_ep.cpp', 'csrc/kernels/runtime.cu', 'csrc/kernels/layout.cu', 'csrc/kernels/intranode.cu']
 37 |     include_dirs = ['csrc/']
 38 |     library_dirs = []
 39 |     nvcc_dlink = []
 40 |     extra_link_args = []
 41 | 
 42 |     # NVSHMEM flags
 43 |     if disable_nvshmem:
 44 |         cxx_flags.append('-DDISABLE_NVSHMEM')
 45 |         nvcc_flags.append('-DDISABLE_NVSHMEM')
 46 |     else:
 47 |         sources.extend(['csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu'])
 48 |         include_dirs.extend([f'{nvshmem_dir}/include'])
 49 |         library_dirs.extend([f'{nvshmem_dir}/lib'])
 50 |         nvcc_dlink.extend(['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem_device'])
 51 |         extra_link_args.extend([f'-l:{nvshmem_host_lib}', '-l:libnvshmem_device.a', f'-Wl,-rpath,{nvshmem_dir}/lib'])
 52 | 
 53 |     if int(os.getenv('DISABLE_SM90_FEATURES', 0)):
 54 |         # Prefer A100
 55 |         os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '8.0')
 56 | 
 57 |         # Disable some SM90 features: FP8, launch methods, and TMA
 58 |         cxx_flags.append('-DDISABLE_SM90_FEATURES')
 59 |         nvcc_flags.append('-DDISABLE_SM90_FEATURES')
 60 | 
 61 |         # Disable internode and low-latency kernels
 62 |         assert disable_nvshmem
 63 |     else:
 64 |         # Prefer H800 series
 65 |         os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '9.0')
 66 | 
 67 |         # CUDA 12 flags
 68 |         nvcc_flags.extend(['-rdc=true', '--ptxas-options=--register-usage-level=10'])
 69 | 
 70 |     # Disable LD/ST tricks, as some CUDA version does not support `.L1::no_allocate`
 71 |     if os.environ['TORCH_CUDA_ARCH_LIST'].strip() != '9.0':
 72 |         assert int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', 1)) == 1
 73 |         os.environ['DISABLE_AGGRESSIVE_PTX_INSTRS'] = '1'
 74 | 
 75 |     # Disable aggressive PTX instructions
 76 |     if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '1')):
 77 |         cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
 78 |         nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
 79 | 
 80 |     # Put them together
 81 |     extra_compile_args = {
 82 |         'cxx': cxx_flags,
 83 |         'nvcc': nvcc_flags,
 84 |     }
 85 |     if len(nvcc_dlink) > 0:
 86 |         extra_compile_args['nvcc_dlink'] = nvcc_dlink
 87 | 
 88 |     # Summary
 89 |     print(f'Build summary:')
 90 |     print(f' > Sources: {sources}')
 91 |     print(f' > Includes: {include_dirs}')
 92 |     print(f' > Libraries: {library_dirs}')
 93 |     print(f' > Compilation flags: {extra_compile_args}')
 94 |     print(f' > Link flags: {extra_link_args}')
 95 |     print(f' > Arch list: {os.environ["TORCH_CUDA_ARCH_LIST"]}')
 96 |     print(f' > NVSHMEM path: {nvshmem_dir}')
 97 |     print()
 98 | 
 99 |     # noinspection PyBroadException
100 |     try:
101 |         cmd = ['git', 'rev-parse', '--short', 'HEAD']
102 |         revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
103 |     except Exception as _:
104 |         revision = ''
105 | 
106 |     setuptools.setup(
107 |         name='deep_ep',
108 |         version='1.1.0' + revision,
109 |         packages=setuptools.find_packages(
110 |             include=['deep_ep']
111 |         ),
112 |         ext_modules=[
113 |             CUDAExtension(
114 |                 name='deep_ep_cpp',
115 |                 include_dirs=include_dirs,
116 |                 library_dirs=library_dirs,
117 |                 sources=sources,
118 |                 extra_compile_args=extra_compile_args,
119 |                 extra_link_args=extra_link_args
120 |             )
121 |         ],
122 |         cmdclass={
123 |             'build_ext': BuildExtension
124 |         }
125 |     )
126 | 


--------------------------------------------------------------------------------
/tests/test_internode.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import os
  3 | import time
  4 | import torch
  5 | import torch.distributed as dist
  6 | 
  7 | # noinspection PyUnresolvedReferences
  8 | import deep_ep
  9 | from utils import init_dist, bench, bench_kineto, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back
 10 | 
 11 | # Test compatibility with low latency functions
 12 | import test_low_latency
 13 | 
 14 | 
 15 | # noinspection PyShadowingNames
 16 | def test_main(args: argparse.Namespace, num_sms: int,
 17 |               local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int,
 18 |               buffer: deep_ep.Buffer, group: dist.ProcessGroup):
 19 |     # Settings
 20 |     num_tokens, hidden = args.num_tokens, args.hidden
 21 |     num_topk_groups, num_topk, num_experts = args.num_topk_groups, args.num_topk, args.num_experts
 22 | 
 23 |     assert num_experts % num_ranks == 0 and num_local_ranks == 8
 24 |     if local_rank == 0:
 25 |         print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True)
 26 | 
 27 |     # Random data
 28 |     x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
 29 |     x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
 30 |     x_e4m3 = per_token_cast_to_fp8(x)
 31 |     x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T)
 32 |     scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
 33 |     group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
 34 |     group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
 35 |     masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
 36 |     topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1]
 37 |     topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
 38 |     topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
 39 |     rank_idx = topk_idx // (num_experts // num_ranks)
 40 |     rank_idx.masked_fill_(topk_idx == -1, -1)
 41 |     inplace_unique(rank_idx, num_ranks)
 42 |     rdma_rank_idx = rank_idx // num_local_ranks
 43 |     rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
 44 |     inplace_unique(rdma_rank_idx, num_nodes)
 45 | 
 46 |     # RDMA dispatch counts
 47 |     rdma_idx = topk_idx // (num_experts // num_nodes)
 48 |     rdma_idx.masked_fill_(topk_idx == -1, -1)
 49 |     inplace_unique(rdma_idx, num_nodes)
 50 |     num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
 51 | 
 52 |     # Expert meta
 53 |     num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')
 54 |     for i in range(num_experts):
 55 |         num_tokens_per_expert[i] = (topk_idx == i).sum()
 56 |     gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
 57 |     dist.all_reduce(gbl_num_tokens_per_expert, group=group)
 58 | 
 59 |     # Rank layout meta
 60 |     num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')
 61 |     num_tokens_per_rdma_rank = torch.empty((num_nodes, ), dtype=torch.int, device='cuda')
 62 |     token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')
 63 |     for i in range(num_ranks):
 64 |         num_tokens_per_rank[i] = (rank_idx == i).sum()
 65 |         token_sel = (rank_idx == i).max(dim=-1)[0]
 66 |         count = token_sel.sum().item()
 67 |         tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
 68 |         tokens[:count] = torch.sort(tokens[:count])[0]
 69 |         token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')
 70 |     for i in range(num_nodes):
 71 |         num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
 72 |     token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
 73 |     is_token_in_rank = token_idx_in_rank >= 0
 74 |     gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
 75 |     dist.all_reduce(gbl_num_tokens_per_rank, group=group)
 76 | 
 77 |     ref_num_tokens_per_rank, ref_num_tokens_per_rdma_rank, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \
 78 |         buffer.get_dispatch_layout(topk_idx, num_experts)
 79 |     assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
 80 |     assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
 81 |     assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
 82 |     assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
 83 |     t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
 84 |     if local_rank == 0:
 85 |         print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)
 86 |         print('', flush=True)
 87 |     group.barrier()
 88 |     time.sleep(1)
 89 | 
 90 |     # Config
 91 |     rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
 92 |     config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
 93 | 
 94 |     # Test dispatch
 95 |     # noinspection PyShadowingNames
 96 |     def check_data(check_x, recv_gbl_rank_prefix_sum):
 97 |         assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
 98 |         check_start = 0
 99 |         for i in range(num_ranks):
100 |             check_end = recv_gbl_rank_prefix_sum[i].item()
101 |             assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
102 |             check_start = check_end
103 | 
104 |     for previous_mode in (False, True):
105 |         for async_mode in (False, True):
106 |             for current_x in (x_pure_rand, x, x_e4m3):
107 |                 for with_topk in (False, True):
108 |                     if local_rank == 0:
109 |                         print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
110 |                     dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank,  'is_token_in_rank': is_token_in_rank,
111 |                                      'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode}
112 |                     if with_topk:
113 |                         dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights})
114 |                     if previous_mode:
115 |                         dispatch_args.update({'previous_event': buffer.capture()})
116 |                     recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
117 |                     event.current_stream_wait() if async_mode else ()
118 |                     recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
119 | 
120 |                     # Checks
121 |                     recv_gbl_rank_prefix_sum = handle[-4]
122 |                     assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'
123 |                     assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
124 |                     if current_x is not x_pure_rand:
125 |                         check_data(recv_x, recv_gbl_rank_prefix_sum)
126 |                     if with_topk:
127 |                         # Check `topk_idx`
128 |                         assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
129 |                         for i, count in enumerate(recv_num_tokens_per_expert_list):
130 |                             assert recv_topk_idx.eq(i).sum().item() == count
131 | 
132 |                         # Check `topk_weights`
133 |                         if current_x is not x_pure_rand:
134 |                             recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
135 |                             check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
136 | 
137 |                     # Test cached dispatch (must without top-k staffs)
138 |                     if not with_topk:
139 |                         dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
140 |                         if previous_mode:
141 |                             dispatch_args.update({'previous_event': buffer.capture()})
142 |                         recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
143 |                         event.current_stream_wait() if async_mode else ()
144 |                         recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
145 |                         if current_x is not x_pure_rand:
146 |                             check_data(recv_x, recv_gbl_rank_prefix_sum)
147 | 
148 |                     # Test combine
149 |                     bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
150 |                     bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
151 |                     combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode}
152 |                     if with_topk:
153 |                         combine_args.update({'topk_weights': recv_topk_weights})
154 |                     if previous_mode:
155 |                         combine_args.update({'previous_event': buffer.capture()})
156 |                     combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
157 |                     event.current_stream_wait() if async_mode else ()
158 |                     check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1)
159 |                     ref_x = x_pure_rand if current_x is x_pure_rand else x
160 |                     assert calc_diff(check_x, ref_x) < 5e-6
161 |                     if with_topk:
162 |                         check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1))
163 |                         ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights
164 |                         assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
165 | 
166 |                     # For later tuning
167 |                     dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
168 |                     dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
169 |                     combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
170 |                     combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
171 | 
172 |                     if local_rank == 0:
173 |                         print(' passed', flush=True)
174 |     if local_rank == 0:
175 |         print('', flush=True)
176 | 
177 |     # Tune dispatch performance
178 |     best_dispatch_results = None
179 |     fp8_factor = (1 + 4 / 128) / 2
180 |     for current_x in (x_e4m3, x):
181 |         best_time, best_results = 1e10, None
182 |         rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes
183 |         nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
184 |         for nvl_chunk_size in range(4, 45, 4):
185 |             for rdma_chunk_size in range(4, 33, 4):
186 |                 config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
187 |                 tune_args = {'x': current_x, 'handle': handle, 'config': config}
188 |                 t, notify_t = bench_kineto(lambda: buffer.dispatch(**tune_args), ('dispatch', 'notify'))
189 |                 if t < best_time:
190 |                     best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t)
191 |                 if local_rank == 0:
192 |                     print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
193 |         if local_rank == 0:
194 |             print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
195 |             print('', flush=True)
196 | 
197 |         if isinstance(current_x, tuple):
198 |             # Gather FP8 the best config from rank 0
199 |             best_dispatch_results = torch.tensor([best_results[0], best_results[1], best_results[2]], dtype=torch.int32, device='cuda')
200 |             all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
201 |             dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)
202 |             best_dispatch_results = all_best_fp8_results_list[0].tolist()
203 |     dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], rdma_buffer_size)
204 | 
205 |     dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank,
206 |                      'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert,
207 |                      'config': dispatch_config if dispatch_config is not None else config}
208 |     recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
209 | 
210 |     # Tune combine performance
211 |     best_time, best_results = 1e10, None
212 |     for nvl_chunk_size in range(1, 13, 1):
213 |         for rdma_chunk_size in range(8, 33, 4):
214 |             config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
215 |             tune_args = {'x': recv_x, 'handle': handle, 'config': config}
216 |             t, notify_t = bench_kineto(lambda: buffer.combine(**tune_args), ('combine', 'notify'))
217 |             if local_rank == 0:
218 |                 print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
219 |                 if t < best_time:
220 |                     best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t)
221 | 
222 |     if local_rank == 0:
223 |         print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
224 |         print('', flush=True)
225 | 
226 | 
227 | # noinspection PyUnboundLocalVariable,PyShadowingNames
228 | def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
229 |     num_nodes = int(os.getenv('WORLD_SIZE', 1))
230 |     rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
231 |     if args.test_ll_compatibility:
232 |         ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
233 | 
234 |     num_sms = 24
235 |     num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
236 | 
237 |     buffer = deep_ep.Buffer(group, int(2e9), int(1e9), low_latency_mode=args.test_ll_compatibility,
238 |                             num_qps_per_rank=num_qps_per_rank, explicitly_destroy=True)
239 |     assert num_local_ranks == 8 and num_ranks > 8
240 |     torch.manual_seed(rank)
241 | 
242 |     for i in (num_sms, ):
243 |         test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
244 |         if local_rank == 0:
245 |             print('', flush=True)
246 | 
247 |     # Test compatibility with low latency functions
248 |     if args.test_ll_compatibility:
249 |         buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
250 |         test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
251 | 
252 |     # Destroy the buffer runtime and communication group
253 |     buffer.destroy()
254 |     dist.barrier()
255 |     dist.destroy_process_group()
256 | 
257 | 
258 | if __name__ == '__main__':
259 |     parser = argparse.ArgumentParser(description='Test internode EP kernels')
260 |     parser.add_argument('--num-processes', type=int, default=8,
261 |                        help='Number of processes to spawn (default: 8)')
262 |     parser.add_argument('--num-tokens', type=int, default=4096,
263 |                        help='Number of tokens (default: 4096)')
264 |     parser.add_argument('--hidden', type=int, default=7168,
265 |                        help='Hidden dimension size (default: 7168)')
266 |     parser.add_argument('--num-topk-groups', type=int, default=None,
267 |                        help='Number of top-k groups (default: `min(num_nodes, 4)`)')
268 |     parser.add_argument('--num-topk', type=int, default=8,
269 |                        help='Number of top-k experts (default: 8)')
270 |     parser.add_argument('--num-experts', type=int, default=256,
271 |                        help='Number of experts (default: 256')
272 |     parser.add_argument('--test-ll-compatibility', action='store_true',
273 |                         help='whether to test compatibility with low-latency kernels')
274 |     args = parser.parse_args()
275 | 
276 |     # Set default `num_topk_groups` if not provided
277 |     if args.num_topk_groups is None:
278 |         num_nodes = int(os.getenv('WORLD_SIZE', 1))
279 |         args.num_topk_groups = min(num_nodes, 4)
280 | 
281 |     num_processes = args.num_processes
282 |     torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
283 | 


--------------------------------------------------------------------------------
/tests/test_intranode.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import time
  3 | import torch
  4 | import torch.distributed as dist
  5 | 
  6 | # noinspection PyUnresolvedReferences
  7 | import deep_ep
  8 | from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_back
  9 | 
 10 | # Test compatibility with low latency functions
 11 | import test_low_latency
 12 | 
 13 | 
 14 | # noinspection PyShadowingNames
 15 | def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks: int, rank: int,
 16 |               buffer: deep_ep.Buffer, group: dist.ProcessGroup):
 17 |     # Settings
 18 |     num_tokens, hidden = args.num_tokens, args.hidden
 19 |     num_topk, num_experts = args.num_topk, args.num_experts
 20 | 
 21 |     assert num_experts % num_ranks == 0
 22 |     if local_rank == 0:
 23 |         print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True)
 24 | 
 25 |     # Random data
 26 |     x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
 27 |     x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
 28 |     x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None
 29 |     x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None
 30 |     scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
 31 |     topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
 32 |     topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
 33 |     topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
 34 |     rank_idx = topk_idx // (num_experts // num_ranks)
 35 |     rank_idx.masked_fill_(topk_idx == -1, -1)
 36 |     inplace_unique(rank_idx, num_ranks)
 37 | 
 38 |     # Expert meta
 39 |     num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')
 40 |     for i in range(num_experts):
 41 |         num_tokens_per_expert[i] = (topk_idx == i).sum()
 42 |     gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
 43 |     dist.all_reduce(gbl_num_tokens_per_expert, group=group)
 44 | 
 45 |     # Rank layout meta
 46 |     num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')
 47 |     token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')
 48 |     for i in range(num_ranks):
 49 |         num_tokens_per_rank[i] = (rank_idx == i).sum()
 50 |         token_sel = (rank_idx == i).max(dim=-1)[0]
 51 |         count = token_sel.sum().item()
 52 |         tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
 53 |         tokens[:count] = torch.sort(tokens[:count])[0]
 54 |         token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')
 55 |     token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
 56 |     is_token_in_rank = token_idx_in_rank >= 0
 57 |     gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
 58 |     dist.all_reduce(gbl_num_tokens_per_rank, group=group)
 59 | 
 60 |     ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \
 61 |         buffer.get_dispatch_layout(topk_idx, num_experts)
 62 |     assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
 63 |     assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
 64 |     assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
 65 |     t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
 66 |     if local_rank == 0:
 67 |         print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)
 68 |         print('', flush=True)
 69 |     group.barrier()
 70 |     time.sleep(1)
 71 | 
 72 |     # Config
 73 |     nvl_buffer_size = 256
 74 |     config = deep_ep.Config(num_sms, 8, nvl_buffer_size)
 75 | 
 76 |     # Test dispatch
 77 |     # noinspection PyShadowingNames
 78 |     def check_data(check_x, rank_prefix_matrix):
 79 |         assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
 80 |         check_start = 0
 81 |         for i in range(num_ranks):
 82 |             check_end = rank_prefix_matrix[i][rank].item()
 83 |             assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
 84 |             check_start = check_end
 85 | 
 86 |     for previous_mode in (False, True):
 87 |         for async_mode in (False, True):
 88 |             for current_x in filter(lambda elem: elem is not None, (x_pure_rand, x, x_e4m3)):
 89 |                 for with_topk in (False, True):
 90 |                     if local_rank == 0:
 91 |                         print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
 92 |                     dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank,  'is_token_in_rank': is_token_in_rank,
 93 |                                      'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode}
 94 |                     if with_topk:
 95 |                         dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights})
 96 |                     if previous_mode:
 97 |                         dispatch_args.update({'previous_event': buffer.capture()})
 98 |                     recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
 99 |                     event.current_stream_wait() if async_mode else ()
100 |                     recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
101 | 
102 |                     # Checks
103 |                     rank_prefix_matrix = handle[0]
104 |                     assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'
105 |                     assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
106 |                     if current_x is not x_pure_rand:
107 |                         check_data(recv_x, rank_prefix_matrix)
108 |                     recv_topk_weights_clone = None
109 |                     if with_topk:
110 |                         # Check `topk_idx`
111 |                         assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
112 |                         for i, count in enumerate(recv_num_tokens_per_expert_list):
113 |                             assert recv_topk_idx.eq(i).sum().item() == count
114 | 
115 |                         # Check `topk_weights`
116 |                         recv_topk_weights_clone = recv_topk_weights.clone()
117 |                         if current_x is not x_pure_rand:
118 |                             recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
119 |                             check_data(recv_topk_weights, rank_prefix_matrix)
120 | 
121 |                     # Test `num_worst_tokens != 0`
122 |                     if with_topk:
123 |                         num_worst_tokens = num_tokens * num_ranks
124 |                         dispatch_args.update({'num_worst_tokens': num_worst_tokens})
125 |                         recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args)
126 |                         event.current_stream_wait() if async_mode else ()
127 |                         recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x
128 |                         assert len(empty_list) == 0
129 |                         assert num_worst_tokens == recv_worst_x.size(0)
130 |                         assert num_worst_tokens == recv_worst_topk_idx.size(0)
131 |                         assert num_worst_tokens == recv_worst_topk_weights.size(0)
132 |                         assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)])
133 |                         assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)])
134 |                         assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)])
135 |                         assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item()
136 | 
137 |                     # Test cached dispatch (must without top-k staffs)
138 |                     if not with_topk:
139 |                         dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
140 |                         if previous_mode:
141 |                             dispatch_args.update({'previous_event': buffer.capture()})
142 |                         recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
143 |                         event.current_stream_wait() if async_mode else ()
144 |                         recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
145 |                         if current_x is not x_pure_rand:
146 |                             check_data(recv_x, rank_prefix_matrix)
147 | 
148 |                     # Test combine
149 |                     combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
150 |                     if with_topk:
151 |                         combine_args.update({'topk_weights': recv_topk_weights})
152 |                     if previous_mode:
153 |                         combine_args.update({'previous_event': buffer.capture()})
154 |                     combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
155 |                     event.current_stream_wait() if async_mode else ()
156 |                     check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
157 |                     ref_x = x_pure_rand if current_x is x_pure_rand else x
158 |                     assert calc_diff(check_x, ref_x) < 5e-6
159 |                     if with_topk:
160 |                         check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1))
161 |                         ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights
162 |                         assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
163 | 
164 |                     # For later tuning
165 |                     dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
166 |                     combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
167 | 
168 |                     if local_rank == 0:
169 |                         print(' passed', flush=True)
170 |     if local_rank == 0:
171 |         print('', flush=True)
172 | 
173 |     # Tune dispatch performance
174 |     best_dispatch_results = None
175 |     fp8_factor = (1 + 4 / 128) / 2
176 |     for current_x in filter(lambda elem: elem is not None, (x_e4m3, x)):
177 |         best_time, best_results = 1e10, None
178 |         nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
179 |         for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ):
180 |             if nvl_chunk_size > 0:
181 |                 config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
182 |             else:
183 |                 # Test default config as well
184 |                 deep_ep.Buffer.set_num_sms(num_sms)
185 |                 config = deep_ep.Buffer.get_dispatch_config(num_ranks)
186 |             tune_args = {'x': current_x, 'handle': handle, 'config': config}
187 |             t = bench(lambda: buffer.dispatch(**tune_args))[0]
188 |             if t < best_time and nvl_chunk_size > 0:
189 |                 best_time, best_results = t, (num_sms, nvl_chunk_size)
190 |             if local_rank == 0:
191 |                 print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
192 |                       f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True)
193 |         if local_rank == 0:
194 |             print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True)
195 |             print('', flush=True)
196 | 
197 |         # Gather the best config from rank 0 and the first test setting
198 |         if best_dispatch_results is None:
199 |             best_dispatch_results = torch.tensor([best_results[0], best_results[1]], dtype=torch.int32, device='cuda')
200 |             all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
201 |             dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)
202 |             best_dispatch_results = all_best_fp8_results_list[0].tolist()
203 |     dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size)
204 | 
205 |     dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank,
206 |                      'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert,
207 |                      'config': dispatch_config if dispatch_config is not None else config}
208 |     recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
209 | 
210 |     # Tune combine performance
211 |     best_time, best_results = 1e10, None
212 |     for nvl_chunk_size in tuple(range(1, 17, 1)) + (0, ):
213 |         if nvl_chunk_size > 0:
214 |             config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
215 |         else:
216 |             # Test default config as well
217 |             deep_ep.Buffer.set_num_sms(num_sms)
218 |             config = deep_ep.Buffer.get_combine_config(num_ranks)
219 |         tune_args = {'x': recv_x, 'handle': handle, 'config': config}
220 |         t = bench(lambda: buffer.combine(**tune_args))[0]
221 |         if local_rank == 0:
222 |             print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
223 |                   f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True)
224 |             if t < best_time and nvl_chunk_size > 0:
225 |                 best_time, best_results = t, (num_sms, nvl_chunk_size)
226 | 
227 |     if local_rank == 0:
228 |         print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True)
229 |         print('', flush=True)
230 | 
231 | 
232 | # noinspection PyUnboundLocalVariable,PyShadowingNames
233 | def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
234 |     rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
235 |     test_ll_compatibility, num_rdma_bytes = False, 0
236 |     if test_ll_compatibility:
237 |         ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
238 |         num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
239 | 
240 |     buffer = deep_ep.Buffer(group, int(2e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
241 |                             num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), explicitly_destroy=True)
242 |     torch.manual_seed(rank)
243 | 
244 |     for i in (24, ):
245 |         test_main(args, i, local_rank, num_ranks, rank, buffer, group)
246 |         if local_rank == 0:
247 |             print('', flush=True)
248 | 
249 |     # Test compatibility with low latency functions
250 |     if test_ll_compatibility:
251 |         buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
252 |         test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
253 | 
254 |     # Destroy the buffer runtime and communication group
255 |     buffer.destroy()
256 |     dist.barrier()
257 |     dist.destroy_process_group()
258 | 
259 | 
260 | if __name__ == '__main__':
261 |     parser = argparse.ArgumentParser(description='Test intranode EP kernels')
262 |     parser.add_argument('--num-processes', type=int, default=8,
263 |                        help='Number of processes to spawn (default: 8)')
264 |     parser.add_argument('--num-tokens', type=int, default=4096,
265 |                        help='Number of tokens (default: 4096)')
266 |     parser.add_argument('--hidden', type=int, default=7168,
267 |                        help='Hidden dimension size (default: 7168)')
268 |     parser.add_argument('--num-topk', type=int, default=8,
269 |                        help='Number of top-k experts (default: 8)')
270 |     parser.add_argument('--num-experts', type=int, default=256,
271 |                        help='Number of experts (default: 256)')
272 |     args = parser.parse_args()
273 | 
274 |     num_processes = args.num_processes
275 |     torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
276 | 


--------------------------------------------------------------------------------
/tests/test_low_latency.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import random
  3 | import torch
  4 | import torch.distributed as dist
  5 | from functools import partial
  6 | 
  7 | import deep_ep
  8 | from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back
  9 | 
 10 | 
 11 | def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
 12 |               rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer,
 13 |               use_logfmt: bool = False, seed: int = 0):
 14 |     torch.manual_seed(seed + rank)
 15 |     random.seed(seed + rank)
 16 | 
 17 |     assert num_experts % num_ranks == 0
 18 |     num_local_experts = num_experts // num_ranks
 19 | 
 20 |     # NOTES: the integers greater than 256 exceed the BF16 precision limit
 21 |     rank_offset = 128
 22 |     assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
 23 | 
 24 |     x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
 25 |     x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
 26 |     x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
 27 |     scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
 28 |     topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
 29 |     topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
 30 | 
 31 |     # Randomly mask some positions
 32 |     for i in range(10):
 33 |         topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
 34 | 
 35 |     # Check dispatch correctness
 36 |     do_check = True
 37 |     hash_value, num_times = 0, 0
 38 |     for current_x in (x, x_pure_rand):
 39 |         for return_recv_hook in (False, True):
 40 |             for dispatch_use_fp8 in (False, True):
 41 |                 for round_scale in (False, True) if dispatch_use_fp8 else (False, ):
 42 |                     for use_ue8m0 in (False, True) if round_scale else (False, ):
 43 |                         num_times += 1
 44 |                         for i in range((num_times % 2) + 1):
 45 |                             cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
 46 |                             packed_recv_x, packed_recv_count, handle, event, hook = \
 47 |                                 buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
 48 |                                                             use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
 49 |                                                             cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
 50 |                                                             async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
 51 |                             hook() if return_recv_hook else event.current_stream_wait()
 52 |                         packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
 53 |                         simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
 54 |                             if dispatch_use_fp8 else packed_recv_x.clone()
 55 |                         all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
 56 |                         dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
 57 |                         for i in range(num_local_experts if do_check else 0):
 58 |                             expert_id = rank * num_local_experts + i
 59 |                             recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
 60 |                             recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
 61 | 
 62 |                             # Check expert indices
 63 |                             int_mask = (2 ** 32) - 1
 64 |                             num_valid_tokens = recv_count.item()
 65 |                             assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'
 66 |                             assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
 67 |                             assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
 68 | 
 69 |                             # Check received data
 70 |                             if current_x is not x_pure_rand:
 71 |                                 recv_x = recv_x[:num_valid_tokens]
 72 |                                 recv_x_amin = recv_x[:, :-128].amin(dim=-1)
 73 |                                 recv_src_info = recv_src_info[:num_valid_tokens]
 74 |                                 assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
 75 |                                 if round_scale:
 76 |                                     assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
 77 |                                 else:
 78 |                                     assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
 79 |                                 for j in range(num_ranks):
 80 |                                     begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
 81 |                                     if not round_scale:
 82 |                                         assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
 83 |                                     assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
 84 |                             if dispatch_use_fp8:
 85 |                                 hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
 86 |                                 hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
 87 |                             else:
 88 |                                 hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
 89 | 
 90 |                         # Check combine correctness
 91 |                         for zero_copy in (False, ) if use_logfmt else (False, True):
 92 |                             if zero_copy:
 93 |                                 buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
 94 |                             out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
 95 |                             combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
 96 |                                                                                 use_logfmt=use_logfmt,
 97 |                                                                                 async_finish=not return_recv_hook, zero_copy=zero_copy,
 98 |                                                                                 return_recv_hook=return_recv_hook, out=out)
 99 |                             hook() if return_recv_hook else event.current_stream_wait()
100 |                             if do_check:
101 |                                 diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
102 |                                 assert torch.isnan(combined_x).sum().item() == 0
103 |                                 assert diff < (7e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {zero_copy=}'
104 |                                 hash_value ^= hash_tensor(combined_x)
105 | 
106 |     # noinspection PyShadowingNames
107 |     def large_gemm_with_hook(hook):
108 |         mat_0 = torch.randn((8192, 8192), dtype=torch.float)
109 |         mat_1 = torch.randn((8192, 8192), dtype=torch.float)
110 |         mat_0 @ mat_1
111 |         hook()
112 | 
113 |     # noinspection PyShadowingNames
114 |     def test_func(return_recv_hook: bool):
115 |         recv_x, recv_count, handle, event, hook = \
116 |             buffer.low_latency_dispatch(x_pure_rand, topk_idx, num_tokens, num_experts,
117 |                                         cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
118 |                                         use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
119 |         large_gemm_with_hook(hook) if return_recv_hook else None
120 |         combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
121 |                                                              use_logfmt=use_logfmt, return_recv_hook=return_recv_hook)
122 |         large_gemm_with_hook(hook) if return_recv_hook else None
123 | 
124 |     # Calculate bandwidth
125 |     num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
126 |     num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
127 |     for i in range(num_tokens):
128 |         num_selections = (topk_idx[i] != -1).sum().item()
129 |         num_dispatch_comm_bytes += num_fp8_bytes * num_selections
130 |         num_combine_comm_bytes += num_bf16_bytes * num_selections
131 | 
132 |     # Dispatch + combine testing
133 |     avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
134 |     print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
135 |           f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
136 | 
137 |     # Separate profiling
138 |     for return_recv_hook in (False, True):
139 |         group.barrier()
140 |         dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
141 |                                              kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
142 |                                              suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1)
143 |         if not return_recv_hook:
144 |             print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
145 |                   f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True)
146 |         else:
147 |             print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
148 |                   f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True)
149 |     return hash_value
150 | 
151 | 
152 | # noinspection PyUnboundLocalVariable,PyShadowingNames
153 | def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
154 |     rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
155 |     num_tokens, hidden = args.num_tokens, args.hidden
156 |     num_topk, num_experts = args.num_topk, args.num_experts
157 | 
158 |     num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
159 |     if local_rank == 0:
160 |         print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
161 |     buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
162 |                             num_qps_per_rank=num_experts // num_ranks,
163 |                             allow_nvlink_for_low_latency_mode=not args.disable_nvlink, explicitly_destroy=True)
164 |     test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
165 |               use_logfmt=args.use_logfmt, seed=1)
166 | 
167 |     do_pressure_test = False
168 |     for seed in range(int(1e9) if do_pressure_test else 0):
169 |         if local_rank == 0:
170 |             print(f'Testing with seed {seed} ...', flush=True)
171 |         ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
172 |                              use_logfmt=args.use_logfmt, seed=seed)
173 |         for i in range(20):
174 |             assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
175 |                              use_logfmt=args.use_logfmt, seed=seed) == ref_hash, f'Error: seed={seed}'
176 | 
177 |     # Destroy the buffer runtime and communication group
178 |     buffer.destroy()
179 |     dist.barrier()
180 |     dist.destroy_process_group()
181 | 
182 | 
183 | if __name__ == '__main__':
184 |     # TODO: you may modify NUMA binding for less CPU overhead
185 |     # TODO: buggy with `num_tokens=512`
186 |     parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
187 |     parser.add_argument('--num-processes', type=int, default=8,
188 |                        help='Number of processes to spawn (default: 8)')
189 |     parser.add_argument('--num-tokens', type=int, default=128,
190 |                        help='Number of tokens (default: 128)')
191 |     parser.add_argument('--hidden', type=int, default=7168,
192 |                        help='Hidden dimension size (default: 7168)')
193 |     parser.add_argument('--num-topk', type=int, default=8,
194 |                        help='Number of top-k experts (default: 8)')
195 |     parser.add_argument('--num-experts', type=int, default=288,
196 |                        help='Number of experts (default: 288)')
197 |     parser.add_argument('--disable-nvlink', action='store_true',
198 |                         help='Whether to disable NVLink for testing')
199 |     parser.add_argument('--use-logfmt', action='store_true',
200 |                         help='Whether to test LogFMT combine')
201 |     args = parser.parse_args()
202 | 
203 |     num_processes = args.num_processes
204 |     torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
205 | 


--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
  1 | import inspect
  2 | import json
  3 | import tempfile
  4 | from pathlib import Path
  5 | 
  6 | import numpy as np
  7 | import os
  8 | import sys
  9 | import torch
 10 | import torch.distributed as dist
 11 | from typing import Optional, Union
 12 | 
 13 | 
 14 | def init_dist(local_rank: int, num_local_ranks: int):
 15 |     # NOTES: you may rewrite this function with your own cluster settings
 16 |     ip = os.getenv('MASTER_ADDR', '127.0.0.1')
 17 |     port = int(os.getenv('MASTER_PORT', '8361'))
 18 |     num_nodes = int(os.getenv('WORLD_SIZE', 1))
 19 |     node_rank = int(os.getenv('RANK', 0))
 20 | 
 21 |     sig = inspect.signature(dist.init_process_group)
 22 |     params = {
 23 |         'backend': 'nccl',
 24 |         'init_method': f'tcp://{ip}:{port}',
 25 |         'world_size': num_nodes * num_local_ranks,
 26 |         'rank': node_rank * num_local_ranks + local_rank,
 27 |     }
 28 |     if 'device_id' in sig.parameters:
 29 |         # noinspection PyTypeChecker
 30 |         params['device_id'] = torch.device(f'cuda:{local_rank}')
 31 |     dist.init_process_group(**params)
 32 |     torch.set_default_dtype(torch.bfloat16)
 33 |     torch.set_default_device('cuda')
 34 |     torch.cuda.set_device(local_rank)
 35 | 
 36 |     return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes)))
 37 | 
 38 | 
 39 | def calc_diff(x: torch.Tensor, y: torch.Tensor):
 40 |     x, y = x.double() + 1, y.double() + 1
 41 |     denominator = (x * x + y * y).sum()
 42 |     sim = 2 * (x * y).sum() / denominator
 43 |     return (1 - sim).item()
 44 | 
 45 | 
 46 | def per_token_cast_to_fp8(x: torch.Tensor):
 47 |     assert x.dim() == 2 and x.size(1) % 128 == 0
 48 |     m, n = x.shape
 49 |     x_view = x.view(m, -1, 128)
 50 |     x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
 51 |     return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
 52 | 
 53 | 
 54 | def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
 55 |     if x_scales.dtype == torch.int:
 56 |         x_scales = x_scales.view(dtype=torch.int8).to(torch.int) << 23
 57 |         x_scales = x_scales.view(dtype=torch.float)
 58 |     x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
 59 |     x_scales = x_scales.view(x_fp8.size(0), -1, 1)
 60 |     return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
 61 | 
 62 | 
 63 | def inplace_unique(x: torch.Tensor, num_slots: int):
 64 |     assert x.dim() == 2
 65 |     mask = x < 0
 66 |     x_padded = x.masked_fill(mask, num_slots)
 67 |     bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
 68 |     bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
 69 |     bin_count = bin_count[:, :num_slots]
 70 |     sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
 71 |     sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
 72 |     sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
 73 |     x[:, :].fill_(-1)
 74 |     valid_len = min(num_slots, x.size(1))
 75 |     x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
 76 | 
 77 | 
 78 | def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int):
 79 |     num_tokens, num_experts = scores.shape
 80 |     scores = scores.view(num_tokens, num_groups, -1)
 81 |     mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
 82 |     mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
 83 |     return (scores * mask).view(num_tokens, num_experts)
 84 | 
 85 | 
 86 | def bench(fn, num_warmups: int = 50, num_tests: int = 50, post_fn=None):
 87 |     # Flush L2 cache with 256 MB data
 88 |     torch.cuda.synchronize()
 89 |     cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
 90 | 
 91 |     # Warmup
 92 |     for _ in range(num_warmups):
 93 |         fn()
 94 | 
 95 |     # Flush L2
 96 |     cache.zero_()
 97 | 
 98 |     # Testing
 99 |     start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
100 |     end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
101 |     for i in range(num_tests):
102 |         # Record
103 |         start_events[i].record()
104 |         fn()
105 |         end_events[i].record()
106 |         if post_fn is not None:
107 |             post_fn()
108 |     torch.cuda.synchronize()
109 | 
110 |     times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:]
111 |     return np.average(times), np.min(times), np.max(times)
112 | 
113 | 
114 | class empty_suppress:
115 |     def __enter__(self):
116 |         return self
117 | 
118 |     def __exit__(self, *_):
119 |         pass
120 | 
121 | 
122 | class suppress_stdout_stderr:
123 |     def __enter__(self):
124 |         self.outnull_file = open(os.devnull, 'w')
125 |         self.errnull_file = open(os.devnull, 'w')
126 | 
127 |         self.old_stdout_fileno_undup = sys.stdout.fileno()
128 |         self.old_stderr_fileno_undup = sys.stderr.fileno()
129 | 
130 |         self.old_stdout_fileno = os.dup(sys.stdout.fileno())
131 |         self.old_stderr_fileno = os.dup(sys.stderr.fileno())
132 | 
133 |         self.old_stdout = sys.stdout
134 |         self.old_stderr = sys.stderr
135 | 
136 |         os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
137 |         os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
138 | 
139 |         sys.stdout = self.outnull_file
140 |         sys.stderr = self.errnull_file
141 |         return self
142 | 
143 |     def __exit__(self, *_):
144 |         sys.stdout = self.old_stdout
145 |         sys.stderr = self.old_stderr
146 | 
147 |         os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
148 |         os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
149 | 
150 |         os.close(self.old_stdout_fileno)
151 |         os.close(self.old_stderr_fileno)
152 | 
153 |         self.outnull_file.close()
154 |         self.errnull_file.close()
155 | 
156 | 
157 | def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppress_kineto_output: bool = False,
158 |                  trace_path: Optional[str] = None, barrier_comm_profiling: bool = False,
159 |                  num_kernels_per_period: int = 1):
160 |     # Profile
161 |     suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
162 |     with suppress():
163 |         schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
164 |         with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof:
165 |             for i in range(2):
166 |                 # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
167 |                 if barrier_comm_profiling:
168 |                     lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
169 |                     rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
170 |                     lhs @ rhs
171 |                     dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
172 |                 for _ in range(num_tests):
173 |                     fn()
174 |                 prof.step()
175 | 
176 |     # Parse the profiling table
177 |     assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
178 |     is_tuple = isinstance(kernel_names, tuple)
179 |     prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
180 |     kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
181 |     assert all([isinstance(name, str) for name in kernel_names])
182 |     for name in kernel_names:
183 |         assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
184 | 
185 |     # Save chrome traces
186 |     if trace_path is not None:
187 |         prof.export_chrome_trace(trace_path)
188 | 
189 |     # Return average kernel durations
190 |     units = {'ms': 1e3, 'us': 1e6}
191 |     kernel_durations = []
192 |     for name in kernel_names:
193 |         for line in prof_lines:
194 |             if name in line:
195 |                 time_str = line.split()[-2]
196 |                 for unit, scale in units.items():
197 |                     if unit in time_str:
198 |                         kernel_durations.append(float(time_str.replace(unit, '')) / scale)
199 |                         break
200 |                 break
201 | 
202 |     # Expand the kernels by periods
203 |     if num_kernels_per_period > 1:
204 |         with tempfile.NamedTemporaryFile(suffix='.json') as tmp:
205 |             prof.export_chrome_trace(tmp.name)
206 |             profile_data = json.loads(Path(tmp.name).read_text())
207 | 
208 |         for i, kernel_name in enumerate(kernel_names):
209 |             events = [event for event in profile_data['traceEvents'] if f'::{kernel_name}' in event['name']]
210 |             events = sorted(events, key=lambda event: event['ts'])
211 |             durations = [event['dur'] / 1e6 for event in events]
212 |             assert len(durations) % num_kernels_per_period == 0
213 |             num_kernel_patterns = len(durations) // num_kernels_per_period
214 |             kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns
215 |                                for j in range(num_kernels_per_period)]
216 | 
217 |     # Return execution durations
218 |     return kernel_durations if is_tuple else kernel_durations[0]
219 | 
220 | 
221 | def hash_tensor(t: torch.Tensor):
222 |     return t.view(torch.int64).sum().item()
223 | 


--------------------------------------------------------------------------------
/third-party/README.md:
--------------------------------------------------------------------------------
 1 | # Install NVSHMEM
 2 | 
 3 | ## Important notices
 4 | 
 5 | **This project is neither sponsored nor supported by NVIDIA.**
 6 | 
 7 | **Use of NVIDIA NVSHMEM is governed by the terms at [NVSHMEM Software License Agreement](https://docs.nvidia.com/nvshmem/api/sla.html).**
 8 | 
 9 | ## Prerequisites
10 | 
11 | Hardware requirements:
12 |    - GPUs inside one node needs to be connected by NVLink
13 |    - GPUs across different nodes needs to be connected by RDMA devices, see [GPUDirect RDMA Documentation](https://docs.nvidia.com/cuda/gpudirect-rdma/)
14 |    - InfiniBand GPUDirect Async (IBGDA) support, see [IBGDA Overview](https://developer.nvidia.com/blog/improving-network-performance-of-hpc-systems-using-nvidia-magnum-io-nvshmem-and-gpudirect-async/)
15 |    - For more detailed requirements, see [NVSHMEM Hardware Specifications](https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/abstract.html#hardware-requirements)
16 | 
17 | Software requirements:
18 |    - NVSHMEM v3.3.9 or later
19 | 
20 | ## Installation procedure
21 | 
22 | ### 1. Install NVSHMEM binaries
23 | 
24 | NVSHMEM 3.3.9 binaries are available in several formats:
25 |    - Tarballs for  [x86_64](https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/linux-x86_64/libnvshmem-linux-x86_64-3.3.9_cuda12-archive.tar.xz) and [aarch64](https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/linux-sbsa/libnvshmem-linux-sbsa-3.3.9_cuda12-archive.tar.xz)
26 |    - RPM and deb packages: instructions can be found on the [NVHSMEM installer page](https://developer.nvidia.com/nvshmem-downloads?target_os=Linux)
27 |    - Conda packages through conda-forge
28 |    - pip wheels through PyPI: `pip install nvidia-nvshmem-cu12`
29 | DeepEP is compatible with upstream NVSHMEM 3.3.9 and later.
30 | 
31 | 
32 | ### 2. Enable NVSHMEM IBGDA support
33 | 
34 | NVSHMEM Supports two modes with different requirements. Either of the following methods can be used to enable IBGDA support.
35 | 
36 | #### 2.1 Configure NVIDIA driver
37 | 
38 | This configuration enables traditional IBGDA support.
39 | 
40 | Modify `/etc/modprobe.d/nvidia.conf`:
41 | 
42 | ```bash
43 | options nvidia NVreg_EnableStreamMemOPs=1 NVreg_RegistryDwords="PeerMappingOverride=1;"
44 | ```
45 | 
46 | Update kernel configuration:
47 | 
48 | ```bash
49 | sudo update-initramfs -u
50 | sudo reboot
51 | ```
52 | 
53 | #### 2.2 Install GDRCopy and load the gdrdrv kernel module
54 | 
55 | This configuration enables IBGDA through asynchronous post-send operations assisted by the CPU. More information about CPU-assisted IBGDA can be found in [this blog](https://developer.nvidia.com/blog/enhancing-application-portability-and-compatibility-across-new-platforms-using-nvidia-magnum-io-nvshmem-3-0/#cpu-assisted_infiniband_gpu_direct_async%C2%A0).
56 | It comes with a small performance penalty, but can be used when modifying the driver regkeys is not an option.
57 | 
58 | Download GDRCopy
59 | GDRCopy is available as prebuilt deb and rpm packages [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/). or as source code on the [GDRCopy github repository](https://github.com/NVIDIA/gdrcopy).
60 | 
61 | Install GDRCopy following the instructions on the [GDRCopy github repository](https://github.com/NVIDIA/gdrcopy?tab=readme-ov-file#build-and-installation).
62 | 
63 | ## Post-installation configuration
64 | 
65 | When not installing NVSHMEM from RPM or deb packages, set the following environment variables in your shell configuration:
66 | 
67 | ```bash
68 | export NVSHMEM_DIR=/path/to/your/dir/to/install  # Use for DeepEP installation
69 | export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"
70 | export PATH="${NVSHMEM_DIR}/bin:$PATH"
71 | ```
72 | 
73 | ## Verification
74 | 
75 | ```bash
76 | nvshmem-info -a # Should display details of nvshmem
77 | ```
78 | 


--------------------------------------------------------------------------------
/third-party/nvshmem.patch:
--------------------------------------------------------------------------------
  1 | From 9e6cc27cceb3130784e4ea7b61ea3171156365fd Mon Sep 17 00:00:00 2001
  2 | From: Shangyan Zhou <sy.zhou@deepseek.com>
  3 | Date: Fri, 20 Dec 2024 10:57:12 +0800
  4 | Subject: [PATCH 1/4] Change QP creating order.
  5 | 
  6 | ---
  7 |  src/modules/transport/ibgda/ibgda.cpp | 13 ++++++++-----
  8 |  1 file changed, 8 insertions(+), 5 deletions(-)
  9 | 
 10 | diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
 11 | index ef325cd..286132e 100644
 12 | --- a/src/modules/transport/ibgda/ibgda.cpp
 13 | +++ b/src/modules/transport/ibgda/ibgda.cpp
 14 | @@ -2936,17 +2936,20 @@ int nvshmemt_ibgda_connect_endpoints(nvshmem_transport_t t, int *selected_dev_id
 15 |          INFO(ibgda_state->log_level, "Creating %d RC QPs", device->rc.num_eps_per_pe);
 16 |          for (int i = 0; i < num_rc_eps; ++i) {
 17 |              // Do not create loopback to self
 18 | -            if (i / device->rc.num_eps_per_pe == mype) {
 19 | +            int dst_pe = (i + 1 + mype) % n_pes;
 20 | +            int offset = i / n_pes;
 21 | +            int mapped_i = dst_pe * device->rc.num_eps_per_pe + offset;
 22 | +            if (dst_pe == mype) {
 23 |                  continue;
 24 |              }
 25 | -            status = ibgda_create_qp(&device->rc.eps[i], device, portid, i,
 26 | +            status = ibgda_create_qp(&device->rc.eps[mapped_i], device, portid, mapped_i,
 27 |                                       NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC);
 28 |              NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
 29 | -                                  "ibgda_create_dci failed on RC #%d.", i);
 30 | +                                  "ibgda_create_dci failed on RC #%d.", mapped_i);
 31 | 
 32 | -            status = ibgda_get_rc_handle(&local_rc_handles[i], device->rc.eps[i], device);
 33 | +            status = ibgda_get_rc_handle(&local_rc_handles[mapped_i], device->rc.eps[mapped_i], device);
 34 |              NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
 35 | -                                  "ibgda_get_rc_handle failed on RC #%d.", i);
 36 | +                                  "ibgda_get_rc_handle failed on RC #%d.", mapped_i);
 37 |          }
 38 | 
 39 |          if (num_rc_eps) {
 40 | --
 41 | 2.25.1
 42 | 
 43 | 
 44 | From b11d41e4f3727f2f6ccc00a8c852e59e2ee33c8a Mon Sep 17 00:00:00 2001
 45 | From: Shangyan Zhou <sy.zhou@deepseek.com>
 46 | Date: Fri, 10 Jan 2025 11:53:38 +0800
 47 | Subject: [PATCH 2/4] Add recv queue and recv cq for rc qps.
 48 | 
 49 | Let the ibgda rc qps use regular recv queue.
 50 | 
 51 | Add recv queue to ibgda dev qp.
 52 | 
 53 | IBGDA create recv cq
 54 | 
 55 | Setup recv cq.
 56 | 
 57 | fix recv queue.
 58 | 
 59 | Remove some useless idx.
 60 | 
 61 | Longer recv queue.
 62 | ---
 63 |  .../nvshmem_common_ibgda.h                    | 19 +++++-
 64 |  src/modules/transport/ibgda/ibgda.cpp         | 65 ++++++++++++++++---
 65 |  2 files changed, 71 insertions(+), 13 deletions(-)
 66 | 
 67 | diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h
 68 | index 8b8a263..1be3dec 100644
 69 | --- a/src/include/device_host_transport/nvshmem_common_ibgda.h
 70 | +++ b/src/include/device_host_transport/nvshmem_common_ibgda.h
 71 | @@ -168,14 +168,17 @@ typedef struct {
 72 |          uint64_t get_head;    // last wqe idx + 1 with a "fetch" operation (g, get, amo_fetch)
 73 |          uint64_t get_tail;    // last wqe idx + 1 polled with cst; get_tail > get_head is possible
 74 |      } tx_wq;
 75 | +    struct {
 76 | +        uint64_t resv_head;   // last reserved wqe idx + 1
 77 | +    } rx_wq;
 78 |      struct {
 79 |          uint64_t head;
 80 |          uint64_t tail;
 81 |      } ibuf;
 82 |      char padding[NVSHMEMI_IBGDA_QP_MANAGEMENT_PADDING];
 83 |  } __attribute__((__aligned__(8))) nvshmemi_ibgda_device_qp_management_v1;
 84 | -static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 96,
 85 | -              "ibgda_device_qp_management_v1 must be 96 bytes.");
 86 | +static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 104,
 87 | +              "ibgda_device_qp_management_v1 must be 104 bytes.");
 88 | 
 89 |  typedef nvshmemi_ibgda_device_qp_management_v1 nvshmemi_ibgda_device_qp_management_t;
 90 | 
 91 | @@ -199,9 +202,19 @@ typedef struct nvshmemi_ibgda_device_qp {
 92 |          // May point to mvars.prod_idx or internal prod_idx
 93 |          uint64_t *prod_idx;
 94 |      } tx_wq;
 95 | +    struct {
 96 | +        uint16_t nwqes;
 97 | +        uint64_t tail;
 98 | +        void *wqe;
 99 | +        __be32 *dbrec;
100 | +        void *bf;
101 | +        nvshmemi_ibgda_device_cq_t *cq;
102 | +        // May point to mvars.prod_idx or internal prod_idx
103 | +        uint64_t *prod_idx;
104 | +    } rx_wq;
105 |      nvshmemi_ibgda_device_qp_management_v1 mvars;  // management variables
106 |  } nvshmemi_ibgda_device_qp_v1;
107 | -static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 184, "ibgda_device_qp_v1 must be 184 bytes.");
108 | +static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 248, "ibgda_device_qp_v1 must be 248 bytes.");
109 | 
110 |  typedef nvshmemi_ibgda_device_qp_v1 nvshmemi_ibgda_device_qp_t;
111 | 
112 | diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
113 | index 286132e..e0b2d5c 100644
114 | --- a/src/modules/transport/ibgda/ibgda.cpp
115 | +++ b/src/modules/transport/ibgda/ibgda.cpp
116 | @@ -198,6 +198,7 @@ struct ibgda_ep {
117 |      off_t dbr_offset;
118 | 
119 |      struct ibgda_cq *send_cq;
120 | +    struct ibgda_cq *recv_cq;
121 |      struct ibv_ah *ah;
122 | 
123 |      uint32_t user_index;
124 | @@ -1538,7 +1539,8 @@ static int ibgda_create_cq_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,
125 | 
126 |      struct ibv_context *context = device->context;
127 | 
128 | -    unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes;
129 | +    // Each RC qp has one send CQ and one recv CQ.
130 | +    unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes * 2;
131 | 
132 |      assert(ibgda_qp_depth > 0);
133 |      size_t num_cqe = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth);
134 | @@ -1701,7 +1703,8 @@ static int ibgda_create_qp_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,
135 |      }
136 | 
137 |      // Allocate and map WQ buffer for all QPs.
138 | -    wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB;  // num_wqebb is always a power of 2
139 | +    // Todo: reduce the size of wq buffer.
140 | +    wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB * 2;  // num_wqebb is always a power of 2
141 |      wq_buf_size = wq_buf_size_per_qp * num_eps;
142 |      status = ibgda_nic_control_alloc(&wq_mobject, wq_buf_size, IBGDA_GPAGE_SIZE);
143 |      NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "cannot allocate wq buf.\n");
144 | @@ -1882,8 +1885,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
145 |      int cqe_version = 0;
146 | 
147 |      struct ibgda_cq *send_cq = NULL;
148 | +    struct ibgda_cq *recv_cq = NULL;
149 | 
150 |      size_t num_wqebb = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth);
151 | +    size_t num_recv_wqe = ibgda_qp_depth;
152 | +    size_t recv_wqe_size = 16;
153 | 
154 |      int status = 0;
155 | 
156 | @@ -1911,6 +1917,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
157 |      status = ibgda_create_cq(&send_cq, device);
158 |      NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_create_cq failed.\n");
159 | 
160 | +    if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) {
161 | +        status = ibgda_create_cq(&recv_cq, device);
162 | +        NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_create_cq failed.\n");
163 | +    }
164 | +
165 |      ep = (struct ibgda_ep *)calloc(1, sizeof(struct ibgda_ep));
166 |      NVSHMEMI_NULL_ERROR_JMP(ep, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,
167 |                              "Unable to allocate mem for ep.\n");
168 | @@ -1939,12 +1950,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
169 |      DEVX_SET(qpc, qp_context, pm_state, MLX5_QPC_PM_STATE_MIGRATED);
170 |      DEVX_SET(qpc, qp_context, pd, device->qp_shared_object.pdn);
171 |      DEVX_SET(qpc, qp_context, uar_page, uar_mobject->uar->page_id);  // BF register
172 | -    DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE);        // Shared Receive Queue
173 | -    DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn);
174 |      DEVX_SET(qpc, qp_context, cqn_snd, send_cq->cqn);
175 | -    DEVX_SET(qpc, qp_context, cqn_rcv, device->qp_shared_object.rcqn);
176 | +    DEVX_SET(qpc, qp_context, cqn_rcv, qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC ? recv_cq->cqn : device->qp_shared_object.rcqn);
177 |      DEVX_SET(qpc, qp_context, log_sq_size, IBGDA_ILOG2_OR0(num_wqebb));
178 | -    DEVX_SET(qpc, qp_context, log_rq_size, 0);
179 |      DEVX_SET(qpc, qp_context, cs_req, 0);                                     // Disable CS Request
180 |      DEVX_SET(qpc, qp_context, cs_res, 0);                                     // Disable CS Response
181 |      DEVX_SET(qpc, qp_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE);  // Enable dbr_umem_id
182 | @@ -1953,6 +1961,15 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
183 |      DEVX_SET(qpc, qp_context, dbr_umem_id, dbr_umem->umem_id);  // DBR buffer
184 |      DEVX_SET(qpc, qp_context, user_index, qp_idx);
185 |      DEVX_SET(qpc, qp_context, page_offset, 0);
186 | +    if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC){
187 | +        DEVX_SET(qpc, qp_context, rq_type, 0);        // Regular recv queue
188 | +        DEVX_SET(qpc, qp_context, log_rq_size, IBGDA_ILOG2(num_recv_wqe)); // 4 wqe
189 | +        DEVX_SET(qpc, qp_context, log_rq_stride, IBGDA_ILOG2(recv_wqe_size) - 4); // max recv wqe size = 16B
190 | +    } else {
191 | +        DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE);        // Shared Receive Queue, DC must use this.
192 | +        DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn);
193 | +        DEVX_SET(qpc, qp_context, log_rq_size, 0);
194 | +    }
195 | 
196 |      ep->devx_qp = mlx5dv_devx_obj_create(context, cmd_in, sizeof(cmd_in), cmd_out, sizeof(cmd_out));
197 |      NVSHMEMI_NULL_ERROR_JMP(ep->devx_qp, status, NVSHMEMX_ERROR_INTERNAL, out,
198 | @@ -1962,9 +1979,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
199 |      ep->portid = portid;
200 | 
201 |      ep->sq_cnt = num_wqebb;
202 | -    ep->sq_buf_offset = 0;
203 | +    ep->sq_buf_offset = num_recv_wqe * recv_wqe_size;
204 | 
205 | -    ep->rq_cnt = 0;
206 | +    ep->rq_cnt = num_recv_wqe;
207 |      ep->rq_buf_offset = 0;
208 | 
209 |      ep->wq_mobject = device->qp_shared_object.wq_mobject;
210 | @@ -1978,6 +1995,7 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
211 |      ep->uar_mobject = uar_mobject;
212 | 
213 |      ep->send_cq = send_cq;
214 | +    ep->recv_cq = recv_cq;
215 | 
216 |      ep->qp_type = qp_type;
217 | 
218 | @@ -1989,6 +2007,7 @@ out:
219 |      if (status) {
220 |          if (uar_mobject) ibgda_unmap_and_free_qp_uar(uar_mobject);
221 |          if (send_cq) ibgda_destroy_cq(send_cq);
222 | +        if (recv_cq) ibgda_destroy_cq(recv_cq);
223 |          if (ep) free(ep);
224 |      }
225 | 
226 | @@ -2287,6 +2306,10 @@ static int ibgda_destroy_ep(struct ibgda_ep *ep) {
227 |          ibgda_destroy_cq(ep->send_cq);
228 |      }
229 | 
230 | +    if (ep->recv_cq) {
231 | +        ibgda_destroy_cq(ep->recv_cq);
232 | +    }
233 | +
234 |      if (ep->ah) {
235 |          ftable.destroy_ah(ep->ah);
236 |      }
237 | @@ -2318,7 +2341,7 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda
238 |      dev_qp->qpn = ep->qpn;
239 | 
240 |      assert(ep->wq_mobject->has_gpu_mapping);
241 | -    dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset);
242 | +    dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->sq_buf_offset);
243 | 
244 |      if (ibgda_nic_handler == IBGDA_NIC_HANDLER_GPU) {
245 |          assert(ep->dbr_mobject->has_gpu_mapping);
246 | @@ -2330,6 +2353,12 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda
247 |      }
248 | 
249 |      dev_qp->tx_wq.nwqes = ep->sq_cnt;
250 | +    if (ep->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) {
251 | +        dev_qp->rx_wq.nwqes = ep->rq_cnt;
252 | +        dev_qp->rx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->rq_buf_offset);
253 | +        dev_qp->rx_wq.dbrec = (__be32 *)((uintptr_t)ep->dbr_mobject->aligned.gpu_ptr + ep->dbr_offset);
254 | +        dev_qp->rx_wq.bf = (void *)ep->uar_mobject->aligned.gpu_ptr;
255 | +    }
256 | 
257 |      ibuf_dci_start = (uintptr_t)device->qp_shared_object.internal_buf.mem_object->aligned.gpu_ptr;
258 |      ibuf_rc_start = ibuf_dci_start + (size_per_dci * device->dci.num_eps);
259 | @@ -2379,6 +2408,9 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
260 |      nvshmemi_ibgda_device_cq_t *cq_d = NULL;
261 |      nvshmemi_ibgda_device_cq_t *cq_h = NULL;
262 | 
263 | +    nvshmemi_ibgda_device_cq_t *recv_cq_d = NULL;
264 | +    nvshmemi_ibgda_device_cq_t *recv_cq_h = NULL;
265 | +
266 |      uint8_t *qp_group_switches_d = NULL;
267 | 
268 |      const size_t mvars_offset = offsetof(nvshmemi_ibgda_device_qp_t, mvars);
269 | @@ -2386,6 +2418,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
270 |      const size_t cons_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.cons_idx);
271 |      const size_t wqe_h_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.resv_head);
272 |      const size_t wqe_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.ready_head);
273 | +    const size_t rx_resv_head_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.resv_head);
274 | 
275 |      nvshmemi_ibgda_device_qp_map_type_t rc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
276 |      nvshmemi_ibgda_device_qp_map_type_t dc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
277 | @@ -2421,7 +2454,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
278 |          num_dct_handles += device->dct.num_eps * n_pes;
279 |          num_dci_handles += device->dci.num_eps;
280 |          num_rc_handles += device->rc.num_eps_per_pe * n_pes;
281 | -        num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1));
282 | +        num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1) * 2);
283 |          num_shared_dci_handles += device->dci.num_shared_eps;
284 |      }
285 |      assert(num_dci_handles - num_shared_dci_handles >= 0);
286 | @@ -2456,6 +2489,10 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
287 |      for (int i = 0; i < num_cq_handles; i++) {
288 |          nvshmemi_init_ibgda_device_cq(cq_h[i]);
289 |      }
290 | +
291 | +    recv_cq_h = (nvshmemi_ibgda_device_cq_t *)calloc(1, sizeof(*recv_cq_h));
292 | +    NVSHMEMI_NULL_ERROR_JMP(recv_cq_h, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "recv_cq calloc err.");
293 | +    nvshmemi_init_ibgda_device_cq(recv_cq_h[0]);
294 |      /* allocate host memory for dct, rc, cq, dci end */
295 | 
296 |      /* allocate device memory for dct, rc, cq, dci start */
297 | @@ -2559,6 +2596,14 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
298 |                  }
299 | 
300 |                  ++cq_idx;
301 | +
302 | +                rc_h[arr_idx].rx_wq.cq = &cq_d[cq_idx];
303 | +
304 | +                ibgda_get_device_cq(&cq_h[cq_idx], device->rc.eps[i]->recv_cq);
305 | +                cq_h[cq_idx].resv_head = (uint64_t *)(base_mvars_d_addr + rx_resv_head_offset);
306 | +                cq_h[cq_idx].qpn = rc_h[arr_idx].qpn;
307 | +                cq_h[cq_idx].qp_type = rc_h[arr_idx].qp_type;
308 | +                ++cq_idx;
309 |              }
310 |          }
311 |      }
312 | --
313 | 2.25.1
314 | 
315 | 
316 | From af479f9f23103d4a1579fae38676d6b3022df887 Mon Sep 17 00:00:00 2001
317 | From: Shangyan Zhou <sy.zhou@deepseek.com>
318 | Date: Sat, 8 Feb 2025 18:02:39 +0800
319 | Subject: [PATCH 3/4] Maintain recv queue's cons_idx.
320 | 
321 | ---
322 |  src/include/device_host_transport/nvshmem_common_ibgda.h | 5 +++--
323 |  src/modules/transport/ibgda/ibgda.cpp                    | 6 ++++--
324 |  2 files changed, 7 insertions(+), 4 deletions(-)
325 | 
326 | diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h
327 | index 1be3dec..ea1e284 100644
328 | --- a/src/include/device_host_transport/nvshmem_common_ibgda.h
329 | +++ b/src/include/device_host_transport/nvshmem_common_ibgda.h
330 | @@ -170,6 +170,7 @@ typedef struct {
331 |      } tx_wq;
332 |      struct {
333 |          uint64_t resv_head;   // last reserved wqe idx + 1
334 | +        uint64_t cons_idx;    // polled wqe idx + 1 (consumer index + 1)
335 |      } rx_wq;
336 |      struct {
337 |          uint64_t head;
338 | @@ -177,7 +178,7 @@ typedef struct {
339 |      } ibuf;
340 |      char padding[NVSHMEMI_IBGDA_QP_MANAGEMENT_PADDING];
341 |  } __attribute__((__aligned__(8))) nvshmemi_ibgda_device_qp_management_v1;
342 | -static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 104,
343 | -              "ibgda_device_qp_management_v1 must be 104 bytes.");
344 | +static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 112,
345 | +              "ibgda_device_qp_management_v1 must be 112 bytes.");
346 | 
347 |  typedef nvshmemi_ibgda_device_qp_management_v1 nvshmemi_ibgda_device_qp_management_t;
348 | @@ -214,7 +215,7 @@ typedef struct nvshmemi_ibgda_device_qp {
349 |      } rx_wq;
350 |      nvshmemi_ibgda_device_qp_management_v1 mvars;  // management variables
351 |  } nvshmemi_ibgda_device_qp_v1;
352 | -static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 248, "ibgda_device_qp_v1 must be 248 bytes.");
353 | +static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 256, "ibgda_device_qp_v1 must be 256 bytes.");
354 | 
355 |  typedef nvshmemi_ibgda_device_qp_v1 nvshmemi_ibgda_device_qp_t;
356 | 
357 | diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
358 | index e0b2d5c..bc339c5 100644
359 | --- a/src/modules/transport/ibgda/ibgda.cpp
360 | +++ b/src/modules/transport/ibgda/ibgda.cpp
361 | @@ -1067,7 +1067,7 @@ static inline void ibgda_nic_control_free(struct ibgda_mem_object *mobject) {
362 |          ibgda_host_mem_free(mobject);
363 |  }
364 | 
365 | -static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device) {
366 | +static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device, int cc = 1) {
367 |      int status = 0;
368 | 
369 |      struct ibgda_cq *gcq = NULL;
370 | @@ -1118,7 +1118,7 @@ static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device)
371 |      cq_context = DEVX_ADDR_OF(create_cq_in, cmd_in, cq_context);
372 |      DEVX_SET(cqc, cq_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE);
373 |      DEVX_SET(cqc, cq_context, cqe_sz, MLX5_CQE_SIZE_64B);
374 | -    DEVX_SET(cqc, cq_context, cc, 0x1);  // Use collapsed CQ
375 | +    DEVX_SET(cqc, cq_context, cc, cc);  // Use collapsed CQ
376 |      DEVX_SET(cqc, cq_context, oi, 0x1);  // Allow overrun
377 |      DEVX_SET(cqc, cq_context, dbr_umem_id, dbr_umem->umem_id);
378 |      DEVX_SET(cqc, cq_context, log_cq_size, IBGDA_ILOG2_OR0(num_cqe));
379 | @@ -2419,6 +2419,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
380 |      const size_t wqe_h_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.resv_head);
381 |      const size_t wqe_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.ready_head);
382 |      const size_t rx_resv_head_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.resv_head);
383 | +    const size_t rx_cons_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.cons_idx);
384 | 
385 |      nvshmemi_ibgda_device_qp_map_type_t rc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
386 |      nvshmemi_ibgda_device_qp_map_type_t dc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
387 | @@ -2601,6 +2602,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
388 | 
389 |                  ibgda_get_device_cq(&cq_h[cq_idx], device->rc.eps[i]->recv_cq);
390 |                  cq_h[cq_idx].resv_head = (uint64_t *)(base_mvars_d_addr + rx_resv_head_offset);
391 | +                cq_h[cq_idx].cons_idx = (uint64_t *)(base_mvars_d_addr + rx_cons_offset);
392 |                  cq_h[cq_idx].qpn = rc_h[arr_idx].qpn;
393 |                  cq_h[cq_idx].qp_type = rc_h[arr_idx].qp_type;
394 |                  ++cq_idx;
395 | --
396 | 2.25.1
397 | 
398 | 
399 | From e0ba3fa21b4b633b481c6684c3ad04f2670c8df4 Mon Sep 17 00:00:00 2001
400 | From: Shangyan Zhou <sy.zhou@deepseek.com>
401 | Date: Tue, 11 Feb 2025 11:00:57 +0800
402 | Subject: [PATCH 4/4] Init rx_wq counters.
403 | 
404 | ---
405 |  src/include/device_host_transport/nvshmem_common_ibgda.h | 2 ++
406 |  1 file changed, 2 insertions(+)
407 | 
408 | diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h
409 | index ea1e284..e6640d6 100644
410 | --- a/src/include/device_host_transport/nvshmem_common_ibgda.h
411 | +++ b/src/include/device_host_transport/nvshmem_common_ibgda.h
412 | @@ -46,6 +46,8 @@
413 |          qp_man.tx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                    \
414 |          qp_man.tx_wq.get_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                    \
415 |          qp_man.tx_wq.get_tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                    \
416 | +        qp_man.rx_wq.resv_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                    \
417 | +        qp_man.rx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                    \
418 |          qp_man.ibuf.head = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                         \
419 |          qp_man.ibuf.tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                         \
420 |      } while (0);
421 | --
422 | 2.25.1
423 | 
424 | diff --git a/src/modules/transport/common/transport_ib_common.cpp b/src/modules/transport/common/transport_ib_common.cpp
425 | index c89f408..f99018a 100644
426 | --- a/src/modules/transport/common/transport_ib_common.cpp
427 | +++ b/src/modules/transport/common/transport_ib_common.cpp
428 | @@ -26,6 +26,9 @@ int nvshmemt_ib_common_nv_peer_mem_available() {
429 |      if (access("/sys/kernel/mm/memory_peers/nvidia-peermem/version", F_OK) == 0) {
430 |          return NVSHMEMX_SUCCESS;
431 |      }
432 | +    if (access("/sys/module/nvidia_peermem/version", F_OK) == 0) {
433 | +        return NVSHMEMX_SUCCESS;
434 | +    }
435 |  
436 |      return NVSHMEMX_ERROR_INTERNAL;
437 |  }
438 | 
439 | 
440 | From 099f608fcd9a1d34c866ad75d0af5d02d2020374 Mon Sep 17 00:00:00 2001
441 | From: Kaichao You <youkaichao@gmail.com>
442 | Date: Tue, 10 Jun 2025 00:35:03 -0700
443 | Subject: [PATCH] remove gdrcopy dependency
444 | 
445 | ---
446 |  src/modules/transport/ibgda/ibgda.cpp | 6 ++++++
447 |  1 file changed, 6 insertions(+)
448 | 
449 | diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
450 | index ef325cd..16ee09c 100644
451 | --- a/src/modules/transport/ibgda/ibgda.cpp
452 | +++ b/src/modules/transport/ibgda/ibgda.cpp
453 | @@ -406,6 +406,7 @@ static size_t ibgda_get_host_page_size() {
454 |      return host_page_size;
455 |  }
456 | 
457 | +#ifdef NVSHMEM_USE_GDRCOPY
458 |  int nvshmemt_ibgda_progress(nvshmem_transport_t t) {
459 |      nvshmemt_ibgda_state_t *ibgda_state = (nvshmemt_ibgda_state_t *)t->state;
460 |      int n_devs_selected = ibgda_state->n_devs_selected;
461 | @@ -459,6 +460,11 @@ int nvshmemt_ibgda_progress(nvshmem_transport_t t) {
462 |      }
463 |      return 0;
464 |  }
465 | +#else
466 | +int nvshmemt_ibgda_progress(nvshmem_transport_t t) {
467 | +    return NVSHMEMX_ERROR_NOT_SUPPORTED;
468 | +}
469 | +#endif
470 | 
471 |  int nvshmemt_ibgda_show_info(struct nvshmem_transport *transport, int style) {
472 |      NVSHMEMI_ERROR_PRINT("ibgda show info not implemented");
473 | --
474 | 2.34.1
475 | 


--------------------------------------------------------------------------------