├── tests ├── __init__.py ├── fabric_lib │ ├── __init__.py │ ├── test_types.py │ └── test_handle.py ├── p2p_all_to_all │ ├── __init__.py │ └── data.py ├── fabric.py └── markers.py ├── benchmarks └── __init__.py ├── python └── pplx_garden │ ├── py.typed │ ├── __init__.py │ ├── kernels │ ├── __init__.py │ └── all_to_all.py │ ├── native │ ├── __init__.py │ ├── p2p_all_to_all.py │ ├── cumem.py │ ├── cumem.pyi │ └── p2p_all_to_all.pyi │ ├── utils │ ├── __init__.py │ ├── logging_utils.py │ └── math.py │ ├── distributed │ ├── __init__.py │ ├── distributed_ops.py │ ├── nccl_all_reduce.py │ └── parallel_group.py │ ├── fabric_lib.py │ └── fabric_lib.pyi ├── fabric-lib ├── libfabric-sys │ ├── .gitignore │ ├── wrapper.h │ ├── Cargo.toml │ ├── src │ │ └── lib.rs │ └── build.rs ├── src │ ├── utils │ │ ├── mod.rs │ │ ├── defer.rs │ │ ├── hex.rs │ │ ├── memory.rs │ │ └── obj_pool.rs │ ├── verbs │ │ ├── mod.rs │ │ ├── verbs_devinfo.rs │ │ └── verbs_address.rs │ ├── efa │ │ ├── mod.rs │ │ ├── efa_mr.rs │ │ └── efa_devinfo.rs │ ├── provider_dispatch.rs │ ├── lib.rs │ ├── provider.rs │ ├── error.rs │ ├── rdma_op.rs │ ├── mr.rs │ ├── transfer_engine_builder.rs │ ├── host_buffer.rs │ ├── imm_count.rs │ ├── interface.rs │ └── api.rs ├── libibverbs-sys │ ├── src │ │ └── lib.rs │ ├── Cargo.toml │ ├── wrapper.h │ └── build.rs ├── fabric-debug │ └── Cargo.toml └── Cargo.toml ├── rust-toolchain.toml ├── rust ├── cuda-lib │ ├── cudart-sys │ │ ├── wrapper.h │ │ ├── src │ │ │ └── lib.rs │ │ ├── Cargo.toml │ │ └── build.rs │ ├── gdrapi-sys │ │ ├── src │ │ │ └── lib.rs │ │ ├── Cargo.toml │ │ └── build.rs │ ├── cuda-sys │ │ ├── Cargo.toml │ │ ├── src │ │ │ └── lib.rs │ │ └── build.rs │ ├── src │ │ ├── device.rs │ │ ├── test_driver.rs │ │ ├── lib.rs │ │ ├── test_gdr.rs │ │ ├── error.rs │ │ ├── event.rs │ │ ├── driver.rs │ │ ├── rt.rs │ │ ├── mem.rs │ │ └── gdr.rs │ └── Cargo.toml ├── build-utils │ ├── Cargo.toml │ └── src │ │ └── lib.rs ├── thread-lib │ ├── Cargo.toml │ └── src │ │ └── lib.rs ├── proc-lib │ ├── Cargo.toml │ └── src │ │ └── lib.rs ├── torch-lib │ ├── Cargo.toml │ ├── src │ │ ├── torch_lib.h │ │ ├── test_torch.rs │ │ ├── lib.rs │ │ └── torch_lib.cc │ └── build.rs └── logging-lib │ ├── Cargo.toml │ └── src │ └── lib.rs ├── MANIFEST.in ├── rustfmt.toml ├── p2p-all-to-all ├── src │ ├── lib.rs │ └── a2a_handles.rs ├── a2a-kernels │ ├── Cargo.toml │ ├── src │ │ ├── core │ │ │ ├── common_utils.h │ │ │ ├── combine_utils.cuh │ │ │ ├── memory.cuh │ │ │ ├── vector.cuh │ │ │ ├── launch_utils.cuh │ │ │ └── device_utils.cuh │ │ ├── a2a │ │ │ └── a2a_kernels.h │ │ └── lib.rs │ └── build.rs └── Cargo.toml ├── .gitignore ├── python-ext ├── src │ ├── lib.rs │ ├── py_device.rs │ └── py_p2p_all_to_all.rs └── Cargo.toml ├── scripts └── run-docker.sh ├── LICENSE ├── Cargo.toml ├── docker └── dev.Dockerfile ├── pyproject.toml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/pplx_garden/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/fabric_lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/pplx_garden/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/p2p_all_to_all/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/pplx_garden/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/pplx_garden/native/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/pplx_garden/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fabric-lib/libfabric-sys/.gitignore: -------------------------------------------------------------------------------- 1 | tmp/ 2 | artifacts/ 3 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "1.91.0" 3 | profile = "default" 4 | -------------------------------------------------------------------------------- /rust/cuda-lib/cudart-sys/wrapper.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | -------------------------------------------------------------------------------- /fabric-lib/src/utils/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod defer; 2 | pub mod hex; 3 | pub mod memory; 4 | pub mod obj_pool; 5 | -------------------------------------------------------------------------------- /rust/build-utils/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | name = "build-utils" 4 | publish = false 5 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include Cargo.toml 2 | graft fabric-lib 3 | graft p2p-all-to-all 4 | graft python-ext 5 | graft rust 6 | -------------------------------------------------------------------------------- /rust/cuda-lib/cudart-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(warnings)] 2 | include!(concat!(env!("OUT_DIR"), "/cudart-bindings.rs")); 3 | -------------------------------------------------------------------------------- /rust/cuda-lib/gdrapi-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(warnings)] 2 | include!(concat!(env!("OUT_DIR"), "/gdrapi-bindings.rs")); 3 | -------------------------------------------------------------------------------- /fabric-lib/libibverbs-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(warnings)] 2 | include!(concat!(env!("OUT_DIR"), "/libibverbs-bindings.rs")); 3 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | style_edition = "2024" 2 | max_width = 88 3 | newline_style = "Unix" 4 | use_field_init_shorthand = true 5 | use_small_heuristics = "Max" 6 | -------------------------------------------------------------------------------- /rust/thread-lib/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | name = "thread-lib" 4 | 5 | [dependencies] 6 | libc = { workspace = true } 7 | syscalls = { workspace = true } 8 | -------------------------------------------------------------------------------- /p2p-all-to-all/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod a2a_context; 2 | mod a2a_handles; 3 | mod a2a_worker; 4 | 5 | pub use a2a_context::AllToAllContext; 6 | pub use a2a_handles::AllToAllRankHandle; 7 | -------------------------------------------------------------------------------- /fabric-lib/src/verbs/mod.rs: -------------------------------------------------------------------------------- 1 | mod verbs_address; 2 | mod verbs_devinfo; 3 | mod verbs_domain; 4 | mod verbs_qp; 5 | mod verbs_rdma_op; 6 | 7 | pub use verbs_devinfo::{VerbsDeviceInfo, VerbsDeviceList}; 8 | pub use verbs_domain::VerbsDomain; 9 | -------------------------------------------------------------------------------- /python/pplx_garden/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from .parallel_group import ( 2 | ParallelGroup as ParallelGroup, 3 | ) 4 | from .process_group import ( 5 | ParallelLaunch as ParallelLaunch, 6 | ProcessGroup as ProcessGroup, 7 | ) 8 | -------------------------------------------------------------------------------- /python/pplx_garden/native/p2p_all_to_all.py: -------------------------------------------------------------------------------- 1 | # pyright: reportMissingModuleSource=false, reportAttributeAccessIssue=false, reportMissingImports=false 2 | 3 | from pplx_garden._rust import ( 4 | AllToAllContext as AllToAllContext, 5 | ) 6 | -------------------------------------------------------------------------------- /fabric-lib/libfabric-sys/wrapper.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | -------------------------------------------------------------------------------- /rust/cuda-lib/cuda-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | links = "cuda" 4 | name = "cuda-sys" 5 | publish = false 6 | 7 | [build-dependencies] 8 | build-utils = { workspace = true } 9 | 10 | bindgen = {workspace = true} 11 | -------------------------------------------------------------------------------- /rust/cuda-lib/cudart-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | links = "cudart" 4 | name = "cudart-sys" 5 | publish = false 6 | 7 | [build-dependencies] 8 | build-utils = { workspace = true } 9 | 10 | bindgen = {workspace = true} 11 | -------------------------------------------------------------------------------- /rust/cuda-lib/gdrapi-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | links = "gdrapi" 4 | name = "gdrapi-sys" 5 | publish = false 6 | 7 | [build-dependencies] 8 | build-utils = { workspace = true } 9 | 10 | bindgen = {workspace = true} 11 | -------------------------------------------------------------------------------- /rust/proc-lib/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "proc-lib" 3 | edition = "2024" 4 | 5 | [lib] 6 | proc-macro = true 7 | 8 | [dependencies] 9 | quote = "1.0" 10 | proc-macro2 = "1.0" 11 | 12 | cudart-sys = { path = "../cuda-lib/cudart-sys" } 13 | -------------------------------------------------------------------------------- /fabric-lib/libfabric-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | links = "libfabric" 4 | name = "libfabric-sys" 5 | publish = false 6 | 7 | [build-dependencies] 8 | build-utils = { workspace = true } 9 | 10 | bindgen = { workspace = true } 11 | -------------------------------------------------------------------------------- /fabric-lib/src/efa/mod.rs: -------------------------------------------------------------------------------- 1 | mod efa_devinfo; 2 | mod efa_domain; 3 | mod efa_mr; 4 | mod efa_rdma_op; 5 | 6 | pub use efa_devinfo::{EfaDomainInfo, get_efa_domains}; 7 | pub use efa_domain::EfaDomain; 8 | 9 | // TODO(lequn): Remove pub 10 | pub use efa_mr::EfaMemDesc; 11 | -------------------------------------------------------------------------------- /fabric-lib/libibverbs-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | links = "ibverbs" 4 | name = "libibverbs-sys" 5 | publish = false 6 | 7 | [build-dependencies] 8 | build-utils = { workspace = true } 9 | 10 | bindgen = { workspace = true } 11 | cc = { workspace = true } 12 | -------------------------------------------------------------------------------- /rust/cuda-lib/src/device.rs: -------------------------------------------------------------------------------- 1 | use bincode::{Decode, Encode}; 2 | 3 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Encode, Decode)] 4 | pub struct CudaDeviceId(pub u8); 5 | 6 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 7 | pub enum Device { 8 | Host, 9 | Cuda(CudaDeviceId), 10 | } 11 | -------------------------------------------------------------------------------- /rust/cuda-lib/cuda-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(warnings)] 2 | include!(concat!(env!("OUT_DIR"), "/cuda-bindings.rs")); 3 | 4 | pub unsafe fn cuMemAlloc(dptr: *mut u64, bytesize: usize) -> CUresult { 5 | cuMemAlloc_v2(dptr, bytesize) 6 | } 7 | 8 | pub unsafe fn cuMemFree(dptr: u64) -> CUresult { 9 | cuMemFree_v2(dptr) 10 | } 11 | -------------------------------------------------------------------------------- /rust/torch-lib/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | links = "torch-lib" 4 | name = "torch-lib" 5 | publish = false 6 | 7 | [dependencies] 8 | cxx = { workspace = true } 9 | pyo3 = { workspace = true } 10 | cuda-lib = { workspace = true } 11 | 12 | [build-dependencies] 13 | cxx-build = { workspace = true } 14 | pkg-config = { workspace = true } 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .DS_Store 3 | data 4 | 5 | __pycache__ 6 | .pytest_cache 7 | .mypy_cache 8 | *.pyc 9 | *.egg-info 10 | dist 11 | .venv 12 | _version.py 13 | 14 | build 15 | build-cmake 16 | target 17 | *.so 18 | 19 | *.nsys-rep 20 | *.qdstrm 21 | 22 | tarpaulin-report.html 23 | cobertura.xml 24 | codecov.json 25 | codecov.txt 26 | .coverage* 27 | *.profraw 28 | -------------------------------------------------------------------------------- /fabric-lib/libibverbs-sys/wrapper.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | static inline int ibv_query_port_wrap( 4 | struct ibv_context *context, 5 | uint8_t port_num, 6 | struct ibv_port_attr *port_attr) 7 | { 8 | // ibv_query_port is a macro. Have to use this trick to make bindgen work. 9 | return ibv_query_port(context, port_num, port_attr); 10 | } 11 | -------------------------------------------------------------------------------- /rust/logging-lib/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | name = "logging-lib" 4 | 5 | [dependencies] 6 | anyhow = { workspace = true } 7 | is-terminal = { workspace = true } 8 | clap = { workspace = true } 9 | tracing = { workspace = true } 10 | tracing-core = { workspace = true } 11 | tracing-log = { workspace = true } 12 | tracing-subscriber = { workspace = true } 13 | -------------------------------------------------------------------------------- /p2p-all-to-all/a2a-kernels/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | links = "a2a_kernels" 4 | name = "a2a-kernels" 5 | publish = false 6 | 7 | [dependencies] 8 | cxx = { workspace = true } 9 | torch-lib = { workspace = true } 10 | 11 | [build-dependencies] 12 | cc = {workspace = true, features = ["parallel"]} 13 | cxx-build = { workspace = true } 14 | build-utils = {workspace = true} 15 | -------------------------------------------------------------------------------- /fabric-lib/fabric-debug/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | name = "fabric-debug" 4 | publish = false 5 | 6 | [dependencies] 7 | fabric-lib = { workspace = true } 8 | cuda-lib = { workspace = true } 9 | logging-lib = { workspace = true, features=[] } 10 | 11 | anyhow = { workspace = true } 12 | bytes = { workspace = true } 13 | postcard = { workspace = true } 14 | serde = { workspace = true } 15 | -------------------------------------------------------------------------------- /fabric-lib/libfabric-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(warnings)] 2 | include!(concat!(env!("OUT_DIR"), "/libfabric-bindings.rs")); 3 | 4 | pub const FI_ADDR_UNSPEC: fi_addr_t = u64::MAX; 5 | 6 | pub fn make_fi_version(major: u16, minor: u16) -> u32 { 7 | ((major as u32) << 16) | (minor as u32) 8 | } 9 | 10 | pub unsafe fn fi_close(fid: *mut fid) { 11 | (*(*fid).ops).close.unwrap_unchecked()(fid); 12 | } 13 | -------------------------------------------------------------------------------- /tests/fabric_lib/test_types.py: -------------------------------------------------------------------------------- 1 | from pplx_garden.fabric_lib import DomainAddress 2 | 3 | 4 | def test_domain_address() -> None: 5 | str_addr = "fe800000000000000455eefffe35f1c500000000d85fb3680000000000000000" 6 | addr1 = DomainAddress.from_str(str_addr) 7 | addr2 = DomainAddress.from_str(str_addr) 8 | assert addr1 == addr2 9 | assert hash(addr1) == hash(addr2) 10 | assert len({addr1, addr2}) == 1 11 | -------------------------------------------------------------------------------- /p2p-all-to-all/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | name = "p2p-all-to-all" 4 | 5 | [dependencies] 6 | a2a-kernels = { path = "a2a-kernels" } 7 | cuda-lib = { workspace = true } 8 | fabric-lib = { workspace = true } 9 | thread-lib = { workspace = true } 10 | torch-lib = { workspace = true } 11 | 12 | anyhow = { workspace = true } 13 | nvtx = { workspace = true } 14 | oneshot = { workspace = true } 15 | tracing = { workspace = true } 16 | -------------------------------------------------------------------------------- /rust/cuda-lib/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | name = "cuda-lib" 4 | publish = false 5 | 6 | [dependencies] 7 | cuda-sys = { path = "./cuda-sys" } 8 | cudart-sys = { path = "./cudart-sys" } 9 | gdrapi-sys = { path = "./gdrapi-sys" } 10 | 11 | libc = { workspace = true } 12 | thiserror = { workspace = true } 13 | bincode = { workspace = true, features = ["derive", "alloc"] } 14 | 15 | [dev-dependencies] 16 | proc-lib.workspace = true 17 | -------------------------------------------------------------------------------- /rust/cuda-lib/src/test_driver.rs: -------------------------------------------------------------------------------- 1 | use cuda_sys::CUDA_ERROR_OUT_OF_MEMORY; 2 | 3 | use crate::driver::CudaDriverError; 4 | use proc_lib::gpu_test; 5 | 6 | #[gpu_test] 7 | #[test] 8 | fn CudaDriverError_display() { 9 | let e = CudaDriverError::new(CUDA_ERROR_OUT_OF_MEMORY, "some test context"); 10 | assert_eq!( 11 | format!("{}", e), 12 | "CudaDriverError: code 2 (\"out of memory\"), context: some test context" 13 | ); 14 | } 15 | -------------------------------------------------------------------------------- /tests/fabric.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | from pathlib import Path 3 | 4 | 5 | @cache 6 | def count_sys_nvidia() -> int: 7 | return len(list(Path("/sys/bus/pci/drivers/nvidia/").glob("0000:*"))) 8 | 9 | 10 | @cache 11 | def count_sys_infiniband_verbs() -> int: 12 | return len(list(Path("/sys/class/infiniband_verbs/").glob("uverbs*"))) 13 | 14 | 15 | def get_nets_per_gpu() -> int: 16 | return count_sys_infiniband_verbs() // count_sys_nvidia() 17 | -------------------------------------------------------------------------------- /fabric-lib/src/utils/defer.rs: -------------------------------------------------------------------------------- 1 | pub struct Defer { 2 | f: F, 3 | canceled: bool, 4 | } 5 | 6 | impl Defer { 7 | pub fn new(f: F) -> Self { 8 | Self { f, canceled: false } 9 | } 10 | 11 | pub fn cancel(&mut self) { 12 | self.canceled = true; 13 | } 14 | } 15 | 16 | impl Drop for Defer { 17 | fn drop(&mut self) { 18 | if !self.canceled { 19 | (self.f)(); 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /rust/cuda-lib/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | 3 | pub use cuda_sys; 4 | pub use cudart_sys; 5 | pub use gdrapi_sys; 6 | pub mod driver; 7 | pub mod event; 8 | pub mod gdr; 9 | pub mod rt; 10 | 11 | pub mod cumem; 12 | mod error; 13 | mod mem; 14 | pub use error::{CudaError, CudaResult}; 15 | pub use mem::{CudaDeviceMemory, CudaHostMemory}; 16 | mod device; 17 | pub use device::{CudaDeviceId, Device}; 18 | 19 | #[cfg(test)] 20 | mod test_driver; 21 | 22 | #[cfg(test)] 23 | mod test_gdr; 24 | -------------------------------------------------------------------------------- /python/pplx_garden/native/cumem.py: -------------------------------------------------------------------------------- 1 | # pyright: reportMissingModuleSource=false, reportAttributeAccessIssue=false, reportMissingImports=false 2 | 3 | from pplx_garden._rust import ( 4 | CUMemAllocHandle as CUMemAllocHandle, 5 | CUMemExportHandle as CUMemExportHandle, 6 | CUMemHandleKind as CUMemHandleKind, 7 | CUMemImportHandle as CUMemImportHandle, 8 | CUMemMapping as CUMemMapping, 9 | CUMulticastExportHandle as CUMulticastExportHandle, 10 | CUMulticastHandle as CUMulticastHandle, 11 | ) 12 | -------------------------------------------------------------------------------- /python/pplx_garden/fabric_lib.py: -------------------------------------------------------------------------------- 1 | # pyright: reportMissingModuleSource=false, reportAttributeAccessIssue=false, reportMissingImports=false 2 | 3 | from pplx_garden._rust import ( 4 | DomainAddress as DomainAddress, 5 | DomainInfo as DomainInfo, 6 | MemoryRegionDescriptor as MemoryRegionDescriptor, 7 | MemoryRegionHandle as MemoryRegionHandle, 8 | PageIndices as PageIndices, 9 | TopologyGroup as TopologyGroup, 10 | TransferEngine as TransferEngine, 11 | TransferEngineBuilder as TransferEngineBuilder, 12 | ) 13 | -------------------------------------------------------------------------------- /rust/thread-lib/src/lib.rs: -------------------------------------------------------------------------------- 1 | use libc::{CPU_SET, CPU_ZERO, cpu_set_t, pthread_self, pthread_setaffinity_np}; 2 | use syscalls::Errno; 3 | 4 | pub fn pin_cpu(cpu: usize) -> Result<(), Errno> { 5 | unsafe { 6 | let mut cpuset = std::mem::zeroed(); 7 | CPU_ZERO(&mut cpuset); 8 | CPU_SET(cpu, &mut cpuset); 9 | let ret = 10 | pthread_setaffinity_np(pthread_self(), size_of::(), &cpuset); 11 | if ret != 0 { 12 | return Err(Errno::new(ret)); 13 | } 14 | Ok(()) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /fabric-lib/src/utils/hex.rs: -------------------------------------------------------------------------------- 1 | use std::num::ParseIntError; 2 | 3 | use bytes::{BufMut, Bytes, BytesMut}; 4 | 5 | pub fn fmt_hex(f: &mut std::fmt::Formatter<'_>, bytes: &[u8]) -> std::fmt::Result { 6 | for x in bytes { 7 | write!(f, "{:02x}", x)?; 8 | } 9 | Ok(()) 10 | } 11 | 12 | pub fn from_hex(s: &str) -> Result { 13 | let mut bytes = BytesMut::with_capacity(s.len() / 2); 14 | for i in (0..s.len()).step_by(2) { 15 | bytes.put_u8(u8::from_str_radix(&s[i..i + 2], 16)?); 16 | } 17 | Ok(bytes.freeze()) 18 | } 19 | -------------------------------------------------------------------------------- /p2p-all-to-all/src/a2a_handles.rs: -------------------------------------------------------------------------------- 1 | use fabric_lib::api::{DomainAddress, MemoryRegionDescriptor}; 2 | 3 | pub struct AllToAllRankHandle { 4 | pub address: DomainAddress, 5 | pub num_routed_desc: MemoryRegionDescriptor, 6 | pub recv_buffer_desc: MemoryRegionDescriptor, 7 | } 8 | 9 | impl AllToAllRankHandle { 10 | pub fn new( 11 | address: DomainAddress, 12 | num_routed_desc: MemoryRegionDescriptor, 13 | recv_buffer_desc: MemoryRegionDescriptor, 14 | ) -> Self { 15 | Self { address, num_routed_desc, recv_buffer_desc } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /python-ext/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod py_cumem; 2 | mod py_device; 3 | mod py_fabric_lib; 4 | mod py_p2p_all_to_all; 5 | 6 | use pyo3::{Bound, PyResult, pymodule, types::PyModule}; 7 | 8 | #[pymodule] 9 | fn _rust(m: &Bound<'_, PyModule>) -> PyResult<()> { 10 | let _ = logging_lib::init(&logging_lib::LoggingOpts { 11 | log_color: logging_lib::LogColor::Auto, 12 | log_format: logging_lib::LogFormat::Text, 13 | log_directives: None, 14 | }); 15 | 16 | py_cumem::init(m)?; 17 | py_p2p_all_to_all::init(m)?; 18 | py_fabric_lib::init(m)?; 19 | 20 | Ok(()) 21 | } 22 | -------------------------------------------------------------------------------- /scripts/run-docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | verbs="" 4 | for path in /dev/infiniband/uverbs*; do 5 | verbs="$verbs --device=$path" 6 | done 7 | 8 | root_dir=$(realpath $(dirname $0)/..) 9 | 10 | 11 | set -x 12 | exec docker run --rm -it --name=dev-pplx-garden \ 13 | -v $root_dir:/app \ 14 | --init \ 15 | --shm-size=32g \ 16 | --ulimit=memlock=-1 \ 17 | --ulimit=stack=67108864 \ 18 | --gpus=all \ 19 | $verbs \ 20 | --device=/dev/gdrdrv \ 21 | --cap-add=IPC_LOCK \ 22 | --cap-add=SYS_ADMIN \ 23 | --cap-add=SYS_PTRACE \ 24 | --security-opt=seccomp=unconfined \ 25 | --network host \ 26 | pplx-garden-dev 27 | -------------------------------------------------------------------------------- /python-ext/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | name = "pplx-garden-python-ext" 4 | 5 | [lib] 6 | name = "pplx_garden_python_ext" 7 | crate-type = ["cdylib"] 8 | 9 | [dependencies] 10 | cuda-lib = { workspace = true } 11 | fabric-lib = { workspace = true } 12 | logging-lib = { workspace = true } 13 | thread-lib = { workspace = true } 14 | torch-lib = { workspace = true } 15 | p2p-all-to-all = { workspace = true } 16 | 17 | bincode = { workspace = true } 18 | bytes = { workspace = true } 19 | parking_lot = { workspace = true } 20 | postcard = { workspace = true } 21 | pyo3 = { workspace = true, features = ["extension-module", "abi3-py310"] } 22 | serde = { workspace = true } 23 | tracing = { workspace = true } 24 | -------------------------------------------------------------------------------- /fabric-lib/src/provider_dispatch.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Cow; 2 | 3 | use crate::{efa::EfaDomainInfo, provider::RdmaDomainInfo, verbs::VerbsDeviceInfo}; 4 | 5 | #[derive(Clone)] 6 | pub enum DomainInfo { 7 | Efa(EfaDomainInfo), 8 | Verbs(VerbsDeviceInfo), 9 | } 10 | 11 | impl RdmaDomainInfo for DomainInfo { 12 | fn name(&self) -> Cow<'_, str> { 13 | match self { 14 | DomainInfo::Efa(info) => info.name(), 15 | DomainInfo::Verbs(info) => info.name(), 16 | } 17 | } 18 | 19 | fn link_speed(&self) -> u64 { 20 | match self { 21 | DomainInfo::Efa(info) => info.link_speed(), 22 | DomainInfo::Verbs(info) => info.link_speed(), 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /tests/markers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from pplx_garden.utils.torch import has_cuda, has_tp 5 | 6 | mark_ci_2gpu = pytest.mark.ci_2gpu 7 | mark_ci_4gpu = pytest.mark.ci_4gpu 8 | 9 | mark_fabric = pytest.mark.fabric 10 | mark_kernel = pytest.mark.kernel 11 | 12 | 13 | def mark_tp(n: int) -> pytest.MarkDecorator: 14 | return pytest.mark.skipif(not has_tp(n), reason=f"requires {n} GPUs") 15 | 16 | 17 | gpu_only = pytest.mark.skipif(not has_cuda(), reason="test requires CUDA") 18 | cpu_only = pytest.mark.cpu_only 19 | 20 | all_devices = pytest.mark.parametrize( 21 | "device", 22 | [ 23 | pytest.param(torch.device("cuda"), marks=gpu_only, id="cuda"), 24 | pytest.param(torch.device("cpu"), id="cpu"), 25 | ], 26 | ) 27 | -------------------------------------------------------------------------------- /fabric-lib/src/efa/efa_mr.rs: -------------------------------------------------------------------------------- 1 | use std::{ffi::c_void, ptr::NonNull}; 2 | 3 | use libfabric_sys::fid_mr; 4 | 5 | use crate::mr::MemoryRegionLocalDescriptor; 6 | 7 | #[derive(Debug, Clone, Copy)] 8 | pub struct EfaMemDesc(pub *mut *mut c_void); 9 | 10 | impl From> for EfaMemDesc { 11 | fn from(mr: NonNull) -> Self { 12 | EfaMemDesc(unsafe { &raw mut (*mr.as_ptr()).mem_desc }) 13 | } 14 | } 15 | 16 | impl From for EfaMemDesc { 17 | fn from(desc: MemoryRegionLocalDescriptor) -> Self { 18 | EfaMemDesc(desc.0 as *mut *mut c_void) 19 | } 20 | } 21 | 22 | impl From for MemoryRegionLocalDescriptor { 23 | fn from(desc: EfaMemDesc) -> Self { 24 | MemoryRegionLocalDescriptor(desc.0 as u64) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /rust/cuda-lib/src/test_gdr.rs: -------------------------------------------------------------------------------- 1 | use crate::gdr::{GdrCopyContext, GdrFlag}; 2 | use proc_lib::gpu_test; 3 | 4 | #[gpu_test] 5 | #[test] 6 | fn gdr_copy_flag() { 7 | // Set the current device. 8 | unsafe { 9 | let mut device: i32 = 0; 10 | cuda_sys::cuInit(0); 11 | cuda_sys::cuDeviceGet(&mut device, 0); 12 | 13 | let mut dev_ctx: cuda_sys::CUcontext = { std::ptr::null_mut() }; 14 | cuda_sys::cuDevicePrimaryCtxRetain(&mut dev_ctx, device); 15 | cuda_sys::cuCtxSetCurrent(dev_ctx); 16 | } 17 | 18 | // Create the GDR copy context. 19 | let gdr_context = GdrCopyContext::new().unwrap(); 20 | 21 | // Allocate a flag. 22 | let flag = GdrFlag::new(&gdr_context).unwrap(); 23 | 24 | // Set the value of the flag. 25 | flag.set(true); 26 | } 27 | -------------------------------------------------------------------------------- /rust/torch-lib/src/torch_lib.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "rust/cxx.h" 6 | 7 | namespace at { 8 | class RecordFunction; 9 | } 10 | 11 | namespace torch_lib { 12 | 13 | class TorchProfilerGuard final { 14 | public: 15 | TorchProfilerGuard(const char* name); 16 | ~TorchProfilerGuard(); 17 | 18 | private: 19 | std::unique_ptr guard; 20 | }; 21 | 22 | } // namespace torch_lib 23 | 24 | #include "torch-lib/src/lib.rs.h" 25 | 26 | namespace torch_lib { 27 | 28 | char *from_blob( 29 | char *data_ptr, 30 | rust::Slice shape, 31 | ScalarType dtype, 32 | Device device, 33 | rust::Box context 34 | ); 35 | 36 | ScalarType torch_to_scalar_type(char *obj); 37 | char *scalar_to_torch_type(ScalarType scalar_type); 38 | 39 | uint64_t current_stream(); 40 | 41 | std::unique_ptr profile_range(rust::String name); 42 | 43 | } // namespace torch_lib 44 | -------------------------------------------------------------------------------- /p2p-all-to-all/a2a-kernels/src/core/common_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #ifdef __CUDA_ARCH__ 7 | #define ROSE_HOST_DEVICE __host__ __device__ 8 | #else 9 | #define ROSE_HOST_DEVICE 10 | #endif 11 | 12 | namespace rose { 13 | 14 | /// The fixed warp size. 15 | constexpr size_t WARP_SIZE = 32; 16 | 17 | /// Return the next power of 2 following the given number. 18 | ROSE_HOST_DEVICE inline uint32_t next_pow_2(const uint32_t num) { 19 | if (num <= 1) { 20 | return num; 21 | } 22 | return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); 23 | } 24 | 25 | template ROSE_HOST_DEVICE T ceil_div(T x, T y) { return (x + y - 1) / y; } 26 | 27 | template ROSE_HOST_DEVICE T round_up(T x, T y) { return ceil_div(x, y) * y; } 28 | 29 | template ROSE_HOST_DEVICE T min(T x, T y) { return x < y ? x : y; } 30 | 31 | template ROSE_HOST_DEVICE T max(T x, T y) { return x > y ? x : y; } 32 | 33 | } // namespace rose 34 | -------------------------------------------------------------------------------- /fabric-lib/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2024" 3 | name = "fabric-lib" 4 | publish = false 5 | 6 | [features] 7 | default = [] 8 | tokio = ["dep:tokio", "tokio/sync", "tokio/rt-multi-thread"] 9 | 10 | [dependencies] 11 | libfabric-sys = { path = "./libfabric-sys" } 12 | libibverbs-sys = { path = "./libibverbs-sys" } 13 | 14 | cuda-lib = { workspace = true } 15 | thread-lib = { workspace = true } 16 | 17 | anyhow = { workspace = true } 18 | bytes = { workspace = true } 19 | crossbeam-channel = { workspace = true } 20 | dashmap = { workspace = true } 21 | libc = { workspace = true } 22 | once_cell = { workspace = true } 23 | oneshot = { workspace = true } 24 | parking_lot = { workspace = true } 25 | postcard = { workspace = true } 26 | serde = { workspace = true } 27 | smallvec = { workspace = true } 28 | thiserror = { workspace = true } 29 | tracing = { workspace = true } 30 | syscalls = { workspace = true } 31 | mockall = { workspace = true } 32 | 33 | # Optional dependencies 34 | tokio = { workspace = true, optional = true } 35 | -------------------------------------------------------------------------------- /tests/fabric_lib/test_handle.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: T201 2 | 3 | import pickle 4 | 5 | import torch 6 | 7 | from pplx_garden.fabric_lib import TransferEngine 8 | from tests.markers import gpu_only, mark_fabric 9 | 10 | 11 | @mark_fabric 12 | @gpu_only 13 | def test_pickle_unpickle_descriptor() -> None: 14 | # Build a transfer engine. 15 | group = TransferEngine.detect_topology()[0] 16 | builder = TransferEngine.builder() 17 | builder.add_gpu_domains( 18 | group.cuda_device, 19 | group.domains, 20 | group.cpus[0], 21 | group.cpus[1], 22 | ) 23 | engine = builder.build() 24 | 25 | # Allocate and register a CPU buffer. 26 | src_buf = torch.ones( 27 | (4096,), 28 | dtype=torch.uint8, 29 | device="cpu", 30 | ) 31 | 32 | # Check pickle round-trip. 33 | _, descriptor = engine.register_tensor(src_buf) 34 | pickled = pickle.dumps(descriptor) 35 | unpickled = pickle.loads(pickled) 36 | assert descriptor.as_bytes() == unpickled.as_bytes() 37 | -------------------------------------------------------------------------------- /python/pplx_garden/kernels/all_to_all.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol 2 | 3 | import torch 4 | 5 | 6 | class AllToAllKernel(Protocol): 7 | def dispatch( 8 | self, 9 | out_expert_num_tokens: torch.Tensor, 10 | out_expert_x: torch.Tensor, 11 | out_expert_x_scale: Optional[torch.Tensor], 12 | dp_x: torch.Tensor, 13 | dp_x_scale: Optional[torch.Tensor], 14 | indices: torch.Tensor, 15 | weights: torch.Tensor, 16 | bound_m: Optional[torch.Tensor] = None, 17 | do_send: bool = True, 18 | do_recv: bool = True, 19 | ) -> None: ... 20 | 21 | def combine( 22 | self, 23 | out_tokens: torch.Tensor, 24 | indices: torch.Tensor, 25 | weights: torch.Tensor, 26 | expert_y: torch.Tensor, 27 | bound_m: Optional[torch.Tensor] = None, 28 | do_send: bool = True, 29 | do_recv: bool = True, 30 | accumulate: bool = False, 31 | ) -> None: ... 32 | 33 | def destroy(self) -> None: ... 34 | -------------------------------------------------------------------------------- /rust/cuda-lib/cuda-sys/build.rs: -------------------------------------------------------------------------------- 1 | use std::{env, path::PathBuf}; 2 | 3 | use build_utils::find_package; 4 | 5 | fn main() -> Result<(), Box> { 6 | let cuda_home = find_package("CUDA_HOME", &["/usr/local/cuda"], "include/cuda.h"); 7 | let bindings = bindgen::Builder::default() 8 | .header(cuda_home.join("include/cuda.h").to_string_lossy()) 9 | .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) 10 | .prepend_enum_name(false) 11 | .allowlist_item(r"(cu|CU).*") 12 | .derive_default(true) 13 | .generate() 14 | .expect("Unable to generate cuda driver bindings"); 15 | let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); 16 | bindings 17 | .write_to_file(out_dir.join("cuda-bindings.rs")) 18 | .expect("Couldn't write cuda driver bindings!"); 19 | 20 | // Dynamic link dependencies 21 | println!("cargo:rustc-link-search=native={}/lib64/stubs", cuda_home.display()); 22 | println!("cargo:rustc-link-lib=cuda"); 23 | 24 | Ok(()) 25 | } 26 | -------------------------------------------------------------------------------- /p2p-all-to-all/a2a-kernels/build.rs: -------------------------------------------------------------------------------- 1 | use std::{env, path::PathBuf}; 2 | 3 | use build_utils::emit_rerun_if_changed_files; 4 | 5 | fn main() { 6 | let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); 7 | 8 | // Generate bindings 9 | cxx_build::bridge("src/lib.rs") 10 | .debug(false) 11 | .cuda(true) 12 | .flag("-t0") 13 | .flag("-O3") 14 | .flag("-cudart=shared") 15 | .flag("-gencode=arch=compute_90a,code=sm_90a") 16 | .flag("-gencode=arch=compute_100a,code=sm_100a") 17 | .flag(format!("-I{}/src", manifest_dir.display())) 18 | .file("src/a2a/a2a_dispatch_recv.cu") 19 | .file("src/a2a/a2a_combine_send.cu") 20 | .file("src/a2a/a2a_combine_recv.cu") 21 | .file("src/a2a/a2a_dispatch_send.cu") 22 | .compile("liba2a_kernels.a"); 23 | 24 | emit_rerun_if_changed_files("src", &["cu", "cuh", "h"]); 25 | 26 | println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64"); 27 | println!("cargo:rustc-link-lib=cudart"); 28 | } 29 | -------------------------------------------------------------------------------- /rust/cuda-lib/gdrapi-sys/build.rs: -------------------------------------------------------------------------------- 1 | use std::{env, path::PathBuf}; 2 | 3 | use build_utils::find_package; 4 | 5 | fn main() -> Result<(), Box> { 6 | let gdrapi_home = find_package("GDRAPI_HOME", &["/usr"], "include/gdrapi.h"); 7 | let bindings = bindgen::Builder::default() 8 | .header_contents("wrapper.h", "#include ") 9 | .clang_arg(format!("-I{}/include", gdrapi_home.display())) 10 | .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) 11 | .prepend_enum_name(false) 12 | .allowlist_item(r"gdr.*") 13 | .derive_default(true) 14 | .layout_tests(false) 15 | .generate() 16 | .expect("Unable to generate gdrapi bindings"); 17 | let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); 18 | bindings 19 | .write_to_file(out_dir.join("gdrapi-bindings.rs")) 20 | .expect("Couldn't write gdrapi bindings!"); 21 | 22 | // Dynamic link dependencies 23 | println!("cargo:rustc-link-lib=gdrapi"); 24 | println!("cargo:rustc-link-search=native={}/lib", gdrapi_home.display()); 25 | 26 | Ok(()) 27 | } 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Perplexity AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /fabric-lib/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod api; 2 | mod domain_group; 3 | mod efa; 4 | mod error; 5 | mod fabric_engine; 6 | mod host_buffer; 7 | mod imm_count; 8 | mod interface; 9 | mod mr; 10 | mod provider; 11 | mod provider_dispatch; 12 | mod rdma_op; 13 | mod topo; 14 | mod transfer_engine; 15 | mod transfer_engine_builder; 16 | mod utils; 17 | mod verbs; 18 | mod worker; 19 | 20 | pub use domain_group::DomainGroup; 21 | pub use error::*; 22 | pub use fabric_engine::FabricEngine; 23 | pub use host_buffer::{HostBuffer, HostBufferAllocator}; 24 | pub use interface::{ 25 | AsyncTransferEngine, BouncingErrorCallback, BouncingRecvCallback, ErrorCallback, 26 | RdmaEngine, RecvCallback, SendBuffer, SendCallback, SendRecvEngine, 27 | }; 28 | pub use provider::{RdmaDomain, RdmaDomainInfo}; 29 | pub use provider_dispatch::DomainInfo; 30 | pub use topo::{TopologyGroup, detect_topology}; 31 | pub use transfer_engine::{ 32 | ImmCountCallback, TransferCallback, TransferEngine, UvmWatcherCallback, 33 | }; 34 | pub use transfer_engine_builder::TransferEngineBuilder; 35 | pub use worker::{InitializingWorker, Worker, WorkerHandle}; 36 | 37 | pub use interface::MockTestTransferEngine; 38 | -------------------------------------------------------------------------------- /rust/torch-lib/src/test_torch.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::c_void; 2 | use std::ptr::NonNull; 3 | 4 | use cuda_lib::Device; 5 | use pyo3::{Py, PyAny, Python}; 6 | 7 | use crate::{ScalarType, from_blob}; 8 | 9 | #[test] 10 | fn test_from_blob() { 11 | let data = vec![1, 2]; 12 | 13 | Python::initialize(); 14 | Python::attach(|py| { 15 | py.import("torch").expect("Failed to import torch"); 16 | 17 | let tensor = from_blob( 18 | NonNull::new(data.as_ptr() as *mut c_void).unwrap(), 19 | &[1, 2], 20 | ScalarType::I32, 21 | Device::Host, 22 | Box::new(data), 23 | ); 24 | 25 | let tensor: Py = unsafe { Py::from_owned_ptr(py, tensor) }; 26 | let shape = tensor.getattr(py, "shape")?.extract::>(py)?; 27 | let dtype = tensor.getattr(py, "dtype")?.bind(py).to_string(); 28 | let device = tensor.getattr(py, "device")?.bind(py).to_string(); 29 | 30 | assert_eq!(shape, vec![1, 2]); 31 | assert_eq!(dtype, "torch.int32"); 32 | assert_eq!(device, "cpu"); 33 | Ok::<(), pyo3::PyErr>(()) 34 | }) 35 | .unwrap(); 36 | } 37 | -------------------------------------------------------------------------------- /fabric-lib/libfabric-sys/build.rs: -------------------------------------------------------------------------------- 1 | use std::{env, path::PathBuf}; 2 | 3 | use build_utils::find_package; 4 | 5 | fn main() -> Result<(), Box> { 6 | let libfabric_home = find_package( 7 | "LIBFABRIC_HOME", 8 | &["/opt/amazon/efa", "/usr"], 9 | "include/rdma/fabric.h", 10 | ); 11 | 12 | // Generate bindings 13 | // https://rust-lang.github.io/rust-bindgen/tutorial-3.html 14 | let bindings = bindgen::Builder::default() 15 | .header("wrapper.h") 16 | .clang_arg(format!("-I{}/include", libfabric_home.display())) 17 | .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) 18 | .prepend_enum_name(false) 19 | .allowlist_item(r"(fi|FI)_.*") 20 | .derive_default(true) 21 | .generate() 22 | .expect("Unable to generate libfabric bindings"); 23 | let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); 24 | bindings 25 | .write_to_file(out_dir.join("libfabric-bindings.rs")) 26 | .expect("Couldn't write libfabric bindings!"); 27 | 28 | // Dynamic link libfabric 29 | println!("cargo:rustc-link-search=native={}/lib", libfabric_home.display()); 30 | println!("cargo:rustc-link-lib=fabric"); 31 | 32 | Ok(()) 33 | } 34 | -------------------------------------------------------------------------------- /rust/cuda-lib/src/error.rs: -------------------------------------------------------------------------------- 1 | use crate::{driver::CudaDriverError, rt::CudartError}; 2 | 3 | pub type CudaResult = ::std::result::Result; 4 | 5 | #[derive(Debug, thiserror::Error)] 6 | pub enum CudaError { 7 | #[error("{0}")] 8 | CudaDriver(#[from] CudaDriverError), 9 | #[error("{0}")] 10 | Cudart(#[from] CudartError), 11 | #[error("{0}")] 12 | CudaError(cuda_sys::CUresult), 13 | #[error("{0}")] 14 | GdrCopyError(&'static str), 15 | #[error("{0}")] 16 | CustomError(String), 17 | #[error("{0}")] 18 | Errno(i32), 19 | } 20 | 21 | #[macro_export] 22 | macro_rules! cuda_check { 23 | ($x:expr) => {{ 24 | let code = unsafe { $x } as u32; 25 | if code != $crate::cuda_sys::CUDA_SUCCESS { 26 | Err($crate::CudaError::Cudart($crate::rt::CudartError { 27 | code, 28 | context: "cuda_check call failed", 29 | })) 30 | } else { 31 | Ok(()) 32 | } 33 | }}; 34 | } 35 | 36 | #[macro_export] 37 | macro_rules! cuda_unwrap { 38 | ($x:expr) => {{ 39 | let ret = unsafe { $x } as u32; 40 | if ret != $crate::cuda_sys::CUDA_SUCCESS { 41 | panic!("cuda_unwrap call failed: {}", ret); 42 | } 43 | }}; 44 | } 45 | -------------------------------------------------------------------------------- /rust/cuda-lib/src/event.rs: -------------------------------------------------------------------------------- 1 | use crate::rt::{CudaResult, CudartError}; 2 | 3 | pub struct CudaEvent { 4 | pub event: cudart_sys::cudaEvent_t, 5 | } 6 | 7 | impl CudaEvent { 8 | pub fn new() -> CudaResult { 9 | let mut event = std::ptr::null_mut(); 10 | let ret = unsafe { cudart_sys::cudaEventCreate(&mut event) }; 11 | if ret != 0 { 12 | return Err(CudartError::new(ret, "cudaEventCreate")); 13 | } 14 | Ok(CudaEvent { event }) 15 | } 16 | 17 | pub fn record(&self) -> CudaResult<()> { 18 | let ret = 19 | unsafe { cudart_sys::cudaEventRecord(self.event, std::ptr::null_mut()) }; 20 | if ret != 0 { 21 | return Err(CudartError::new(ret, "cudaEventRecord")); 22 | } 23 | Ok(()) 24 | } 25 | 26 | pub fn synchronize(&self) -> CudaResult<()> { 27 | let ret = unsafe { cudart_sys::cudaEventSynchronize(self.event) }; 28 | if ret != 0 { 29 | return Err(CudartError::new(ret, "cudaEventSynchronize")); 30 | } 31 | Ok(()) 32 | } 33 | } 34 | 35 | impl Drop for CudaEvent { 36 | fn drop(&mut self) { 37 | let ret = unsafe { cudart_sys::cudaEventDestroy(self.event) }; 38 | if ret != 0 { 39 | panic!("cudaEventDestroy failed: {}", ret); 40 | } 41 | } 42 | } 43 | 44 | unsafe impl Send for CudaEvent {} 45 | unsafe impl Sync for CudaEvent {} 46 | -------------------------------------------------------------------------------- /python/pplx_garden/distributed/distributed_ops.py: -------------------------------------------------------------------------------- 1 | """Definition of interfaces for distributed operation wrappers.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from collections.abc import Iterator 5 | from contextlib import contextmanager 6 | from typing import Optional 7 | 8 | import torch 9 | from torch.distributed import ReduceOp 10 | 11 | 12 | class Reducer(ABC): 13 | """Wrapper around a contextual reducer.""" 14 | 15 | @property 16 | @abstractmethod 17 | def input(self) -> Optional[torch.Tensor]: 18 | """Pre-allocated input tensor to be reduced.""" 19 | 20 | @abstractmethod 21 | def reduce( 22 | self, 23 | x: torch.Tensor, 24 | out: Optional[torch.Tensor] = None, 25 | ) -> torch.Tensor: 26 | """Run the reduction on the pre-allocated input.""" 27 | 28 | 29 | class ReducerBuilder(ABC): 30 | """Interface for all-reduce operations.""" 31 | 32 | @abstractmethod 33 | def reducer( 34 | self, 35 | shape: torch.Size, 36 | dtype: torch.dtype, 37 | op: ReduceOp.RedOpType = ReduceOp.SUM, 38 | ) -> Reducer: 39 | pass 40 | 41 | @abstractmethod 42 | def destroy(self) -> None: 43 | pass 44 | 45 | @abstractmethod 46 | def all_reduce( 47 | self, 48 | x: torch.Tensor, 49 | op: ReduceOp.RedOpType = ReduceOp.SUM, 50 | ) -> torch.Tensor: 51 | pass 52 | 53 | @contextmanager 54 | @abstractmethod 55 | def capture(self) -> Iterator[None]: 56 | yield 57 | -------------------------------------------------------------------------------- /python/pplx_garden/native/cumem.pyi: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Protocol 3 | 4 | import torch 5 | 6 | class CUMemHandleKind(Enum): 7 | Local = ... 8 | FileDescriptor = ... 9 | Fabric = ... 10 | 11 | class CUMemMapping: 12 | @property 13 | def size(self) -> int: ... 14 | def data_ptr(self) -> int: ... 15 | def to_tensor(self, shape: tuple[int, ...], dtype: torch.dtype) -> torch.Tensor: ... 16 | def unmap(self) -> None: ... 17 | 18 | class CUAllocHandle(Protocol): 19 | def map(self, device: torch.device | None = None) -> CUMemMapping: ... 20 | def map_to(self, mapping: CUMemMapping) -> None: ... 21 | 22 | class CUMemAllocHandle(CUAllocHandle): 23 | def __init__( 24 | self, 25 | size: int, 26 | device: torch.device, 27 | handle_kind: CUMemHandleKind = CUMemHandleKind.Local, 28 | ) -> None: ... 29 | def export(self) -> CUMemExportHandle: ... 30 | 31 | class CUMemImportHandle(CUAllocHandle): ... 32 | 33 | class CUMemExportHandle: 34 | def bind(self) -> CUMemImportHandle: ... 35 | 36 | class CUMulticastHandle(CUAllocHandle): 37 | def __init__( 38 | self, 39 | num_devices: int, 40 | size: int, 41 | handle_kind: CUMemHandleKind = CUMemHandleKind.Local, 42 | ) -> None: ... 43 | def export(self) -> CUMulticastExportHandle: ... 44 | def add_device(self, device_index: torch.device) -> None: ... 45 | def bind_mem(self, alloc_handle: CUMemAllocHandle) -> None: ... 46 | 47 | class CUMulticastExportHandle: 48 | def bind(self) -> CUMulticastHandle: ... 49 | -------------------------------------------------------------------------------- /rust/cuda-lib/cudart-sys/build.rs: -------------------------------------------------------------------------------- 1 | use bindgen::callbacks::{ItemInfo, ParseCallbacks}; 2 | use build_utils::find_package; 3 | use std::{env, path::PathBuf}; 4 | 5 | #[derive(Debug)] 6 | struct RenameCallback; 7 | 8 | impl ParseCallbacks for RenameCallback { 9 | fn item_name(&self, item_info: ItemInfo) -> Option { 10 | match item_info.name { 11 | // CUDA 12 defines cudaGetDeviceProperties as cudaGetDeviceProperties_v2. 12 | // CUDA 13 dropped the _v2 suffix. 13 | "cudaGetDeviceProperties_v2" => Some("cudaGetDeviceProperties".into()), 14 | 15 | // No rename needed. 16 | _ => None, 17 | } 18 | } 19 | } 20 | 21 | fn main() -> Result<(), Box> { 22 | let cuda_home = find_package("CUDA_HOME", &["/usr/local/cuda"], "include/cuda.h"); 23 | let bindings = bindgen::Builder::default() 24 | .header("wrapper.h") 25 | .clang_arg(format!("-I{}/include", cuda_home.display())) 26 | .parse_callbacks(Box::new(RenameCallback)) 27 | .prepend_enum_name(false) 28 | .allowlist_item(r"cuda.*") 29 | .derive_default(true) 30 | .generate() 31 | .expect("Unable to generate cuda runtime bindings"); 32 | let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); 33 | bindings 34 | .write_to_file(out_dir.join("cudart-bindings.rs")) 35 | .expect("Couldn't write cuda runtime bindings!"); 36 | 37 | // Dynamic link dependencies 38 | println!("cargo:rustc-link-search=native={}/lib64", cuda_home.display()); 39 | println!("cargo:rustc-link-lib=cudart"); 40 | 41 | Ok(()) 42 | } 43 | -------------------------------------------------------------------------------- /rust/cuda-lib/src/driver.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | ffi::{CStr, c_char, c_void}, 3 | ptr::{NonNull, null}, 4 | }; 5 | 6 | use cuda_sys::{ 7 | CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, cuGetErrorString, 8 | cuMemGetHandleForAddressRange, 9 | }; 10 | 11 | type Result = std::result::Result; 12 | 13 | #[derive(Clone, Debug)] 14 | pub struct CudaDriverError { 15 | pub code: u32, 16 | pub context: &'static str, 17 | } 18 | 19 | impl CudaDriverError { 20 | pub fn new(code: u32, context: &'static str) -> Self { 21 | Self { code, context } 22 | } 23 | } 24 | 25 | impl std::fmt::Display for CudaDriverError { 26 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 27 | let mut errstr: *const c_char = null(); 28 | unsafe { cuGetErrorString(self.code, &mut errstr) }; 29 | 30 | write!( 31 | f, 32 | "CudaDriverError: code {} ({:?}), context: {}", 33 | self.code, 34 | unsafe { CStr::from_ptr(errstr) }, 35 | self.context 36 | ) 37 | } 38 | } 39 | 40 | impl std::error::Error for CudaDriverError {} 41 | 42 | pub fn cu_get_dma_buf_fd(ptr: NonNull, len: usize) -> Result { 43 | let mut dmabuf_fd: i32 = -1; 44 | let ret = unsafe { 45 | cuMemGetHandleForAddressRange( 46 | &raw mut dmabuf_fd as *mut c_void, 47 | ptr.as_ptr() as u64, 48 | len, 49 | CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 50 | 0, 51 | ) 52 | }; 53 | match ret { 54 | 0 => Ok(dmabuf_fd), 55 | _ => Err(CudaDriverError::new(ret, "cuMemGetHandleForAddressRange")), 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /python-ext/src/py_device.rs: -------------------------------------------------------------------------------- 1 | use cuda_lib::{CudaDeviceId, Device}; 2 | use pyo3::{ 3 | Borrowed, FromPyObject, PyAny, PyErr, exceptions::PyValueError, types::PyAnyMethods, 4 | }; 5 | 6 | pub(crate) struct PyDevice(pub Device); 7 | 8 | impl<'py> FromPyObject<'_, 'py> for PyDevice { 9 | type Error = PyErr; 10 | 11 | fn extract(device: Borrowed<'_, 'py, PyAny>) -> Result { 12 | match device.getattr("type")?.extract::<&str>()? { 13 | "cpu" => Ok(PyDevice(Device::Host)), 14 | "cuda" => { 15 | let index = 16 | device.getattr("index")?.extract::>()?.unwrap_or(0); 17 | Ok(PyDevice(Device::Cuda(CudaDeviceId(index)))) 18 | } 19 | device_type => Err(PyValueError::new_err(format!( 20 | "Unknown device type: {device_type}" 21 | ))), 22 | } 23 | } 24 | } 25 | 26 | #[cfg(test)] 27 | mod test { 28 | use super::*; 29 | use pyo3::Python; 30 | 31 | #[test] 32 | fn test_py_device() { 33 | Python::initialize(); 34 | Python::attach(|py| { 35 | let torch = py.import("torch").unwrap(); 36 | let cuda_device = torch.call_method1("device", ("cuda",)).unwrap(); 37 | let PyDevice(device) = cuda_device.extract().unwrap(); 38 | assert_eq!(device, Device::Cuda(CudaDeviceId(0))); 39 | 40 | let cpu_device = torch.call_method1("device", ("cpu",)).unwrap(); 41 | let PyDevice(device) = cpu_device.extract().unwrap(); 42 | assert_eq!(device, Device::Host); 43 | 44 | let cuda_device_2 = torch.call_method1("device", ("cuda:2",)).unwrap(); 45 | let PyDevice(device) = cuda_device_2.extract().unwrap(); 46 | assert_eq!(device, Device::Cuda(CudaDeviceId(2))); 47 | }); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /fabric-lib/libibverbs-sys/build.rs: -------------------------------------------------------------------------------- 1 | use std::{env, path::PathBuf}; 2 | 3 | use build_utils::find_package; 4 | 5 | fn main() -> Result<(), Box> { 6 | let libibverbs_home = 7 | find_package("LIBIBVERBS_HOME", &["/usr"], "include/infiniband/verbs.h"); 8 | let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); 9 | 10 | // Generate bindings 11 | // https://rust-lang.github.io/rust-bindgen/tutorial-3.html 12 | let bindings = bindgen::Builder::default() 13 | .header("wrapper.h") 14 | .clang_arg(format!("-I{}/include", libibverbs_home.display())) 15 | .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) 16 | .prepend_enum_name(false) 17 | .allowlist_item(r"(ibv_|IBV_|ib_|IB_).*") 18 | .derive_debug(false) 19 | .derive_default(true) 20 | // Some functions use static inline functions to lookup vtable. 21 | .wrap_static_fns(true) 22 | .wrap_static_fns_path(out_dir.join("wrap_static_fns.c")) 23 | // Some structs includes pthread types. Let's treat them as opaque. 24 | .allowlist_item(r"pthread_.*") 25 | .opaque_type(r"pthread_.*") 26 | .no_default(r"pthread_.*") 27 | .generate() 28 | .expect("Unable to generate libibverbs bindings"); 29 | bindings 30 | .write_to_file(out_dir.join("libibverbs-bindings.rs")) 31 | .expect("Couldn't write libibverbs bindings!"); 32 | 33 | // Compile wrap_static_fns.c 34 | cc::Build::new() 35 | .file(out_dir.join("wrap_static_fns.c")) 36 | .include(libibverbs_home.join("include")) 37 | .include(env!("CARGO_MANIFEST_DIR")) 38 | .compile("wrap_static_fns"); 39 | 40 | // Dynamic link dependencies 41 | println!("cargo:rustc-link-search=native={}/lib", libibverbs_home.display()); 42 | println!("cargo:rustc-link-lib=ibverbs"); 43 | 44 | Ok(()) 45 | } 46 | -------------------------------------------------------------------------------- /python/pplx_garden/distributed/nccl_all_reduce.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterator 2 | from contextlib import contextmanager 3 | from typing import Optional 4 | 5 | import torch 6 | from torch.distributed import ProcessGroup, ReduceOp 7 | from typing_extensions import override 8 | 9 | from pplx_garden.distributed.distributed_ops import Reducer, ReducerBuilder 10 | 11 | 12 | class NcclReducer(Reducer): 13 | def __init__( 14 | self, 15 | group: ProcessGroup, 16 | op: ReduceOp.RedOpType = ReduceOp.SUM, 17 | ) -> None: 18 | self._group = group 19 | self._op = op 20 | 21 | @property 22 | @override 23 | def input(self) -> Optional[torch.Tensor]: 24 | return None 25 | 26 | def reduce( 27 | self, 28 | x: torch.Tensor, 29 | out: Optional[torch.Tensor] = None, 30 | ) -> torch.Tensor: 31 | if out is None: 32 | out = x 33 | elif out is not x: 34 | out.copy_(x) 35 | torch.distributed.all_reduce(out, op=self._op, group=self._group) 36 | return out 37 | 38 | 39 | class NcclAllReduce(ReducerBuilder): 40 | def __init__(self, group: ProcessGroup) -> None: 41 | self._group = group 42 | 43 | @override 44 | def reducer( 45 | self, 46 | shape: torch.Size, 47 | dtype: torch.dtype, 48 | op: ReduceOp.RedOpType = ReduceOp.SUM, 49 | ) -> Reducer: 50 | return NcclReducer(group=self._group, op=op) 51 | 52 | @override 53 | def destroy(self) -> None: 54 | pass 55 | 56 | @override 57 | def all_reduce( 58 | self, 59 | x: torch.Tensor, 60 | op: ReduceOp.RedOpType = ReduceOp.SUM, 61 | ) -> torch.Tensor: 62 | torch.distributed.all_reduce(x, op=op, group=self._group) 63 | return x 64 | 65 | @contextmanager 66 | @override 67 | def capture(self) -> Iterator[None]: 68 | yield 69 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "fabric-lib", 4 | "fabric-lib/fabric-debug", 5 | "p2p-all-to-all", 6 | "python-ext", 7 | "rust/build-utils", 8 | "rust/cuda-lib", 9 | "rust/logging-lib", 10 | "rust/proc-lib", 11 | "rust/thread-lib", 12 | "rust/torch-lib", 13 | ] 14 | resolver = "2" 15 | 16 | [workspace.dependencies] 17 | # Workspace crates 18 | build-utils = { path = "rust/build-utils" } 19 | cuda-lib = { path = "rust/cuda-lib" } 20 | fabric-lib = { path = "fabric-lib" } 21 | logging-lib = { path = "rust/logging-lib" } 22 | p2p-all-to-all = { path = "p2p-all-to-all" } 23 | proc-lib = { path = "rust/proc-lib" } 24 | thread-lib = { path = "rust/thread-lib" } 25 | torch-lib = { path = "rust/torch-lib" } 26 | 27 | # Shared dependencies 28 | anyhow = { version = "1.0.100", features = ["backtrace"] } 29 | bincode = "2.0.1" 30 | bindgen = "0.72.1" 31 | bytes = { version = "1.10.1", features = ["serde"] } 32 | cc = "1.2.44" 33 | clap = { version = "4.5.51", features = ["derive", "env"] } 34 | crossbeam-channel = "0.5.15" 35 | cxx = { version = "1.0.187" } 36 | cxx-build = { version = "1.0.187" } 37 | dashmap = { version = "6.1.0", features = ["inline"] } 38 | futures = "0.3.31" 39 | hashbrown = { version = "0.16.0", default-features = false, features = ["inline-more", "serde", "default-hasher"] } 40 | is-terminal = "0.4" 41 | libc = "0.2" 42 | mockall = "0.13.1" 43 | nvtx = "1.3.0" 44 | once_cell = "1.21.3" 45 | oneshot = "0.1.11" 46 | parking_lot = "0.12.5" 47 | pkg-config = "0.3.32" 48 | postcard = { version = "1.1.3", features = ["alloc", "use-std"] } 49 | pyo3 = { version = "0.27.1", features = ["anyhow"] } 50 | pyo3-build-config = "0.27.1" 51 | rand = "0.9.2" 52 | serde = { version = "1.0.228", features = ["derive", "rc"] } 53 | smallvec = { version = "1.15.1", features = ["serde", "union"] } 54 | syscalls = "0.7.0" 55 | thiserror = "2.0.17" 56 | tokio = "1.48.0" 57 | tracing = "0.1.41" 58 | tracing-core = "0.1.34" 59 | tracing-log = "0.2.0" 60 | tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } 61 | -------------------------------------------------------------------------------- /rust/torch-lib/build.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | use std::process::Command; 3 | 4 | fn main() { 5 | let cmake_prefix_path = match std::env::var("TORCH_CMAKE_PREFIX_PATH") { 6 | Ok(path) => path, 7 | Err(_) => { 8 | let output = Command::new("python3") 9 | .arg("-W") 10 | .arg("ignore") 11 | .arg("-c") 12 | .arg("import torch; print(torch.utils.cmake_prefix_path)") 13 | .output() 14 | .expect("failed to find Torch CMake prefix path"); 15 | 16 | if !output.stderr.is_empty() { 17 | let stderr_str = String::from_utf8_lossy(&output.stderr); 18 | println!( 19 | "cargo:warning=error getting torch prefix path: {}", 20 | stderr_str.trim() 21 | ); 22 | } 23 | 24 | String::from_utf8(output.stdout).unwrap() 25 | } 26 | }; 27 | 28 | let torch_path = Path::new(&cmake_prefix_path).parent().unwrap().parent().unwrap(); 29 | let torch_include = torch_path.join("include"); 30 | let torch_lib = torch_path.join("lib"); 31 | 32 | let config = pkg_config::Config::new().probe("python3").unwrap(); 33 | 34 | cxx_build::bridge("src/lib.rs") 35 | .file("src/torch_lib.cc") 36 | .flag("-Wno-unused-parameter") 37 | .includes(config.include_paths) 38 | .include(torch_include) 39 | .include("/usr/local/cuda/include") 40 | .std("c++20") 41 | .compile("torch-lib"); 42 | 43 | println!("cargo:rerun-if-changed=src/torch_lib.cc"); 44 | println!("cargo:rerun-if-changed=src/torch_lib.h"); 45 | 46 | println!("cargo:rustc-link-search=native={}", torch_lib.display()); 47 | println!("cargo:rustc-link-arg=-Wl,-rpath,{}", torch_lib.display()); 48 | println!("cargo:rustc-link-lib=torch_python"); 49 | println!("cargo:rustc-link-lib=torch"); 50 | println!("cargo:rustc-link-lib=torch_cuda"); 51 | println!("cargo:rustc-link-lib=torch_cpu"); 52 | println!("cargo:rustc-link-lib=c10_cuda"); 53 | println!("cargo:rustc-link-lib=c10"); 54 | } 55 | -------------------------------------------------------------------------------- /fabric-lib/src/provider.rs: -------------------------------------------------------------------------------- 1 | use std::{borrow::Cow, ffi::c_void, ptr::NonNull, sync::Arc}; 2 | 3 | use crate::{ 4 | api::{DomainAddress, MemoryRegionRemoteKey, PeerGroupHandle, TransferId}, 5 | error::{FabricLibError, Result}, 6 | imm_count::ImmCountMap, 7 | mr::{MemoryRegion, MemoryRegionLocalDescriptor}, 8 | rdma_op::{GroupWriteOp, RecvOp, SendOp, WriteOp}, 9 | }; 10 | 11 | pub trait RdmaDomainInfo { 12 | fn name(&self) -> Cow<'_, str>; 13 | fn link_speed(&self) -> u64; 14 | } 15 | 16 | pub trait RdmaDomain { 17 | type Info: RdmaDomainInfo; 18 | 19 | fn open(info: Self::Info, imm_count_map: Arc) -> Result 20 | where 21 | Self: Sized; 22 | 23 | fn link_speed(&self) -> u64; 24 | fn addr(&self) -> DomainAddress; 25 | 26 | fn register_mr_local(&mut self, region: &MemoryRegion) -> Result<()>; 27 | fn register_mr_allow_remote( 28 | &mut self, 29 | region: &MemoryRegion, 30 | ) -> Result; 31 | fn unregister_mr(&mut self, ptr: NonNull); 32 | fn get_mem_desc(&self, ptr: NonNull) 33 | -> Result; 34 | 35 | fn submit_recv(&mut self, transfer_id: TransferId, op: RecvOp); 36 | fn submit_send( 37 | &mut self, 38 | transfer_id: TransferId, 39 | dest_addr: DomainAddress, 40 | op: SendOp, 41 | ); 42 | fn submit_write( 43 | &mut self, 44 | transfer_id: TransferId, 45 | dest_addr: DomainAddress, 46 | op: WriteOp, 47 | ); 48 | 49 | fn add_peer_group( 50 | &mut self, 51 | handle: PeerGroupHandle, 52 | addrs: Vec, 53 | ) -> Result<()>; 54 | fn submit_group_write( 55 | &mut self, 56 | transfer_id: TransferId, 57 | handle: Option, 58 | op: GroupWriteOp, 59 | ); 60 | 61 | fn poll_progress(&mut self); 62 | fn get_completion(&mut self) -> Option; 63 | } 64 | 65 | pub enum DomainCompletionEntry { 66 | Recv { transfer_id: TransferId, data_len: usize }, 67 | Send(TransferId), 68 | Transfer(TransferId), 69 | ImmData(u32), 70 | ImmCountReached(u32), 71 | Error(TransferId, FabricLibError), 72 | } 73 | -------------------------------------------------------------------------------- /fabric-lib/src/error.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::CStr; 2 | 3 | use cuda_lib::{driver::CudaDriverError, rt::CudartError}; 4 | use libfabric_sys::fi_strerror; 5 | use syscalls::Errno; 6 | 7 | pub type Result = std::result::Result; 8 | 9 | #[derive(Clone, Debug, thiserror::Error)] 10 | pub enum FabricLibError { 11 | #[error("{0}")] 12 | Libfabric(#[from] LibfabricError), 13 | #[error("DomainError: {0}")] 14 | Domain(String), 15 | #[error("{0}")] 16 | Verbs(#[from] VerbsError), 17 | #[error("VerbsCompletionError: {0}")] 18 | VerbsCompletionError(String), 19 | #[error("Libfabric CompletionError: {0}")] 20 | CompletionError(String), 21 | #[error("{0}")] 22 | CudaDriver(#[from] CudaDriverError), 23 | #[error("{0}")] 24 | Cudart(#[from] CudartError), 25 | #[error("{0}")] 26 | Errno(#[from] Errno), 27 | #[error("FabricLibError: {0}")] 28 | Custom(&'static str), 29 | } 30 | 31 | #[derive(Clone, Debug)] 32 | pub struct LibfabricError { 33 | pub code: i32, 34 | pub context: &'static str, 35 | } 36 | 37 | impl LibfabricError { 38 | pub fn new(code: i32, context: &'static str) -> Self { 39 | Self { code, context } 40 | } 41 | } 42 | 43 | impl std::fmt::Display for LibfabricError { 44 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 45 | write!( 46 | f, 47 | "LibfabricError: code {} ({:?}), context: {}", 48 | self.code, 49 | unsafe { CStr::from_ptr(fi_strerror(self.code)) }, 50 | self.context 51 | ) 52 | } 53 | } 54 | 55 | impl std::error::Error for LibfabricError {} 56 | 57 | #[derive(Clone, Debug, thiserror::Error)] 58 | #[error("VerbsError: code {code}, context: {context}")] 59 | pub struct VerbsError { 60 | pub code: Errno, 61 | pub context: &'static str, 62 | } 63 | 64 | impl VerbsError { 65 | pub fn with_last_os_error(context: &'static str) -> Self { 66 | Self { 67 | code: Errno::new( 68 | std::io::Error::last_os_error().raw_os_error().unwrap_or(0), 69 | ), 70 | context, 71 | } 72 | } 73 | 74 | pub fn with_code(code: i32, context: &'static str) -> Self { 75 | Self { code: Errno::new(code), context } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /rust/proc-lib/src/lib.rs: -------------------------------------------------------------------------------- 1 | use proc_macro::TokenStream; 2 | use proc_macro2::TokenTree; 3 | use quote::quote; 4 | 5 | /// We need to copy the implementation of cudaGetDeviceCount here because we can't use `cuda-lib` as a crate dependency 6 | /// if we want to use this macro in `cuda-lib` itself. 7 | fn cuda_get_device_count() -> Result { 8 | let mut count: i32 = 0; 9 | let ret = unsafe { cudart_sys::cudaGetDeviceCount(&raw mut count) }; 10 | match ret { 11 | 0 => Ok(count), 12 | _ => Err(ret), 13 | } 14 | } 15 | 16 | #[proc_macro_attribute] 17 | /// Skip the test if no GPUs are available. 18 | pub fn gpu_test(_args: TokenStream, input: TokenStream) -> TokenStream { 19 | let input = proc_macro2::TokenStream::from(input); 20 | 21 | // instead of panicking, we should skip the test, an error is likely because 22 | // CUDA stuff is unavailable anyway 23 | let available_gpus: Result<(), String> = match cuda_get_device_count() { 24 | Ok(count) if count > 0 => Ok(()), 25 | Ok(_) => Err("no GPUs available".to_string()), 26 | Err(e) => Err(format!("error getting device count, cuda error code: {}", e)), 27 | }; 28 | 29 | let output = if let Err(e) = available_gpus { 30 | quote! { 31 | #[ignore = #e] 32 | #input 33 | } 34 | } else { 35 | let input = append_gpu_suffix(input); 36 | quote! { 37 | #input 38 | } 39 | }; 40 | 41 | TokenStream::from(output) 42 | } 43 | 44 | /// Append _GPU to the function name. we can use this to filter out GPU tests. 45 | fn append_gpu_suffix(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream { 46 | let mut next_is_fn = false; 47 | let mut output = proc_macro2::TokenStream::new(); 48 | for token in input.into_iter() { 49 | let append_token = match &token { 50 | TokenTree::Ident(ident) if next_is_fn => { 51 | next_is_fn = false; 52 | let span = ident.span(); 53 | let new_fn_name = ident.to_string() + "_GPU"; 54 | TokenTree::Ident(proc_macro2::Ident::new(&new_fn_name, span)) 55 | } 56 | TokenTree::Ident(ident) if ident == "fn" => { 57 | next_is_fn = true; 58 | token 59 | } 60 | _ => token, 61 | }; 62 | output.extend([append_token]); 63 | } 64 | output 65 | } 66 | -------------------------------------------------------------------------------- /docker/dev.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.9.1-devel-ubuntu24.04 as gdrcopy-builder 2 | 3 | RUN apt-get update && apt-get install -y build-essential devscripts debhelper fakeroot pkg-config wget 4 | RUN cd /tmp && \ 5 | wget https://github.com/NVIDIA/gdrcopy/archive/refs/tags/v2.5.1.tar.gz && \ 6 | tar -xf v2.5.1.tar.gz && \ 7 | cd gdrcopy-2.5.1/packages/ && \ 8 | CUDA=/usr/local/cuda ./build-deb-packages.sh -t -k 9 | 10 | 11 | 12 | 13 | 14 | FROM nvidia/cuda:12.9.1-devel-ubuntu24.04 as final 15 | 16 | RUN apt-get update && \ 17 | apt-get install -y --no-install-recommends \ 18 | patchelf \ 19 | libclang-dev \ 20 | clang-18 \ 21 | clang-format-18 \ 22 | git \ 23 | build-essential \ 24 | cmake \ 25 | libssl-dev \ 26 | wget \ 27 | curl \ 28 | ninja-build \ 29 | pkg-config \ 30 | python3-dev \ 31 | python3-setuptools \ 32 | python3-pip \ 33 | python3-build \ 34 | python3-venv \ 35 | && \ 36 | apt-get clean && \ 37 | rm -rf /var/lib/apt/lists/* 38 | 39 | # PyTorch 40 | ENV PIP_BREAK_SYSTEM_PACKAGES=1 41 | ENV TORCH_CUDA_ARCH_LIST="9.0a;10.0a+PTX" 42 | RUN python3 -m pip install torch==2.9.0+cu129 --index-url https://download.pytorch.org/whl/cu129 43 | 44 | 45 | # EFA (including libfabric and libibverbs) 46 | RUN cd /tmp && \ 47 | curl -O https://efa-installer.amazonaws.com/aws-efa-installer-1.44.0.tar.gz && \ 48 | tar -xf aws-efa-installer-1.44.0.tar.gz && \ 49 | cd aws-efa-installer && \ 50 | apt-get update && \ 51 | ./efa_installer.sh -y -g -d --skip-kmod --skip-limit-conf --no-verify && \ 52 | rm -rf /var/lib/apt/lists/* && \ 53 | ldconfig && \ 54 | rm -rf /tmp/aws-efa-installer* 55 | ENV NCCL_SOCKET_IFNAME=^docker,lo,veth_def_agent 56 | 57 | 58 | # GDRCopy 59 | COPY --from=gdrcopy-builder /tmp/gdrcopy-2.5.1/packages/libgdrapi_2.5.1-1_amd64.Ubuntu24_04.deb /tmp/libgdrapi_2.5.1-1_amd64.Ubuntu24_04.deb 60 | RUN dpkg -i /tmp/libgdrapi_2.5.1-1_amd64.Ubuntu24_04.deb && \ 61 | rm -rf /tmp/libgdrapi_2.5.1-1_amd64.Ubuntu24_04.deb 62 | 63 | 64 | # Rust 65 | RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain 1.91.0 --component llvm-tools-preview 66 | ENV PATH="/root/.cargo/bin:$PATH" \ 67 | CARGO_HOME="/root/.cargo" \ 68 | RUSTUP_HOME="/root/.rustup" 69 | 70 | 71 | # Python dependencies 72 | RUN python3 -m pip install numpy ninja maturin \ 73 | pytest coverage mypy pylint ruff 74 | 75 | -------------------------------------------------------------------------------- /fabric-lib/src/rdma_op.rs: -------------------------------------------------------------------------------- 1 | use std::{ffi::c_void, ptr::NonNull, sync::Arc}; 2 | 3 | use crate::{ 4 | api::{DomainAddress, MemoryRegionRemoteKey, ScatterTarget}, 5 | mr::MemoryRegionLocalDescriptor, 6 | }; 7 | 8 | pub struct SingleWriteOp { 9 | pub src_ptr: NonNull, 10 | pub src_desc: MemoryRegionLocalDescriptor, 11 | pub src_offset: u64, 12 | pub length: u64, 13 | pub imm_data: Option, 14 | pub dst_ptr: u64, 15 | pub dst_rkey: MemoryRegionRemoteKey, 16 | pub dst_offset: u64, 17 | } 18 | 19 | pub struct ImmWriteOp { 20 | pub imm_data: u32, 21 | pub dst_ptr: u64, 22 | pub dst_rkey: MemoryRegionRemoteKey, 23 | } 24 | 25 | pub struct PagedWriteOp { 26 | pub src_page_indices: Arc>, 27 | pub dst_page_indices: Arc>, 28 | pub page_indices_beg: usize, 29 | pub page_indices_end: usize, 30 | pub length: u64, 31 | pub src_ptr: NonNull, 32 | pub src_desc: MemoryRegionLocalDescriptor, 33 | pub src_stride: u64, 34 | pub src_offset: u64, 35 | pub dst_ptr: u64, 36 | pub dst_rkey: MemoryRegionRemoteKey, 37 | pub dst_stride: u64, 38 | pub dst_offset: u64, 39 | pub imm_data: Option, 40 | } 41 | 42 | pub enum WriteOp { 43 | Single(SingleWriteOp), 44 | Imm(ImmWriteOp), 45 | Paged(PagedWriteOp), 46 | } 47 | 48 | pub struct ScatterGroupWriteOp { 49 | pub domain_idx: usize, 50 | pub src_ptr: NonNull, 51 | pub src_desc: MemoryRegionLocalDescriptor, 52 | pub imm_data: Option, 53 | pub dsts: Arc>, 54 | pub dst_beg: usize, 55 | pub dst_end: usize, 56 | pub byte_shards: u32, 57 | pub byte_shard_idx: u32, 58 | } 59 | 60 | pub enum GroupWriteOp { 61 | Scatter(ScatterGroupWriteOp), 62 | } 63 | 64 | impl GroupWriteOp { 65 | pub fn num_targets(&self) -> usize { 66 | match self { 67 | GroupWriteOp::Scatter(op) => op.dsts.len(), 68 | } 69 | } 70 | 71 | pub fn peer_addr_iter(&self) -> impl Iterator { 72 | match self { 73 | GroupWriteOp::Scatter(op) => { 74 | op.dsts.iter().map(|dst| &dst.dst_mr.addr_rkey_list[op.domain_idx].0) 75 | } 76 | } 77 | } 78 | } 79 | 80 | pub struct SendOp { 81 | pub ptr: NonNull, 82 | pub len: usize, 83 | pub desc: MemoryRegionLocalDescriptor, 84 | } 85 | 86 | pub struct RecvOp { 87 | pub ptr: NonNull, 88 | pub len: usize, 89 | pub desc: MemoryRegionLocalDescriptor, 90 | } 91 | -------------------------------------------------------------------------------- /fabric-lib/src/verbs/verbs_devinfo.rs: -------------------------------------------------------------------------------- 1 | use std::{borrow::Cow, ffi::CStr, sync::Arc}; 2 | 3 | use crate::{ 4 | error::{Result, VerbsError}, 5 | provider::RdmaDomainInfo, 6 | }; 7 | 8 | use libibverbs_sys::{ibv_device, ibv_free_device_list, ibv_get_device_list}; 9 | 10 | pub struct VerbsDeviceList { 11 | pub list: *mut *mut ibv_device, 12 | pub num_devices: usize, 13 | } 14 | 15 | unsafe impl Send for VerbsDeviceList {} 16 | unsafe impl Sync for VerbsDeviceList {} 17 | 18 | impl VerbsDeviceList { 19 | pub fn get_all_devices() -> Result> { 20 | let mut num_devices = 0; 21 | let list = unsafe { ibv_get_device_list(&raw mut num_devices) }; 22 | if list.is_null() { 23 | Err(VerbsError::with_last_os_error("ibv_get_device_list").into()) 24 | } else { 25 | Ok(Arc::new(Self { list, num_devices: num_devices as usize })) 26 | } 27 | } 28 | } 29 | 30 | impl Drop for VerbsDeviceList { 31 | fn drop(&mut self) { 32 | unsafe { ibv_free_device_list(self.list) }; 33 | } 34 | } 35 | 36 | #[derive(Clone)] 37 | pub struct VerbsDeviceInfo { 38 | pub device_list: Arc, 39 | pub device_index: usize, 40 | pub port_num: u8, 41 | pub gid_index: u8, 42 | } 43 | 44 | impl VerbsDeviceInfo { 45 | pub fn new(device_list: Arc, device_index: usize) -> Self { 46 | // TODO: port_num 47 | // TODO: gid_index 48 | Self { device_list, device_index, port_num: 1, gid_index: 0 } 49 | } 50 | 51 | pub fn device(&self) -> *mut ibv_device { 52 | unsafe { *self.device_list.list.add(self.device_index) } 53 | } 54 | } 55 | 56 | impl RdmaDomainInfo for VerbsDeviceInfo { 57 | fn name(&self) -> Cow<'_, str> { 58 | unsafe { CStr::from_ptr((*self.device()).name.as_ptr()).to_string_lossy() } 59 | } 60 | 61 | fn link_speed(&self) -> u64 { 62 | let path = format!( 63 | "{}/ports/{}/rate", 64 | unsafe { 65 | CStr::from_ptr((*self.device()).ibdev_path.as_ptr()).to_string_lossy() 66 | }, 67 | self.port_num 68 | ); 69 | match std::fs::read_to_string(path) { 70 | Ok(content) => { 71 | let trimmed = content.trim(); 72 | let end_pos = trimmed.find(' ').unwrap_or(trimmed.len()); 73 | let gbps: f64 = trimmed[..end_pos].parse().unwrap_or(0.0); 74 | (gbps * 1e9) as u64 75 | } 76 | Err(_) => 0, 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /fabric-lib/src/verbs/verbs_address.rs: -------------------------------------------------------------------------------- 1 | use bytes::Bytes; 2 | use libibverbs_sys::ibv_gid; 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use crate::{api::DomainAddress, utils::hex::fmt_hex}; 6 | 7 | #[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] 8 | #[repr(transparent)] 9 | pub struct Gid { 10 | pub raw: [u8; 16], 11 | } 12 | 13 | #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] 14 | pub struct VerbsUDAddress { 15 | pub gid: Gid, 16 | pub lid: u16, 17 | pub qp_num: u32, 18 | pub qkey: u32, 19 | } 20 | 21 | #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] 22 | pub struct VerbsRCAddress { 23 | pub gid: Gid, 24 | pub lid: u16, 25 | pub qp_num: u32, 26 | pub psn: u32, 27 | } 28 | 29 | impl std::fmt::Debug for Gid { 30 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 31 | fmt_hex(f, &self.raw) 32 | } 33 | } 34 | 35 | impl std::fmt::Display for Gid { 36 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 37 | fmt_hex(f, &self.raw) 38 | } 39 | } 40 | 41 | impl From for ibv_gid { 42 | fn from(gid: Gid) -> Self { 43 | Self { raw: gid.raw } 44 | } 45 | } 46 | 47 | impl VerbsUDAddress { 48 | const BYTES: usize = 26; 49 | const _SIZE_CHECK: () = 50 | assert!(std::mem::size_of::() == Self::BYTES); 51 | 52 | pub fn to_bytes(&self) -> [u8; Self::BYTES] { 53 | let mut bytes = [0; Self::BYTES]; 54 | bytes[..16].copy_from_slice(&self.gid.raw); 55 | bytes[16..18].copy_from_slice(&self.lid.to_le_bytes()); 56 | bytes[18..22].copy_from_slice(&self.qp_num.to_le_bytes()); 57 | bytes[22..26].copy_from_slice(&self.qkey.to_le_bytes()); 58 | bytes 59 | } 60 | 61 | pub fn from_bytes(bytes: &[u8]) -> Option { 62 | // TODO: make it more idiomatic. 63 | if bytes.len() != Self::BYTES { 64 | return None; 65 | } 66 | unsafe { 67 | Some(Self { 68 | gid: Gid { raw: bytes[..16].try_into().unwrap_unchecked() }, 69 | lid: u16::from_le_bytes(bytes[16..18].try_into().unwrap_unchecked()), 70 | qp_num: u32::from_le_bytes(bytes[18..22].try_into().unwrap_unchecked()), 71 | qkey: u32::from_le_bytes(bytes[22..26].try_into().unwrap_unchecked()), 72 | }) 73 | } 74 | } 75 | } 76 | 77 | impl From<&VerbsUDAddress> for DomainAddress { 78 | fn from(addr: &VerbsUDAddress) -> Self { 79 | DomainAddress(Bytes::copy_from_slice(&addr.to_bytes())) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /python/pplx_garden/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | """Logging utilities.""" 2 | 3 | import logging 4 | import os 5 | import warnings 6 | 7 | 8 | class LoggerFormatter(logging.Formatter): 9 | LEVEL_COLOURS = [ 10 | (logging.DEBUG, "\x1b[40;1m"), 11 | (logging.INFO, "\x1b[34;1m"), 12 | (logging.WARNING, "\x1b[33;1m"), 13 | (logging.ERROR, "\x1b[31m"), 14 | (logging.CRITICAL, "\x1b[41m"), 15 | ] 16 | 17 | FORMATS = { 18 | level: logging.Formatter( 19 | f"\x1b[36;1m[%(asctime)s.%(msecs)03d]\x1b[0m %(process)d {colour}%(levelname)-8s" 20 | f"\x1b[0m \x1b[32m%(name)s\x1b[0m %(message)s", 21 | "%Y-%m-%d %H:%M:%S", 22 | ) 23 | for level, colour in LEVEL_COLOURS 24 | } 25 | 26 | def format(self, record: logging.LogRecord) -> str: 27 | """Format the log record.""" 28 | formatter = self.FORMATS.get(record.levelno) 29 | if formatter is None: 30 | formatter = self.FORMATS[logging.DEBUG] 31 | 32 | # Override the traceback to always print in red 33 | if record.exc_info: 34 | text = formatter.formatException(record.exc_info) 35 | record.exc_text = f"\x1b[31m{text}\x1b[0m" 36 | 37 | output = formatter.format(record) 38 | 39 | # Remove the cache layer 40 | record.exc_text = None 41 | return output 42 | 43 | 44 | _IS_SETUP = False 45 | 46 | 47 | def setup( 48 | *, 49 | handler: logging.Handler | None = None, 50 | level: str | int | None = None, 51 | ) -> None: 52 | """Setup the logging.""" 53 | global _IS_SETUP # pylint: disable=global-statement 54 | if _IS_SETUP: 55 | return 56 | _IS_SETUP = True 57 | 58 | level = level or logging.INFO 59 | handler = handler or logging.StreamHandler() 60 | 61 | # Check if DD_ENV is set, if so use plain formatter without colors 62 | if os.environ.get("DD_ENV"): 63 | formatter = logging.Formatter( 64 | "[%(asctime)s.%(msecs)03d] %(process)d %(levelname)-8s %(name)s %(message)s", 65 | "%Y-%m-%d %H:%M:%S", 66 | ) 67 | else: 68 | formatter = LoggerFormatter() 69 | 70 | handler.setFormatter(formatter) 71 | 72 | logger = logging.getLogger() 73 | logger.setLevel(level) 74 | logger.handlers.clear() 75 | logger.addHandler(handler) 76 | 77 | # Filter out UserWarning and FutureWarning 78 | warnings.filterwarnings("ignore", category=UserWarning) 79 | warnings.filterwarnings("ignore", category=FutureWarning) 80 | 81 | 82 | def get_logger(path: str) -> logging.Logger: 83 | """Return logger for the given path.""" 84 | return logging.getLogger(path) 85 | -------------------------------------------------------------------------------- /fabric-lib/src/utils/memory.rs: -------------------------------------------------------------------------------- 1 | use std::{mem::MaybeUninit, ptr::NonNull}; 2 | 3 | /// A simple memory pool that manages fixed-size chunks of memory. 4 | /// 5 | /// This pool pre-allocates a contiguous buffer and manages allocation/deallocation 6 | /// of fixed-size chunks within that buffer. It uses unsafe operations for performance 7 | /// and does not perform bounds checking on freed pointers. 8 | pub struct MemoryPool { 9 | chunk_size: usize, 10 | num_chunks: usize, 11 | buffer: Vec>, 12 | free_list: Vec>, 13 | } 14 | 15 | impl MemoryPool { 16 | /// Creates a new memory pool with the specified chunk size and number of chunks. 17 | /// 18 | /// All chunks are initially available for allocation. 19 | pub fn new(chunk_size: usize, num_chunks: usize) -> Self { 20 | let mut buffer = Vec::with_capacity(chunk_size * num_chunks); 21 | unsafe { buffer.set_len(chunk_size * num_chunks) }; 22 | let ptr = unsafe { NonNull::new_unchecked(buffer.as_mut_ptr() as *mut u8) }; 23 | let free_list = 24 | (0..num_chunks).map(|i| unsafe { ptr.byte_add(i * chunk_size) }).collect(); 25 | Self { chunk_size, num_chunks, buffer, free_list } 26 | } 27 | 28 | /// Returns the size of each chunk in bytes. 29 | pub fn chunk_size(&self) -> usize { 30 | self.chunk_size 31 | } 32 | 33 | /// Returns the total number of chunks in the pool. 34 | #[allow(dead_code)] 35 | pub fn num_chunks(&self) -> usize { 36 | self.num_chunks 37 | } 38 | 39 | /// Returns a pointer to the start of the underlying buffer. 40 | pub fn buffer_ptr(&self) -> NonNull { 41 | unsafe { NonNull::new_unchecked(self.buffer.as_ptr() as *mut u8) } 42 | } 43 | 44 | /// Returns the length of the underlying buffer in bytes. 45 | pub fn buffer_len(&self) -> usize { 46 | self.buffer.len() 47 | } 48 | 49 | /// Allocates a uninitialized memory chunk from the pool. 50 | /// Returns `None` if no chunks are available. 51 | /// 52 | /// # Safety 53 | /// 54 | /// The caller must ensure that the returned pointer is not used after being freed. 55 | pub unsafe fn alloc(&mut self) -> Option> { 56 | self.free_list.pop() 57 | } 58 | 59 | /// Frees a previously allocated chunk back to the pool. 60 | /// 61 | /// # Safety 62 | /// 63 | /// The caller must ensure that: 64 | /// - `ptr` was previously returned by `alloc()` from this pool 65 | /// - `ptr` is not used after being freed 66 | /// - `ptr` is not freed multiple times 67 | pub unsafe fn free(&mut self, ptr: NonNull) { 68 | self.free_list.push(ptr); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /fabric-lib/src/mr.rs: -------------------------------------------------------------------------------- 1 | use std::{ffi::c_void, ptr::NonNull}; 2 | 3 | use cuda_lib::driver::cu_get_dma_buf_fd; 4 | use cuda_lib::rt::{cudaMemoryTypeDevice, cudaPointerGetAttributes}; 5 | use cuda_lib::{CudaDeviceId, Device}; 6 | use once_cell::sync::Lazy; 7 | 8 | use crate::error::{FabricLibError, Result}; 9 | 10 | #[derive(Debug, PartialEq, Eq, Hash)] 11 | pub enum Mapping { 12 | Host, 13 | Device { device_id: CudaDeviceId, dmabuf_fd: Option }, 14 | } 15 | 16 | #[derive(Debug, PartialEq, Eq, Hash)] 17 | pub struct MemoryRegion { 18 | ptr: NonNull, 19 | len: usize, 20 | mapping: Mapping, 21 | } 22 | 23 | impl MemoryRegion { 24 | pub fn new(ptr: NonNull, len: usize, device: Device) -> Result { 25 | let mapping = match device { 26 | Device::Host => Mapping::Host, 27 | Device::Cuda(device_id) => { 28 | let attrs = cudaPointerGetAttributes(ptr)?; 29 | if attrs.type_ != cudaMemoryTypeDevice { 30 | return Err(FabricLibError::Custom("not a device pointer")); 31 | } 32 | let dmabuf_fd = if linux_kernel_supports_dma_buf() { 33 | cu_get_dma_buf_fd(ptr, len).ok() 34 | } else { 35 | None 36 | }; 37 | Mapping::Device { device_id, dmabuf_fd } 38 | } 39 | }; 40 | Ok(MemoryRegion { ptr, len, mapping }) 41 | } 42 | 43 | pub fn ptr(&self) -> NonNull { 44 | self.ptr 45 | } 46 | 47 | pub fn len(&self) -> usize { 48 | self.len 49 | } 50 | 51 | pub fn mapping(&self) -> &Mapping { 52 | &self.mapping 53 | } 54 | } 55 | 56 | impl Drop for MemoryRegion { 57 | fn drop(&mut self) { 58 | match self.mapping { 59 | Mapping::Host => {} 60 | Mapping::Device { dmabuf_fd: None, .. } => {} 61 | Mapping::Device { dmabuf_fd: Some(dmabuf_fd), .. } => unsafe { 62 | libc::close(dmabuf_fd); 63 | }, 64 | } 65 | } 66 | } 67 | 68 | /// A local descriptor for a memory region. 69 | /// For verbs, this is the MR LKEY. 70 | /// For libfabric, this is the MR descriptor. 71 | #[derive(Debug, Clone, Copy)] 72 | #[repr(transparent)] 73 | pub struct MemoryRegionLocalDescriptor(pub u64); 74 | 75 | static LINUX_KERNEL_SUPPORTS_DMA_BUF: Lazy = Lazy::new(|| { 76 | let Ok(version) = std::fs::read_to_string("/proc/sys/kernel/osrelease") else { 77 | return false; 78 | }; 79 | let mut parts = version.split('.'); 80 | let major: u32 = parts.next().and_then(|s| s.parse().ok()).unwrap_or(0); 81 | let minor: u32 = parts.next().and_then(|s| s.parse().ok()).unwrap_or(0); 82 | 83 | (major, minor) >= (5, 12) 84 | }); 85 | 86 | fn linux_kernel_supports_dma_buf() -> bool { 87 | *LINUX_KERNEL_SUPPORTS_DMA_BUF 88 | } 89 | -------------------------------------------------------------------------------- /rust/build-utils/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | env, 3 | path::{Path, PathBuf}, 4 | }; 5 | 6 | /// Finds the path to a package directory by checking an environment variable and a list of default paths. 7 | /// 8 | /// The function checks if the environment variable `env_var` is set and points to a directory containing `check_file`. 9 | /// If not, it searches each path in `default_paths` for the presence of `check_file`. 10 | /// Returns the first directory containing `check_file`, or panics if none is found. 11 | /// 12 | /// # Arguments 13 | /// * `env_var` - The name of the environment variable to check. 14 | /// * `default_paths` - A slice of default directory paths to search. 15 | /// * `check_file` - The relative path to the file that must exist in the directory. 16 | /// 17 | /// # Panics 18 | /// Panics if neither the environment variable nor any of the default paths contain `check_file`. 19 | pub fn find_package( 20 | env_var: &str, 21 | default_paths: &[&str], 22 | check_file: &str, 23 | ) -> PathBuf { 24 | println!("cargo:rerun-if-env-changed={}", env_var); 25 | env::var_os(env_var) 26 | .map(PathBuf::from) 27 | .into_iter() 28 | .chain(default_paths.iter().map(PathBuf::from)) 29 | .find(|dir| dir.join(check_file).is_file()) 30 | .unwrap_or_else(|| { 31 | panic!( 32 | "find_package: {} is not set and {} is not found in the default paths", 33 | env_var, check_file 34 | ) 35 | }) 36 | } 37 | 38 | /// Recursively emits `cargo:rerun-if-changed` for all files under `src_dir` 39 | /// with one of the given `extensions`. 40 | /// 41 | /// Example: 42 | /// ```no_run 43 | /// use build_utils::emit_rerun_if_changed_files; 44 | /// emit_rerun_if_changed_files("src", &["cu", "cuh", "h"]); 45 | /// ``` 46 | pub fn emit_rerun_if_changed_files(src_dir: &str, extensions: &[&str]) { 47 | fn visit_dir(dir: &Path, extensions: &[&str]) -> std::io::Result<()> { 48 | for entry in std::fs::read_dir(dir)? { 49 | let entry = entry?; 50 | let path = entry.path(); 51 | if path.is_dir() { 52 | visit_dir(&path, extensions)?; 53 | } else if let Some(ext) = path.extension().and_then(|s| s.to_str()) 54 | && extensions.contains(&ext) 55 | { 56 | println!("cargo:rerun-if-changed={}", path.display()); 57 | } 58 | } 59 | Ok(()) 60 | } 61 | 62 | let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); 63 | let root = manifest_dir.join(src_dir); 64 | 65 | if let Err(err) = visit_dir(&root, extensions) { 66 | eprintln!("cargo:warning=Failed to scan {}: {}", root.display(), err); 67 | } 68 | 69 | // Also watch the directory itself so new files trigger rebuilds 70 | println!("cargo:rerun-if-changed={}", root.display()); 71 | } 72 | -------------------------------------------------------------------------------- /python/pplx_garden/native/p2p_all_to_all.pyi: -------------------------------------------------------------------------------- 1 | # ruff: noqa: A002 2 | 3 | import torch 4 | 5 | from pplx_garden.fabric_lib import ( 6 | DomainAddress, 7 | MemoryRegionDescriptor, 8 | MemoryRegionHandle, 9 | TransferEngine, 10 | ) 11 | 12 | class AllToAllContext: 13 | @classmethod 14 | def create( 15 | cls, 16 | hidden_dim: int, 17 | hidden_dim_scale: int | None, 18 | in_elemsize: int, 19 | out_elemsize: int, 20 | out_dtype: torch.dtype, 21 | scale_elemsize: int | None, 22 | max_num_tokens: int, 23 | max_recv_tokens: int, 24 | max_private_tokens: int, 25 | num_experts: int, 26 | expert_padding: int, 27 | num_experts_per_token: int, 28 | rank: int, 29 | dp_size: int, 30 | node_size: int, 31 | world_size: int, 32 | num_routed_ptr: int, 33 | num_routed_mr: MemoryRegionHandle, 34 | send_buffer_ptr: int, 35 | send_buffer_mr: MemoryRegionHandle, 36 | recv_buffer_ptr: int, 37 | recv_buffer_mr: MemoryRegionHandle, 38 | sync_ptrs: list[int], 39 | send_ptrs: list[int], 40 | recv_ptrs: list[int], 41 | device: int, 42 | imm_base: int, 43 | ranks: list[ 44 | tuple[ 45 | DomainAddress, 46 | MemoryRegionDescriptor, 47 | MemoryRegionDescriptor, 48 | ] 49 | ], 50 | transfer_engine: TransferEngine, 51 | worker_cpu: int | None, 52 | ) -> None: ... 53 | def dispatch_send( 54 | self, 55 | num_tokens: int, 56 | x_ptr: int, 57 | x_stride: int, 58 | x_scale_ptr: int | None, 59 | x_scale_stride_elem: int | None, 60 | x_scale_stride_token: int | None, 61 | indices_ptr: int, 62 | indices_stride: int, 63 | weights_ptr: int, 64 | weights_stride: int, 65 | bound_m_ptr: int | None, 66 | stream: int, 67 | ) -> None: ... 68 | def dispatch_recv( 69 | self, 70 | out_num_tokens_ptr: int, 71 | out_x_ptr: int, 72 | out_x_stride: int, 73 | out_x_scale_ptr: int | None, 74 | out_x_scale_stride_elem: int | None, 75 | out_x_scale_stride_token: int | None, 76 | stream: int, 77 | ) -> None: ... 78 | def combine_send( 79 | self, 80 | expert_x_ptr: int, 81 | expert_x_stride: int, 82 | stream: int, 83 | ) -> None: ... 84 | def combine_recv( 85 | self, 86 | num_tokens: int, 87 | num_recv_tokens: int, 88 | expert_y_dtype: torch.dtype, 89 | out_tokens_ptr: int, 90 | out_tokens_stride: int, 91 | indices_ptr: int, 92 | indices_stride: int, 93 | weights_ptr: int, 94 | weights_stride: int, 95 | bound_m_ptr: int | None, 96 | accumulate: bool, 97 | stream: int, 98 | ) -> None: ... 99 | -------------------------------------------------------------------------------- /python/pplx_garden/utils/math.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | 4 | 5 | def round_up(n: int, m: int) -> int: 6 | return (n + m - 1) // m * m 7 | 8 | 9 | def ceil_div(n: int, m: int) -> int: 10 | return (n + m - 1) // m 11 | 12 | 13 | def floor_div(n: int, m: int) -> int: 14 | return n // m 15 | 16 | 17 | def stddev(xs: list[float]) -> float: 18 | if len(xs) <= 1: 19 | return 0.0 20 | 21 | n = len(xs) 22 | m = 0.0 23 | s = 0.0 24 | for k, x in enumerate(xs): 25 | old_m = m 26 | m = m + (x - m) / (k + 1) 27 | s = s + (x - m) * (x - old_m) 28 | variance = s / (n - 1) 29 | return math.sqrt(variance) 30 | 31 | 32 | def mean(xs: list[float]) -> float: 33 | if len(xs) == 0: 34 | return 0.0 35 | return sum(xs) / len(xs) 36 | 37 | 38 | def mean_and_stddev(xs: list[float]) -> tuple[float, float]: 39 | return mean(xs), stddev(xs) 40 | 41 | 42 | @dataclass 43 | class Statistics: 44 | mean: float 45 | stddev: float 46 | min: float 47 | p01: float 48 | p25: float 49 | p50: float 50 | p75: float 51 | p95: float 52 | p99: float 53 | max: float 54 | 55 | @classmethod 56 | def create(cls, xs: list[float]) -> "Statistics": 57 | if len(xs) == 0: 58 | return cls( 59 | mean=0.0, 60 | stddev=0.0, 61 | min=0.0, 62 | p01=0.0, 63 | p25=0.0, 64 | p50=0.0, 65 | p75=0.0, 66 | p95=0.0, 67 | p99=0.0, 68 | max=0.0, 69 | ) 70 | 71 | n = len(xs) 72 | sorted_xs = sorted(xs) 73 | mean, stddev = mean_and_stddev(xs) 74 | 75 | def percentile(p: float) -> float: 76 | index = int(n * p) 77 | if n * p == index or index + 1 >= n: 78 | return sorted_xs[index] 79 | return (sorted_xs[index] + sorted_xs[index + 1]) / 2 80 | 81 | return cls( 82 | mean=mean, 83 | stddev=stddev, 84 | min=sorted_xs[0], 85 | p01=percentile(0.01), 86 | p25=percentile(0.25), 87 | p50=percentile(0.5), 88 | p75=percentile(0.75), 89 | p95=percentile(0.95), 90 | p99=percentile(0.99), 91 | max=sorted_xs[-1], 92 | ) 93 | 94 | def __str__(self) -> str: 95 | fields = [ 96 | ("min", self.min), 97 | ("p01", self.p01), 98 | ("p25", self.p25), 99 | ("p50", self.p50), 100 | ("p75", self.p75), 101 | ("p95", self.p95), 102 | ("p99", self.p99), 103 | ("max", self.max), 104 | ] 105 | 106 | desc = f"mean={self.mean:.1f}±{self.stddev:.1f} μs" 107 | for name, value in fields: 108 | desc += f", {name}={value:.1f} μs" 109 | return desc 110 | -------------------------------------------------------------------------------- /rust/logging-lib/src/lib.rs: -------------------------------------------------------------------------------- 1 | use clap::{Parser, ValueEnum}; 2 | use is_terminal::IsTerminal; 3 | use tracing::{Dispatch, dispatcher}; 4 | use tracing_log::AsLog; 5 | use tracing_subscriber::EnvFilter; 6 | 7 | #[derive(Debug, Parser)] 8 | pub struct LoggingOpts { 9 | #[clap(long, env = "PPLX_LOG_FORMAT", default_value = "json")] 10 | pub log_format: LogFormat, 11 | 12 | #[clap(long, env = "PPLX_LOG_COLOR", default_value = "auto")] 13 | pub log_color: LogColor, 14 | 15 | /// Additional debug level flags in the RUST_LOG format to configure loggin on a 16 | /// per-target basis. If both this and RUST_LOG set a log level for a target, 17 | /// the RUST_LOG setting will take priority. 18 | pub log_directives: Option, 19 | } 20 | 21 | pub fn init(opts: &LoggingOpts) -> Result<(), anyhow::Error> { 22 | let color = match opts.log_color { 23 | // tracing_subscriber::fmt uses stdout: 24 | // https://docs.rs/tracing-subscriber/latest/tracing_subscriber/fmt/index.html 25 | LogColor::Auto => std::io::stdout().is_terminal(), 26 | 27 | LogColor::Always => true, 28 | LogColor::Never => false, 29 | }; 30 | 31 | // Get log levels from whatever directives were passed, if any, 32 | // then override with what's in the RUST_LOG env var 33 | let mut log_filter_builder = EnvFilter::builder(); 34 | if let Some(directives) = &opts.log_directives { 35 | log_filter_builder = 36 | log_filter_builder.with_default_directive(directives.parse()?); 37 | } 38 | 39 | let log_filter = log_filter_builder.from_env_lossy(); 40 | let builder = tracing_subscriber::fmt().with_env_filter(log_filter); 41 | 42 | #[cfg(test)] 43 | let builder = builder.with_test_writer(); 44 | 45 | #[cfg(not(test))] 46 | let builder = builder.with_writer(std::io::stderr); 47 | 48 | let dispatch: Dispatch = match opts.log_format { 49 | LogFormat::Text => { 50 | let subscriber = builder.with_ansi(color).finish(); 51 | subscriber.into() 52 | } 53 | LogFormat::Json => { 54 | let subscriber = builder.json().finish(); 55 | subscriber.into() 56 | } 57 | }; 58 | dispatcher::set_global_default(dispatch)?; 59 | 60 | tracing_log::LogTracer::builder() 61 | // Note that we must call this *after* setting the global default 62 | // subscriber, so that we get its max level hint. 63 | .with_max_level(tracing_core::LevelFilter::current().as_log()) 64 | .init()?; 65 | Ok(()) 66 | } 67 | 68 | #[derive(Debug, Clone, PartialEq, Eq, ValueEnum)] 69 | pub enum LogFormat { 70 | Text, 71 | Json, 72 | } 73 | 74 | #[derive(Debug, Clone, PartialEq, Eq, ValueEnum)] 75 | pub enum LogColor { 76 | Auto, 77 | Always, 78 | Never, 79 | } 80 | 81 | impl std::str::FromStr for LogColor { 82 | type Err = String; 83 | 84 | fn from_str(s: &str) -> Result { 85 | match s { 86 | "auto" => Ok(LogColor::Auto), 87 | "always" => Ok(LogColor::Always), 88 | "never" => Ok(LogColor::Never), 89 | s => Err(format!( 90 | "{s} is not a valid option, expected `auto`, `always` or `never`" 91 | )), 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /fabric-lib/src/efa/efa_devinfo.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | borrow::Cow, 3 | ffi::CStr, 4 | ptr::{NonNull, null, null_mut}, 5 | }; 6 | 7 | use libfabric_sys::{ 8 | FI_ENOMEM, FI_EP_RDM, FI_HMEM, FI_LOCAL_COMM, FI_MR_ALLOCATED, FI_MR_HMEM, 9 | FI_MR_LOCAL, FI_MR_PROV_KEY, FI_MR_VIRT_ADDR, FI_MSG, FI_REMOTE_COMM, FI_RMA, 10 | FI_THREAD_DOMAIN, fi_dupinfo, fi_freeinfo, fi_getinfo, fi_info, make_fi_version, 11 | }; 12 | 13 | use crate::{ 14 | error::{LibfabricError, Result}, 15 | provider::RdmaDomainInfo, 16 | }; 17 | 18 | pub struct EfaDomainInfo { 19 | pub fi: NonNull, 20 | } 21 | 22 | unsafe impl Send for EfaDomainInfo {} 23 | unsafe impl Sync for EfaDomainInfo {} 24 | 25 | impl EfaDomainInfo { 26 | pub fn dup(info: NonNull) -> Self { 27 | // Copy fi_info. fi_dupinfo does not copy next. 28 | let fi = 29 | NonNull::new(unsafe { fi_dupinfo(info.as_ptr()) }).expect("fi_dupinfo"); 30 | EfaDomainInfo { fi } 31 | } 32 | 33 | pub fn fi(&self) -> NonNull { 34 | self.fi 35 | } 36 | } 37 | 38 | impl RdmaDomainInfo for EfaDomainInfo { 39 | fn name(&self) -> Cow<'_, str> { 40 | unsafe { 41 | CStr::from_ptr((*(*self.fi.as_ptr()).domain_attr).name).to_string_lossy() 42 | } 43 | } 44 | 45 | fn link_speed(&self) -> u64 { 46 | unsafe { (*(*(*self.fi.as_ptr()).nic).link_attr).speed as u64 } 47 | } 48 | } 49 | 50 | impl Clone for EfaDomainInfo { 51 | fn clone(&self) -> Self { 52 | EfaDomainInfo::dup(self.fi) 53 | } 54 | } 55 | 56 | impl Drop for EfaDomainInfo { 57 | fn drop(&mut self) { 58 | unsafe { fi_freeinfo(self.fi.as_ptr()) } 59 | } 60 | } 61 | 62 | pub fn get_efa_domains() -> Result> { 63 | let mut vec = Vec::new(); 64 | unsafe { 65 | let mut hints = NonNull::new(fi_dupinfo(null())) 66 | .ok_or_else(|| LibfabricError::new(FI_ENOMEM as i32, "fi_dupinfo"))?; 67 | let h = hints.as_mut(); 68 | h.caps = 69 | FI_MSG as u64 | FI_RMA as u64 | FI_HMEM | FI_LOCAL_COMM | FI_REMOTE_COMM; 70 | (*h.ep_attr).type_ = FI_EP_RDM; 71 | (*h.fabric_attr).prov_name = c"efa".as_ptr() as *mut libc::c_char; 72 | (*h.domain_attr).mr_mode = (FI_MR_LOCAL 73 | | FI_MR_HMEM 74 | | FI_MR_VIRT_ADDR 75 | | FI_MR_ALLOCATED 76 | | FI_MR_PROV_KEY) as i32; 77 | (*h.domain_attr).threading = FI_THREAD_DOMAIN; 78 | 79 | let mut info = null_mut(); 80 | let ret = fi_getinfo(make_fi_version(1, 22), null(), null(), 0, h, &mut info); 81 | 82 | // Avoid fi_freeinfo freeing prov_name 83 | (*h.fabric_attr).prov_name = null_mut(); 84 | fi_freeinfo(h); 85 | 86 | let info = 87 | NonNull::new(info).ok_or_else(|| LibfabricError::new(ret, "fi_getinfo"))?; 88 | let mut fi = info; 89 | loop { 90 | vec.push(EfaDomainInfo::dup(fi)); 91 | let Some(next) = NonNull::new((*fi.as_ptr()).next) else { 92 | break; 93 | }; 94 | fi = next; 95 | } 96 | 97 | fi_freeinfo(info.as_ptr()); 98 | } 99 | Ok(vec) 100 | } 101 | -------------------------------------------------------------------------------- /rust/cuda-lib/src/rt.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | ffi::{CStr, c_void}, 3 | ptr::NonNull, 4 | }; 5 | 6 | pub type CudaResult = std::result::Result; 7 | 8 | #[derive(Clone, Debug)] 9 | pub struct CudartError { 10 | pub code: u32, 11 | pub context: &'static str, 12 | } 13 | 14 | impl CudartError { 15 | pub fn new(code: u32, context: &'static str) -> Self { 16 | Self { code, context } 17 | } 18 | } 19 | 20 | impl std::fmt::Display for CudartError { 21 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 22 | write!( 23 | f, 24 | "CudartError: code {} ({:?}), context: {}", 25 | self.code, 26 | unsafe { CStr::from_ptr(cudart_sys::cudaGetErrorString(self.code)) }, 27 | self.context 28 | ) 29 | } 30 | } 31 | 32 | impl std::error::Error for CudartError {} 33 | 34 | pub use cudart_sys::{cudaMemoryTypeDevice, cudaPointerAttributes}; 35 | pub fn cudaPointerGetAttributes( 36 | ptr: NonNull, 37 | ) -> CudaResult { 38 | let mut attrs = cudaPointerAttributes::default(); 39 | let ret = 40 | unsafe { cudart_sys::cudaPointerGetAttributes(&raw mut attrs, ptr.as_ptr()) }; 41 | match ret { 42 | 0 => Ok(attrs), 43 | _ => Err(CudartError::new(ret, "cudaPointerGetAttributes")), 44 | } 45 | } 46 | 47 | pub use cudart_sys::cudaDeviceProp; 48 | pub fn cudaGetDeviceProperties(device: i32) -> CudaResult { 49 | let mut prop = cudaDeviceProp::default(); 50 | let ret = unsafe { cudart_sys::cudaGetDeviceProperties(&raw mut prop, device) }; 51 | match ret { 52 | 0 => Ok(prop), 53 | _ => Err(CudartError::new(ret, "cudaGetDeviceProperties")), 54 | } 55 | } 56 | 57 | pub fn cudaGetDeviceCount() -> CudaResult { 58 | let mut count = 0; 59 | let ret = unsafe { cudart_sys::cudaGetDeviceCount(&raw mut count) }; 60 | match ret { 61 | 0 => Ok(count), 62 | _ => Err(CudartError::new(ret, "cudaGetDeviceCount")), 63 | } 64 | } 65 | 66 | pub fn cudaSetDevice(device: i32) -> CudaResult<()> { 67 | let ret = unsafe { cudart_sys::cudaSetDevice(device) }; 68 | match ret { 69 | 0 => Ok(()), 70 | _ => Err(CudartError::new(ret, "cudaSetDevice")), 71 | } 72 | } 73 | 74 | pub fn cudaHostAlloc(size: usize, flags: u32) -> CudaResult> { 75 | let mut ptr = std::ptr::null_mut(); 76 | let ret = unsafe { cudart_sys::cudaHostAlloc(&raw mut ptr, size, flags) }; 77 | match ret { 78 | 0 => Ok(NonNull::new(ptr).unwrap()), 79 | _ => Err(CudartError::new(ret, "cudaHostAlloc")), 80 | } 81 | } 82 | 83 | pub fn cudaFreeHost(ptr: NonNull) -> CudaResult<()> { 84 | let ret = unsafe { cudart_sys::cudaFreeHost(ptr.as_ptr()) }; 85 | match ret { 86 | 0 => Ok(()), 87 | _ => Err(CudartError::new(ret, "cudaFreeHost")), 88 | } 89 | } 90 | 91 | pub fn cudaGetNumSMs(device: u8) -> CudaResult { 92 | let mut numSMs = 0; 93 | let ret = unsafe { 94 | cuda_sys::cuDeviceGetAttribute( 95 | &mut numSMs, 96 | cudart_sys::cudaDevAttrMultiProcessorCount, 97 | device as i32, 98 | ) 99 | }; 100 | match ret { 101 | 0 => Ok(numSMs as usize), 102 | _ => Err(CudartError::new(ret, "cudaGetNumSMs")), 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /fabric-lib/src/utils/obj_pool.rs: -------------------------------------------------------------------------------- 1 | use std::mem::MaybeUninit; 2 | use std::ptr::NonNull; 3 | 4 | /// A growable, chunked object pool with stable addresses suitable for FFI. 5 | /// 6 | /// - Grows by appending fixed-size chunks (each chunk is a separate Box allocation). 7 | /// - Previously returned pointers remain valid for the lifetime of the pool. 8 | /// - Not thread-safe. Wrap in a Mutex if needed. 9 | /// - Safety: You must call `free_and_drop` to run `T`'s destructor if `T: Drop`. 10 | pub struct ObjectPool { 11 | chunks: Vec]>>, 12 | chunk_size: usize, 13 | 14 | // Index of next unallocated slot within the current (last) chunk. 15 | // Ranges from 0..=chunk_size. If == chunk_size we need a new chunk. 16 | next_in_last: usize, 17 | 18 | // LIFO free list of previously freed slots (as raw pointers). 19 | free_list: Vec>>, 20 | } 21 | 22 | impl ObjectPool { 23 | /// Create a pool with a given chunk size. 24 | pub fn with_chunk_size(chunk_size: usize) -> Self { 25 | assert!(chunk_size > 0, "chunk_size must be > 0"); 26 | Self { 27 | chunks: Vec::new(), 28 | chunk_size, 29 | next_in_last: 0, 30 | free_list: Vec::with_capacity(chunk_size), 31 | } 32 | } 33 | 34 | /// Allocate an **uninitialized** slot. You must initialize it before reading. 35 | /// 36 | /// Safety: 37 | /// - Returned pointer must be written with a valid `T` before any read or `dealloc_init`. 38 | /// - Pointer remains valid until returned to this pool or the pool is dropped. 39 | pub unsafe fn alloc_uninit(&mut self) -> NonNull> { 40 | if let Some(p) = self.free_list.pop() { 41 | return p; 42 | } 43 | if self.chunks.is_empty() || self.next_in_last == self.chunk_size { 44 | // New chunk of MaybeUninit (elements stay uninitialized). 45 | let new = Box::<[MaybeUninit]>::new_uninit_slice(self.chunk_size); 46 | let new = unsafe { new.assume_init() }; 47 | self.chunks.push(new); 48 | self.next_in_last = 0; 49 | } 50 | let c = self.chunks.len() - 1; 51 | let s = self.next_in_last; 52 | self.next_in_last += 1; 53 | unsafe { NonNull::new_unchecked(self.chunks[c].as_mut_ptr().add(s)) } 54 | } 55 | 56 | /// Deallocate a previously allocated pointer and run its destructor. 57 | /// 58 | /// Safety: 59 | /// - `p` must have been returned by `alloc_uninit` of *this* pool, 60 | /// and not already deallocated. 61 | pub unsafe fn free_and_drop(&mut self, p: NonNull) { 62 | // SAFETY: caller guarantees p originated from this pool and is not freed twice. 63 | unsafe { std::ptr::drop_in_place(p.as_ptr()) }; 64 | self.free_list 65 | .push(unsafe { NonNull::new_unchecked(p.as_ptr() as *mut MaybeUninit) }); 66 | } 67 | 68 | /// Deallocate a previously allocated pointer without running its destructor. 69 | /// 70 | /// Safety: 71 | /// - `p` must have been returned by `alloc_uninit` of *this* pool, 72 | /// and not already deallocated. 73 | /// - Only use this if `p` was never initialized or T is POD. 74 | #[allow(dead_code)] 75 | pub unsafe fn free_no_drop(&mut self, p: *mut MaybeUninit) { 76 | debug_assert!(!p.is_null()); 77 | // SAFETY: caller guarantees p originated from this pool and is not freed twice. 78 | self.free_list.push(unsafe { NonNull::new_unchecked(p) }); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /tests/p2p_all_to_all/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | 7 | def rand_topk_idx( 8 | num_tokens: int, 9 | num_experts: int, 10 | num_topk: int, 11 | generator: torch.Generator, 12 | device: torch.device, 13 | ) -> torch.Tensor: 14 | scores = torch.randn( 15 | (num_tokens, num_experts), 16 | dtype=torch.float32, 17 | device=device, 18 | generator=generator, 19 | ) 20 | scores = scores.abs() + 1 21 | topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] 22 | return topk_idx.to(torch.uint32) 23 | 24 | 25 | @dataclass 26 | class RankTestData: 27 | indices: torch.Tensor 28 | weights: torch.Tensor 29 | dp_x: torch.Tensor 30 | dp_x_scale: Optional[torch.Tensor] 31 | bound_m: Optional[torch.Tensor] 32 | expected_num_tokens: torch.Tensor 33 | 34 | @classmethod 35 | def rand_indices_and_count( 36 | cls, 37 | num_experts: int, 38 | num_experts_per_token: int, 39 | max_num_tokens: int, 40 | generator: torch.Generator, 41 | device: torch.device, 42 | ) -> tuple[torch.Tensor, torch.Tensor]: 43 | indices = rand_topk_idx( 44 | max_num_tokens, 45 | num_experts, 46 | num_experts_per_token, 47 | generator, 48 | device, 49 | ) 50 | expected_num_tokens = torch.bincount( 51 | indices.flatten().long(), 52 | minlength=num_experts, 53 | ).to(torch.int32) 54 | return indices, expected_num_tokens 55 | 56 | @classmethod 57 | def create( 58 | cls, 59 | *, 60 | num_experts: int, 61 | num_experts_per_token: int, 62 | max_num_tokens: int, 63 | hidden_dim: int, 64 | hidden_dim_scale: Optional[int], 65 | in_dtype: torch.dtype, 66 | scale_dtype: Optional[torch.dtype], 67 | generator: torch.Generator, 68 | device: torch.device, 69 | ) -> "RankTestData": 70 | assert num_experts_per_token <= num_experts 71 | 72 | indices, expected_num_tokens = cls.rand_indices_and_count( 73 | num_experts, num_experts_per_token, max_num_tokens, generator, device 74 | ) 75 | dp_x = torch.randn( 76 | (max_num_tokens, hidden_dim), 77 | device=device, 78 | generator=generator, 79 | ).to(in_dtype) 80 | 81 | dp_x_scale: Optional[torch.Tensor] 82 | if hidden_dim_scale is not None or scale_dtype is not None: 83 | assert hidden_dim_scale is not None 84 | assert scale_dtype is not None 85 | dp_x_scale = torch.randn( 86 | (max_num_tokens, hidden_dim_scale), 87 | device=device, 88 | generator=generator, 89 | ).to(scale_dtype) 90 | else: 91 | dp_x_scale = None 92 | 93 | weights = torch.rand( 94 | (max_num_tokens, num_experts_per_token), 95 | dtype=torch.float32, 96 | device=device, 97 | generator=generator, 98 | ) 99 | weights = weights / torch.sum(weights, dim=-1, keepdim=True) 100 | 101 | return cls( 102 | dp_x=dp_x, 103 | dp_x_scale=dp_x_scale, 104 | indices=indices, 105 | weights=weights, 106 | expected_num_tokens=expected_num_tokens, 107 | bound_m=None, 108 | ) 109 | -------------------------------------------------------------------------------- /fabric-lib/src/transfer_engine_builder.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashSet; 2 | 3 | use anyhow::{Error, Result}; 4 | 5 | use crate::{ 6 | RdmaDomainInfo, provider_dispatch::DomainInfo, topo::detect_topology, 7 | transfer_engine::TransferEngine, worker::Worker, 8 | }; 9 | 10 | struct GpuDomainSpec { 11 | cuda_device: u8, 12 | domains: Vec, 13 | pin_worker_cpu: u16, 14 | pin_uvm_cpu: u16, 15 | } 16 | 17 | #[derive(Default)] 18 | pub struct TransferEngineBuilder { 19 | gpus: Vec, 20 | } 21 | 22 | impl TransferEngineBuilder { 23 | pub fn add_gpu_domains( 24 | &mut self, 25 | cuda_device: u8, 26 | domains: Vec, 27 | pin_worker_cpu: u16, 28 | pin_uvm_cpu: u16, 29 | ) { 30 | self.gpus.push(GpuDomainSpec { 31 | cuda_device, 32 | domains, 33 | pin_worker_cpu, 34 | pin_uvm_cpu, 35 | }) 36 | } 37 | 38 | pub fn build(&self) -> Result { 39 | let system_topo = detect_topology()?; 40 | 41 | // Validate that there's no duplicated GPUs 42 | let num_gpus = 43 | self.gpus.iter().map(|s| s.cuda_device).collect::>().len(); 44 | if num_gpus != self.gpus.len() { 45 | return Err(Error::msg("Duplicated GPUs in the builder")); 46 | } 47 | if num_gpus == 0 { 48 | return Err(Error::msg("No GPUs in the builder")); 49 | } 50 | 51 | // Validate builder params and prepare workers 52 | let mut workers = Vec::with_capacity(self.gpus.len()); 53 | for spec in self.gpus.iter() { 54 | let Some(topo) = 55 | system_topo.iter().find(|t| t.cuda_device == spec.cuda_device) 56 | else { 57 | return Err(Error::msg(format!( 58 | "cuda:{} not found in system topology", 59 | spec.cuda_device 60 | ))); 61 | }; 62 | 63 | let num_domains = 64 | spec.domains.iter().map(|d| d.name()).collect::>().len(); 65 | if num_domains != spec.domains.len() { 66 | return Err(Error::msg(format!( 67 | "Duplicated domains in cuda:{}", 68 | spec.cuda_device 69 | ))); 70 | } 71 | 72 | for d in spec.domains.iter() { 73 | if !topo.domains.iter().any(|t| t.name() == d.name()) { 74 | return Err(Error::msg(format!( 75 | "Domain {} not found in the topology group of cuda:{}", 76 | d.name(), 77 | spec.cuda_device 78 | ))); 79 | } 80 | } 81 | 82 | for cpu in &[spec.pin_worker_cpu, spec.pin_uvm_cpu] { 83 | if !topo.cpus.contains(cpu) { 84 | return Err(Error::msg(format!( 85 | "CPU {} not found in the topology group of cuda:{}", 86 | cpu, spec.cuda_device 87 | ))); 88 | } 89 | } 90 | 91 | let domain_list: Vec<_> = spec.domains.to_vec(); 92 | let worker = Worker { 93 | domain_list, 94 | pin_worker_cpu: Some(spec.pin_worker_cpu), 95 | pin_uvm_cpu: Some(spec.pin_uvm_cpu), 96 | }; 97 | workers.push((spec.cuda_device, worker)); 98 | } 99 | 100 | // Create the transfer engine. 101 | Ok(TransferEngine::new(workers)?) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=80", "setuptools-rust>=1.12", "setuptools-scm>=9"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "pplx_garden" 7 | description = "Perplexity AI open source garden" 8 | authors = [ 9 | { name = "Lequn Chen", email = "lequn@perplexity.ai" }, 10 | { name = "Nandor Licker", email = "nandor@perplexity.ai" }, 11 | ] 12 | dynamic = ["version"] 13 | dependencies = ["numpy"] 14 | 15 | [tool.setuptools] 16 | package-dir = {"" = "python"} 17 | 18 | [[tool.setuptools-rust.ext-modules]] 19 | target = "pplx_garden._rust" 20 | path = "python-ext/Cargo.toml" 21 | binding = "PyO3" 22 | 23 | [tool.setuptools_scm] 24 | version_file = "python/pplx_garden/_version.py" 25 | 26 | 27 | [tool.ruff] 28 | line-length = 88 29 | 30 | [tool.ruff.lint.isort] 31 | combine-as-imports = true 32 | known-first-party = ["pplx_garden", "tests", "benchmarks"] 33 | 34 | 35 | [tool.ruff.lint] 36 | select = [ 37 | "E", # pycodestyle 38 | "W", # pycodestyle warnings 39 | "F", # Pyflakes 40 | "UP", # pyupgrade 41 | "I", # isort 42 | "SIM", # flake8-simplify 43 | "C4", # flake8-comprehensions 44 | "PT", # flake8-pytest 45 | "PIE", # flake8-pie 46 | "EXE", # flake8-executable 47 | "A", # flake8-builtins 48 | "B", # flake8-bugbear 49 | "ANN", # flake8-annotations 50 | "BLE", # flake8-blind-except 51 | "LOG", # flake8-logging 52 | "G", # flake8-logging-format 53 | "TCH", # flake8-type-checking 54 | "RSE", # flake8-raise 55 | "RET", # flake8-return 56 | "T20", # flake8-print 57 | "ICN", # flake8-import-conventions 58 | "TID", # flake8-tidy-imports 59 | "INP", # flake8-no-pep420 60 | "NPY", # numpy 61 | "FURB", # refurb 62 | "TRY", # tryceratops 63 | "FLY", # flynt, 64 | ] 65 | ignore = [ 66 | "E501", # Line too long 67 | "ANN401", # Allow Any 68 | "TRY003", # Allow long messages 69 | "SIM117", # Allow nested with 70 | "TC006", # Add quotes to type expression in cast() 71 | "C420", # Unnecessary dict comprehension for iterable; use `dict.fromkeys` instead 72 | "A005", # Shadowing builtins 73 | "UP045", # Use `X | None` for type annotations 74 | "B905", # `zip()` without an explicit `strict=` parameter 75 | ] 76 | 77 | [tool.pytest.ini_options] 78 | asyncio_default_fixture_loop_scope = "function" 79 | log_cli = true 80 | log_cli_format = "[%(asctime)s.%(msecs)03d] %(process)d %(levelname)-8s %(name)s %(message)s" 81 | log_cli_level = "DEBUG" 82 | log_date_format = "%Y-%m-%d %H:%M:%S" 83 | markers = [ 84 | "ci_2gpu: marks to run on p5-2gpu runner (To run on p5-1gpu runner, use -m 'not (ci_2gpu or ci_4gpu)')", 85 | "ci_4gpu: marks to run on p5-4gpu runner (To run on p5-1gpu runner, use -m 'not (ci_2gpu or ci_4gpu)')", 86 | "cpu_only: marks tests which do not use a GPU", 87 | "fabric: marks tests which require libfabric", 88 | "kernel: marks kernel tests (deselect with -m 'not kernel')", 89 | ] 90 | 91 | [tool.coverage.run] 92 | branch = true 93 | omit = ["tests/**/*"] 94 | concurrency = ["multiprocessing", "thread"] 95 | parallel = true 96 | 97 | [tool.coverage.report] 98 | exclude_also = [ 99 | # Don't complain about missing debug-only code: 100 | "def __repr__", 101 | 102 | # Don't complain if tests don't hit defensive assertion code: 103 | "raise AssertionError", 104 | "raise NotImplementedError", 105 | 106 | # Don't complain if non-runnable code isn't run: 107 | "if 0:", 108 | "if __name__ == .__main__.:", 109 | 110 | # Don't complain about abstract methods, they aren't run: 111 | "@(abc\\.)?abstractmethod", 112 | ] 113 | -------------------------------------------------------------------------------- /p2p-all-to-all/a2a-kernels/src/core/combine_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "core/memory.cuh" 4 | #include "core/device_utils.cuh" 5 | #include "core/vector.cuh" 6 | 7 | #include 8 | #include 9 | 10 | namespace rose { 11 | 12 | namespace detail { 13 | 14 | __forceinline__ __device__ uint32_t pack_float_2(float a, float b) { 15 | uint32_t value; 16 | asm volatile( 17 | "{ cvt.rn.bf16x2.f32 %0, %1, %2; }" 18 | : "=r"(value) 19 | : "f"(b) 20 | , "f"(a) 21 | ); 22 | return value; 23 | } 24 | 25 | } 26 | 27 | template 28 | struct Arg { 29 | float v[SIZE]; 30 | 31 | __forceinline__ Arg() = default; 32 | 33 | __forceinline__ __device__ Arg(const T *ptr) { 34 | #pragma unroll 35 | for (unsigned i = 0; i < SIZE; ++i) { 36 | v[i] = static_cast(ptr[i]); 37 | } 38 | } 39 | 40 | template 41 | __forceinline__ __device__ void store(U *ptr) const { 42 | #pragma unroll 43 | for (unsigned i = 0; i < SIZE; ++i) { 44 | ptr[i] = static_cast(v[i]); 45 | } 46 | } 47 | 48 | __forceinline__ __device__ void store(__nv_bfloat16 *ptr) const { 49 | st_global_nc_uint4(ptr, make_uint4( 50 | detail::pack_float_2(v[0], v[1]), 51 | detail::pack_float_2(v[2], v[3]), 52 | detail::pack_float_2(v[4], v[5]), 53 | detail::pack_float_2(v[6], v[7]) 54 | )); 55 | } 56 | }; 57 | 58 | template <> 59 | __forceinline__ __device__ Arg<__nv_bfloat16, 8>::Arg(const __nv_bfloat16 *ptr) { 60 | auto from_uint32 = [](uint32_t value) -> float2{ 61 | union { 62 | uint32_t value; 63 | __nv_bfloat162 bvalue; 64 | } temp; 65 | temp.value = value; 66 | return __bfloat1622float2(temp.bvalue); 67 | }; 68 | 69 | uint4 data = ld_global_nc_uint4(ptr); 70 | auto v0 = from_uint32(data.x); 71 | v[0] = v0.x; 72 | v[1] = v0.y; 73 | auto v1 = from_uint32(data.y); 74 | v[2] = v1.x; 75 | v[3] = v1.y; 76 | 77 | auto v2 = from_uint32(data.z); 78 | v[4] = v2.x; 79 | v[5] = v2.y; 80 | auto v3 = from_uint32(data.w); 81 | v[6] = v3.x; 82 | v[7] = v3.y; 83 | }; 84 | 85 | template 86 | struct Acc { 87 | Arg v; 88 | 89 | __forceinline__ __device__ Acc() { 90 | #pragma unroll 91 | for (unsigned i = 0; i < SIZE; ++i) { 92 | v.v[i] = 0.0f; 93 | } 94 | } 95 | 96 | template 97 | __forceinline__ __device__ Acc(const Arg &arg) { 98 | #pragma unroll 99 | for (unsigned i = 0; i < SIZE; ++i) { 100 | v.v[i] = arg.v[i]; 101 | } 102 | } 103 | 104 | __forceinline__ __device__ void store(T *ptr) { 105 | v.store(ptr); 106 | } 107 | 108 | template 109 | __forceinline__ __device__ void add(float weight, const Arg &arg) { 110 | #pragma unroll 111 | for (unsigned i = 0; i < SIZE; ++i) { 112 | v.v[i] += arg.v[i] * weight; 113 | } 114 | } 115 | 116 | template 117 | __forceinline__ __device__ void add(const Arg &arg) { 118 | #pragma unroll 119 | for (unsigned i = 0; i < SIZE; ++i) { 120 | v.v[i] += arg.v[i]; 121 | } 122 | } 123 | }; 124 | 125 | template struct CombineVec { 126 | static constexpr size_t SIZE = VecStorageSize::SIZE; 127 | using DstTy = Arg; 128 | using SrcTy = Arg; 129 | using AccTy = Acc; 130 | }; 131 | 132 | } // namespace rose 133 | -------------------------------------------------------------------------------- /p2p-all-to-all/a2a-kernels/src/a2a/a2a_kernels.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "torch-lib/src/torch_lib.h" 8 | 9 | namespace a2a_kernels { 10 | using ScalarType = torch_lib::ScalarType; 11 | } // namespace a2a_kernels 12 | 13 | #include "a2a-kernels/src/lib.rs.h" 14 | 15 | namespace a2a_kernels { 16 | 17 | int a2a_dispatch_send( 18 | size_t num_blocks, 19 | size_t hidden_dim, 20 | size_t hidden_dim_scale, 21 | size_t num_experts, 22 | size_t num_experts_per_token, 23 | size_t max_private_tokens, 24 | size_t rank, 25 | size_t dp_size, 26 | size_t node_size, 27 | size_t world_size, 28 | size_t num_tokens, 29 | const int32_t *bound_m_ptr, 30 | const uint8_t *x_ptr, 31 | size_t x_elemsize, 32 | size_t x_stride, 33 | const uint8_t *x_scale_ptr, 34 | size_t x_scale_elemsize, 35 | size_t x_scale_stride_elem, 36 | size_t x_scale_stride_token, 37 | const int32_t *indices, 38 | size_t indices_stride, 39 | const float *weights, 40 | size_t weights_stride, 41 | uint32_t *token_offset, 42 | uint32_t *num_routed, 43 | uint32_t *expert_offsets, 44 | uint8_t *dispatch_route_done, 45 | uint8_t *dispatch_send_done, 46 | uint8_t *tx_ready, 47 | uint8_t *send_buffer, 48 | uint32_t *grid_counter, 49 | uint32_t *sync_counter, 50 | uint32_t **sync_ptrs, 51 | uint8_t **recv_ptrs, 52 | uint64_t stream 53 | ); 54 | 55 | int a2a_dispatch_recv( 56 | size_t num_blocks, 57 | size_t hidden_dim, 58 | size_t hidden_dim_scale, 59 | size_t x_elemsize, 60 | size_t x_scale_elemsize, 61 | size_t num_experts, 62 | size_t rank, 63 | size_t node_size, 64 | size_t world_size, 65 | int32_t *out_num_tokens_ptr, 66 | uint8_t *out_x_ptr, 67 | size_t out_x_stride, 68 | uint8_t *out_x_scale_ptr, 69 | size_t out_x_scale_stride_elem, 70 | size_t out_x_scale_stride_token, 71 | uint32_t *tokens_per_expert, 72 | uint8_t *send_buffer, 73 | uint8_t *recv_buffer, 74 | uint32_t *source_rank, 75 | uint32_t *source_offset, 76 | uint32_t *padded_index, 77 | uint32_t *num_routed, 78 | uint32_t *num_recv_tokens_ptr, 79 | uint8_t *num_recv_tokens_flag, 80 | uint8_t *dispatch_recv_flag, 81 | uint8_t *dispatch_recv_done, 82 | uint32_t *grid_counter, 83 | uint32_t *sync_counter, 84 | uint32_t **sync_ptrs, 85 | uint8_t **send_ptrs, 86 | uint64_t stream 87 | ); 88 | 89 | int a2a_combine_send( 90 | size_t num_blocks, 91 | size_t hidden_dim, 92 | size_t x_elemsize, 93 | size_t rank, 94 | size_t node_size, 95 | size_t dp_size, 96 | const uint8_t *expert_x_ptr, 97 | size_t expert_x_stride, 98 | uint8_t *tx_ready, 99 | uint8_t *send_buffer, 100 | uint8_t *recv_buffer, 101 | uint32_t *source_rank, 102 | uint32_t *combine_send_offset, 103 | uint32_t *padded_index, 104 | uint32_t *num_recv_tokens_ptr, 105 | uint8_t *combine_send_done, 106 | uint32_t *token_counter, 107 | uint32_t *sync_counter, 108 | uint32_t **sync_ptrs, 109 | uint8_t **recv_ptrs, 110 | uint64_t stream 111 | ); 112 | 113 | int a2a_combine_recv( 114 | size_t num_blocks, 115 | size_t hidden_dim, 116 | size_t x_elemsize, 117 | ScalarType in_dtype, 118 | ScalarType out_dtype, 119 | size_t num_experts, 120 | size_t num_experts_per_token, 121 | size_t rank, 122 | size_t node_size, 123 | size_t world_size, 124 | size_t num_tokens, 125 | const int32_t *bound_m_ptr, 126 | const int32_t *indices_ptr, 127 | size_t indices_stride, 128 | const float *weights_ptr, 129 | size_t weights_stride, 130 | uint8_t *out_tokens_ptr, 131 | size_t out_tokens_stride, 132 | bool accumulate, 133 | uint8_t *recv_buffer, 134 | uint32_t *token_offset, 135 | uint32_t *expert_offsets, 136 | uint8_t *combine_recv_flag, 137 | uint8_t *combine_recv_done, 138 | uint32_t *sync_counter, 139 | uint32_t **sync_ptrs, 140 | uint64_t stream 141 | ); 142 | 143 | } // namespace a2a_kernels 144 | -------------------------------------------------------------------------------- /python/pplx_garden/distributed/parallel_group.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections.abc import Iterator 3 | from contextlib import contextmanager 4 | from typing import TypeVar 5 | 6 | import torch 7 | from torch.distributed import ReduceOp, Work 8 | 9 | from pplx_garden.distributed.distributed_ops import Reducer 10 | from pplx_garden.utils import logging_utils 11 | 12 | logger = logging_utils.get_logger(__name__) 13 | 14 | 15 | T = TypeVar("T") 16 | 17 | 18 | class ParallelGroup(ABC): 19 | """Abstract base for parallel configurations.""" 20 | 21 | @property 22 | @abstractmethod 23 | def device(self) -> torch.device: 24 | """Device assigned to the current rank.""" 25 | ... 26 | 27 | @property 28 | @abstractmethod 29 | def rank(self) -> int: 30 | """Current rank within the current group.""" 31 | ... 32 | 33 | @property 34 | @abstractmethod 35 | def global_rank(self) -> int: 36 | """Current rank within the global group.""" 37 | ... 38 | 39 | @property 40 | @abstractmethod 41 | def node_rank(self) -> int: 42 | """The rank of the node.""" 43 | ... 44 | 45 | @property 46 | @abstractmethod 47 | def local_rank(self) -> int: 48 | """The rank within the current node.""" 49 | ... 50 | 51 | @property 52 | @abstractmethod 53 | def size(self) -> int: 54 | """The size of the parallel group.""" 55 | ... 56 | 57 | @property 58 | @abstractmethod 59 | def is_inter_node(self) -> bool: 60 | """Returns true of the group spans multiple nodes.""" 61 | ... 62 | 63 | @abstractmethod 64 | def broadcast_object(self, obj: T | None, root: int) -> T: 65 | """Broadcast an object across the CPU interconnect.""" 66 | ... 67 | 68 | @abstractmethod 69 | def broadcast_cpu_tensor_async(self, tensor: torch.Tensor, root: int) -> Work: 70 | """Broadcast a CPU tensor across the CPU interconnect.""" 71 | ... 72 | 73 | @abstractmethod 74 | def reducer( 75 | self, 76 | shape: torch.Size, 77 | dtype: torch.dtype, 78 | op: ReduceOp.RedOpType = ReduceOp.SUM, 79 | ) -> Reducer: ... 80 | 81 | @abstractmethod 82 | def all_reduce( 83 | self, 84 | x: torch.Tensor, 85 | op: ReduceOp.RedOpType = ReduceOp.SUM, 86 | ) -> torch.Tensor: ... 87 | 88 | @abstractmethod 89 | def all_reduce_cpu_async( 90 | self, 91 | x: torch.Tensor, 92 | op: ReduceOp.RedOpType = ReduceOp.SUM, 93 | ) -> Work: ... 94 | 95 | def all_reduce_cpu( 96 | self, 97 | x: torch.Tensor, 98 | op: ReduceOp.RedOpType = ReduceOp.SUM, 99 | ) -> None: 100 | self.all_reduce_cpu_async(x, op).wait() 101 | 102 | @abstractmethod 103 | def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: ... 104 | 105 | @abstractmethod 106 | def all_gather_object(self, obj: T) -> list[T]: ... 107 | 108 | @abstractmethod 109 | def broadcast(self, tensor: torch.Tensor, root: int) -> torch.Tensor: ... 110 | 111 | @abstractmethod 112 | def all_to_all(self, tensor: torch.Tensor) -> torch.Tensor: ... 113 | 114 | @abstractmethod 115 | def barrier(self) -> None: ... 116 | 117 | @abstractmethod 118 | @contextmanager 119 | def capture(self) -> Iterator[None]: ... 120 | 121 | @abstractmethod 122 | def destroy(self) -> None: ... 123 | 124 | @abstractmethod 125 | def slice_by_count(self, slice_count: int) -> "ParallelGroup": 126 | """ 127 | Slice the group into equal-sized `slice_count` subgroups. 128 | Return the subgroup that the current rank belongs to. 129 | """ 130 | ... 131 | 132 | @abstractmethod 133 | def slice_by_lens(self, slice_lens: list[int]) -> "ParallelGroup": 134 | """ 135 | Slice the group into subgroups of the given lengths. 136 | Return the subgroup that the current rank belongs to. 137 | Require: sum(slice_lens) == self.size 138 | """ 139 | ... 140 | 141 | @property 142 | def has_nvshmem(self) -> bool: 143 | return False 144 | -------------------------------------------------------------------------------- /fabric-lib/src/host_buffer.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{HashMap, HashSet}; 2 | use std::{ffi::c_void, ptr::NonNull, sync::Arc}; 3 | 4 | use cuda_lib::Device; 5 | use parking_lot::Mutex; 6 | 7 | use crate::{RdmaEngine, api::MemoryRegionHandle, error::Result}; 8 | 9 | pub struct HostBuffer { 10 | index: usize, 11 | ptr: *mut u8, 12 | length: usize, 13 | mr_handle: MemoryRegionHandle, 14 | cache: Arc>, 15 | } 16 | 17 | unsafe impl Send for HostBuffer {} 18 | unsafe impl Sync for HostBuffer {} 19 | 20 | impl HostBuffer { 21 | pub fn mr_handle(&self) -> MemoryRegionHandle { 22 | self.mr_handle 23 | } 24 | 25 | pub fn as_nonnull(&self) -> NonNull { 26 | unsafe { NonNull::new_unchecked(self.ptr as *mut c_void) } 27 | } 28 | 29 | pub fn as_mut_slice(&mut self) -> &mut [u8] { 30 | unsafe { std::slice::from_raw_parts_mut(self.ptr, self.length) } 31 | } 32 | } 33 | 34 | impl Drop for HostBuffer { 35 | fn drop(&mut self) { 36 | self.cache.lock().free(self.index, self.length); 37 | } 38 | } 39 | 40 | struct HostBufferEntry { 41 | #[allow(clippy::box_collection)] 42 | storage: Box>, 43 | mr_handle: MemoryRegionHandle, 44 | } 45 | 46 | impl HostBufferEntry { 47 | fn new(engine: Arc, size: usize) -> Result { 48 | let mut storage = Box::new(Vec::with_capacity(size)); 49 | unsafe { storage.set_len(size) }; 50 | 51 | // Register the memory region 52 | let buf_base = 53 | unsafe { NonNull::new_unchecked(storage.as_mut_ptr() as *mut c_void) }; 54 | let mr_handle = engine.register_memory_local(buf_base, size, Device::Host)?; 55 | 56 | Ok(Self { storage, mr_handle }) 57 | } 58 | } 59 | 60 | struct HostBufferCache { 61 | buffers: Vec, 62 | free_bufs: HashMap>, 63 | } 64 | 65 | impl HostBufferCache { 66 | fn free(&mut self, index: usize, length: usize) { 67 | self.free_bufs.entry(length).or_default().insert(index); 68 | } 69 | } 70 | 71 | #[derive(Clone)] 72 | pub struct HostBufferAllocator { 73 | cache: Arc>, 74 | engine: Arc, 75 | } 76 | 77 | impl HostBufferAllocator { 78 | pub fn new(engine: Arc) -> Self { 79 | Self { 80 | cache: Arc::new(Mutex::new(HostBufferCache { 81 | buffers: Vec::new(), 82 | free_bufs: HashMap::new(), 83 | })), 84 | engine, 85 | } 86 | } 87 | 88 | pub fn allocate(&self, size: usize) -> Result { 89 | // Try to find a free buffer 90 | // NOTE(lequn): I don't expect there to be many free buffers, so 91 | // doing a linear scan for now. If this becomes a problem, we can 92 | // introduce a proper memory allocator. 93 | let mut cache = self.cache.lock(); 94 | { 95 | for (&allocated_size, entries) in cache.free_bufs.iter_mut() { 96 | if allocated_size < size { 97 | continue; 98 | } 99 | match entries.iter().next().cloned() { 100 | None => continue, 101 | Some(index) => { 102 | entries.remove(&index); 103 | let entry = &cache.buffers[index]; 104 | return Ok(HostBuffer { 105 | index, 106 | ptr: entry.storage.as_ptr() as *mut u8, 107 | length: allocated_size, 108 | mr_handle: entry.mr_handle, 109 | cache: self.cache.clone(), 110 | }); 111 | } 112 | } 113 | } 114 | } 115 | 116 | // Otherwise allocate a new buffer. Size up to the next power of two. 117 | let size = size.next_power_of_two(); 118 | let index = cache.buffers.len(); 119 | cache.buffers.push(HostBufferEntry::new(self.engine.clone(), size)?); 120 | let entry = &cache.buffers[index]; 121 | Ok(HostBuffer { 122 | index, 123 | ptr: entry.storage.as_ptr() as *mut u8, 124 | length: size, 125 | mr_handle: entry.mr_handle, 126 | cache: self.cache.clone(), 127 | }) 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /p2p-all-to-all/a2a-kernels/src/core/memory.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace rose { 6 | 7 | __forceinline__ __device__ void st_volatile_u32(uint32_t *flag_addr, uint32_t flag) { 8 | asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); 9 | } 10 | 11 | __forceinline__ __device__ uint32_t ld_volatile_u32(uint32_t *flag_addr) { 12 | uint32_t flag; 13 | asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); 14 | return flag; 15 | } 16 | 17 | __forceinline__ __device__ uint32_t ld_acquire_u32(uint32_t *flag_addr) { 18 | uint32_t flag; 19 | asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); 20 | return flag; 21 | } 22 | 23 | __forceinline__ __device__ uint8_t ld_mmio_b8(uint8_t *flag_addr) { 24 | uint32_t tmp; 25 | asm volatile( 26 | "{ ld.mmio.relaxed.sys.global.b8 %0, [%1]; }" 27 | : "=r"(tmp) 28 | : "l"(flag_addr) 29 | : 30 | ); 31 | return static_cast(tmp); 32 | } 33 | 34 | __forceinline__ __device__ void st_mmio_b8(uint8_t *flag_addr, uint8_t flag) { 35 | uint32_t tmp = static_cast(flag); 36 | asm volatile( 37 | "{ st.mmio.relaxed.sys.global.b8 [%1], %0; }" 38 | : 39 | : "r"(tmp), "l"(flag_addr) 40 | : 41 | ); 42 | } 43 | 44 | __forceinline__ __device__ void st_release_u32(uint32_t *flag_addr, uint32_t flag) { 45 | asm volatile("st.release.sys.global.u32 [%1], %0;" :: "r"(flag), "l"(flag_addr)); 46 | } 47 | 48 | __forceinline__ __device__ void st_relaxed_u32(uint32_t *flag_addr, uint32_t flag) { 49 | asm volatile("st.relaxed.sys.global.u32 [%1], %0;" :: "r"(flag), "l"(flag_addr)); 50 | } 51 | 52 | __forceinline__ __device__ uint32_t add_release_sys_u32(uint32_t *addr, uint32_t val) { 53 | uint32_t flag; 54 | asm volatile("atom.release.sys.global.add.u32 %0, [%1], %2;" : "=r"(flag) : "l"(addr), "r"(val)); 55 | return flag; 56 | } 57 | 58 | __forceinline__ __device__ uint32_t add_release_gpu_u32(uint32_t *addr, uint32_t val) { 59 | uint32_t flag; 60 | asm volatile("atom.release.gpu.global.add.u32 %0, [%1], %2;" : "=r"(flag) : "l"(addr), "r"(val)); 61 | return flag; 62 | } 63 | 64 | __forceinline__ __device__ void fence_acq_rel_gpu() { 65 | asm volatile("{ fence.acq_rel.gpu; }":: : "memory"); 66 | } 67 | 68 | __forceinline__ __device__ void fence_acquire_gpu() { 69 | asm volatile("{ fence.acquire.gpu; }":: : "memory"); 70 | } 71 | 72 | __forceinline__ __device__ void fence_release_gpu() { 73 | asm volatile("{ fence.release.gpu; }":: : "memory"); 74 | } 75 | 76 | __forceinline__ __device__ void fence_acquire_system() { 77 | asm volatile("{ fence.acquire.sys; }":: : "memory"); 78 | } 79 | 80 | __forceinline__ __device__ void fence_release_system() { 81 | asm volatile("{ fence.release.sys; }":: : "memory"); 82 | } 83 | 84 | __forceinline__ __device__ uint4 ld_global_uint4(const void *ptr) { 85 | uint4 v; 86 | asm volatile( 87 | "{ ld.global.v4.u32 {%0, %1, %2, %3}, [%4]; }" 88 | : "=r"(v.x) 89 | , "=r"(v.y) 90 | , "=r"(v.z) 91 | , "=r"(v.w) 92 | : "l"(ptr) 93 | ); 94 | return v; 95 | } 96 | 97 | __forceinline__ __device__ void st_global_uint4(void *ptr, uint4 v) { 98 | asm volatile( 99 | "{ st.global.v4.u32 [%0], {%1, %2, %3, %4}; }" 100 | : 101 | : "l"(ptr) 102 | , "r"(v.x) 103 | , "r"(v.y) 104 | , "r"(v.z) 105 | , "r"(v.w) 106 | ); 107 | } 108 | 109 | __forceinline__ __device__ uint4 ld_global_nc_uint4(const void *ptr) { 110 | uint4 v; 111 | asm volatile( 112 | "{ ld.global.nc.L1::no_allocate.L2::256B.v4.u32 {%0, %1, %2, %3}, [%4]; }" 113 | : "=r"(v.x) 114 | , "=r"(v.y) 115 | , "=r"(v.z) 116 | , "=r"(v.w) 117 | : "l"(ptr) 118 | ); 119 | return v; 120 | } 121 | 122 | __forceinline__ __device__ void st_global_nc_uint4(void *ptr, uint4 v) { 123 | asm volatile( 124 | "{ st.global.L1::no_allocate.v4.u32 [%0], {%1, %2, %3, %4}; }" 125 | : 126 | : "l"(ptr) 127 | , "r"(v.x) 128 | , "r"(v.y) 129 | , "r"(v.z) 130 | , "r"(v.w) 131 | ); 132 | } 133 | 134 | __forceinline__ __device__ uint4 ld_shared_uint4(const void *ptr) { 135 | uint4 v; 136 | asm volatile( 137 | "{ ld.shared.v4.u32 {%0, %1, %2, %3}, [%4]; }" 138 | : "=r"(v.x) 139 | , "=r"(v.y) 140 | , "=r"(v.z) 141 | , "=r"(v.w) 142 | : "l"(ptr) 143 | ); 144 | return v; 145 | } 146 | 147 | __forceinline__ __device__ void st_shared_uint4(void *ptr, uint4 v) { 148 | asm volatile( 149 | "{ st.shared.v4.u32 [%0], {%1, %2, %3, %4}; }" 150 | : 151 | : "l"(ptr) 152 | , "r"(v.x) 153 | , "r"(v.y) 154 | , "r"(v.z) 155 | , "r"(v.w) 156 | ); 157 | } 158 | 159 | } // namespace rose 160 | -------------------------------------------------------------------------------- /rust/cuda-lib/src/mem.rs: -------------------------------------------------------------------------------- 1 | use std::{ffi::c_void, ptr::NonNull}; 2 | 3 | use cudart_sys::{cudaHostAllocMapped, cudaHostAllocPortable, cudaMemAttachGlobal}; 4 | use libc::memset; 5 | 6 | use crate::rt::CudartError; 7 | use crate::rt::{cudaFreeHost, cudaHostAlloc}; 8 | 9 | /// Owned Cuda memory. It will be freed when dropped. 10 | pub struct CudaDeviceMemory { 11 | ptr: NonNull, 12 | size: usize, 13 | } 14 | 15 | impl CudaDeviceMemory { 16 | /// Allocate a device-only CUDA buffer. 17 | pub fn device(size: usize) -> Result { 18 | let mut ptr = std::ptr::null_mut(); 19 | let ret = unsafe { cudart_sys::cudaMalloc(&raw mut ptr, size) }; 20 | let ptr = 21 | NonNull::new(ptr).ok_or_else(|| CudartError::new(ret, "cudaMalloc"))?; 22 | Ok(Self { ptr, size }) 23 | } 24 | 25 | /// Allocate a CUDA buffer visible to both host and device. 26 | pub fn alloc(size: usize) -> Result { 27 | let mut ptr = std::ptr::null_mut(); 28 | let ret = unsafe { 29 | cudart_sys::cudaMallocManaged(&raw mut ptr, size, cudaMemAttachGlobal) 30 | }; 31 | let ptr = NonNull::new(ptr) 32 | .ok_or_else(|| CudartError::new(ret, "cudaMallocManaged"))?; 33 | Ok(Self { ptr, size }) 34 | } 35 | 36 | /// Create a device buffer, initialized from some host values. 37 | pub fn from_vec(data: &[T]) -> Result { 38 | let size = std::mem::size_of_val(data); 39 | let mem = Self::device(size)?; 40 | unsafe { 41 | cudart_sys::cudaMemcpy( 42 | mem.ptr.as_ptr(), 43 | data.as_ptr() as *const c_void, 44 | size, 45 | cudart_sys::cudaMemcpyHostToDevice, 46 | ); 47 | } 48 | Ok(mem) 49 | } 50 | 51 | pub fn ptr(&self) -> NonNull { 52 | self.ptr 53 | } 54 | 55 | pub fn size(&self) -> usize { 56 | self.size 57 | } 58 | 59 | pub fn leak(self) -> NonNull { 60 | let ptr = self.ptr; 61 | std::mem::forget(self); 62 | ptr 63 | } 64 | 65 | pub fn zero(&self) { 66 | unsafe { 67 | cudart_sys::cudaMemset(self.ptr.as_ptr(), 0, self.size); 68 | } 69 | } 70 | 71 | pub fn get_ptr(&self) -> *const T { 72 | self.ptr.as_ptr() as *const T 73 | } 74 | 75 | pub fn get_mut_ptr(&mut self) -> *mut T { 76 | self.ptr.as_ptr() as *mut T 77 | } 78 | 79 | pub fn as_mut_slice(&mut self) -> &mut [T] { 80 | unsafe { 81 | std::slice::from_raw_parts_mut( 82 | self.ptr.as_ptr() as *mut T, 83 | self.size / std::mem::size_of::(), 84 | ) 85 | } 86 | } 87 | } 88 | 89 | impl Drop for CudaDeviceMemory { 90 | fn drop(&mut self) { 91 | unsafe { cudart_sys::cudaFree(self.ptr.as_ptr()) }; 92 | } 93 | } 94 | 95 | unsafe impl Send for CudaDeviceMemory {} 96 | unsafe impl Sync for CudaDeviceMemory {} 97 | 98 | pub struct CudaHostMemory { 99 | pub ptr: NonNull, 100 | pub size: usize, 101 | } 102 | 103 | impl CudaHostMemory { 104 | pub fn alloc(size: usize) -> Result { 105 | let ptr = cudaHostAlloc(size, cudaHostAllocPortable | cudaHostAllocMapped)?; 106 | unsafe { memset(ptr.as_ptr(), 0, size) }; 107 | Ok(CudaHostMemory { ptr, size }) 108 | } 109 | 110 | pub fn size(&self) -> usize { 111 | self.size 112 | } 113 | 114 | pub fn get_ptr(&self) -> *const T { 115 | self.ptr.as_ptr() as *const T 116 | } 117 | 118 | pub fn get_mut_ptr(&self) -> *mut T { 119 | self.ptr.as_ptr() as *mut T 120 | } 121 | 122 | pub fn get_ref(&self, index: usize) -> &u64 { 123 | unsafe { &*((self.ptr.as_ptr() as *const u64).add(index)) } 124 | } 125 | 126 | pub fn get_mut(&mut self, index: usize) -> &mut u64 { 127 | unsafe { &mut *((self.ptr.as_ptr() as *mut u64).add(index)) } 128 | } 129 | 130 | pub fn zero(&self) { 131 | unsafe { 132 | let slice = 133 | std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut u8, self.size); 134 | slice.fill(0); 135 | } 136 | } 137 | 138 | pub fn as_slice(&self) -> &[T] { 139 | let elemsize = std::mem::size_of::(); 140 | assert!(self.size.is_multiple_of(elemsize)); 141 | unsafe { std::slice::from_raw_parts(self.get_ptr::(), self.size / elemsize) } 142 | } 143 | } 144 | 145 | impl Drop for CudaHostMemory { 146 | fn drop(&mut self) { 147 | if let Err(error) = cudaFreeHost(self.ptr) { 148 | panic!("Failed to free UVM memory: {}", error); 149 | } 150 | } 151 | } 152 | 153 | unsafe impl Send for CudaHostMemory {} 154 | unsafe impl Sync for CudaHostMemory {} 155 | -------------------------------------------------------------------------------- /p2p-all-to-all/a2a-kernels/src/lib.rs: -------------------------------------------------------------------------------- 1 | #[cxx::bridge(namespace = "a2a_kernels")] 2 | #[allow(clippy::missing_safety_doc)] 3 | #[allow(clippy::too_many_arguments)] 4 | mod ffi { 5 | unsafe extern "C++" { 6 | include!("a2a/a2a_kernels.h"); 7 | 8 | #[namespace = "torch_lib"] 9 | type ScalarType = torch_lib::ScalarType; 10 | 11 | unsafe fn a2a_dispatch_send( 12 | num_blocks: usize, 13 | hidden_dim: usize, 14 | hidden_dim_scale: usize, 15 | num_experts: usize, 16 | num_experts_per_token: usize, 17 | max_private_tokens: usize, 18 | rank: usize, 19 | dp_size: usize, 20 | node_size: usize, 21 | world_size: usize, 22 | num_tokens: usize, 23 | bound_m_ptr: *const i32, 24 | x_ptr: *const u8, 25 | x_elemsize: usize, 26 | x_stride: usize, 27 | x_scale_ptr: *const u8, 28 | x_scale_elemsize: usize, 29 | x_scale_stride_elem: usize, 30 | x_scale_stride_token: usize, 31 | indices: *const i32, 32 | indices_stride: usize, 33 | weights: *const f32, 34 | weights_stride: usize, 35 | token_offset: *mut u32, 36 | num_routed: *mut u32, 37 | expert_offsets: *mut u32, 38 | dispatch_route_done: *mut u8, 39 | dispatch_send_done: *mut u8, 40 | tx_ready: *mut u8, 41 | send_buffer: *mut u8, 42 | grid_counter: *mut u32, 43 | sync_counter: *mut u32, 44 | sync_ptrs: *mut *mut u32, 45 | recv_ptrs: *mut *mut u8, 46 | stream: u64, 47 | ) -> i32; 48 | 49 | unsafe fn a2a_dispatch_recv( 50 | num_blocks: usize, 51 | hidden_dim: usize, 52 | hidden_dim_scale: usize, 53 | x_elemsize: usize, 54 | x_scale_elemsize: usize, 55 | num_experts: usize, 56 | rank: usize, 57 | node_size: usize, 58 | world_size: usize, 59 | out_num_tokens_ptr: *mut i32, 60 | out_x_ptr: *mut u8, 61 | out_x_stride: usize, 62 | out_x_scale_ptr: *mut u8, 63 | out_x_scale_stride_elem: usize, 64 | out_x_scale_stride_token: usize, 65 | tokens_per_expert: *mut u32, 66 | send_buffer: *mut u8, 67 | recv_buffer: *mut u8, 68 | source_rank: *mut u32, 69 | source_offset: *mut u32, 70 | padded_index: *mut u32, 71 | num_routed: *mut u32, 72 | num_recv_tokens_ptr: *mut u32, 73 | num_recv_tokens_flag: *mut u8, 74 | dispatch_recv_flag: *mut u8, 75 | dispatch_recv_done: *mut u8, 76 | grid_counter: *mut u32, 77 | sync_counter: *mut u32, 78 | sync_ptrs: *mut *mut u32, 79 | send_ptrs: *mut *mut u8, 80 | stream: u64, 81 | ) -> i32; 82 | 83 | unsafe fn a2a_combine_send( 84 | num_blocks: usize, 85 | hidden_dim: usize, 86 | x_elemsize: usize, 87 | rank: usize, 88 | node_size: usize, 89 | dp_size: usize, 90 | expert_x_ptr: *const u8, 91 | expert_x_stride: usize, 92 | tx_ready: *mut u8, 93 | send_buffer: *mut u8, 94 | recv_buffer: *mut u8, 95 | source_rank: *mut u32, 96 | combine_send_offset: *mut u32, 97 | padded_index: *mut u32, 98 | num_recv_tokens_ptr: *mut u32, 99 | combine_send_done: *mut u8, 100 | token_counter: *mut u32, 101 | sync_counter: *mut u32, 102 | sync_ptrs: *mut *mut u32, 103 | recv_ptrs: *mut *mut u8, 104 | stream: u64, 105 | ) -> i32; 106 | 107 | unsafe fn a2a_combine_recv( 108 | num_blocks: usize, 109 | hidden_dim: usize, 110 | x_elemsize: usize, 111 | in_dtype: ScalarType, 112 | out_dtype: ScalarType, 113 | num_experts: usize, 114 | num_experts_per_token: usize, 115 | rank: usize, 116 | node_size: usize, 117 | world_size: usize, 118 | num_tokens: usize, 119 | bound_m_ptr: *const i32, 120 | indices_ptr: *const i32, 121 | indices_stride: usize, 122 | weights_ptr: *const f32, 123 | weights_stride: usize, 124 | out_tokens_ptr: *mut u8, 125 | out_tokens_stride: usize, 126 | accumulate: bool, 127 | recv_buffer: *mut u8, 128 | token_offset: *mut u32, 129 | expert_offsets: *mut u32, 130 | combine_recv_flag: *mut u8, 131 | combine_recv_done: *mut u8, 132 | sync_counter: *mut u32, 133 | sync_ptrs: *mut *mut u32, 134 | stream: u64, 135 | ) -> i32; 136 | } 137 | } 138 | 139 | pub use ffi::{ 140 | a2a_combine_recv, a2a_combine_send, a2a_dispatch_recv, a2a_dispatch_send, 141 | }; 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | pplx-garden 2 | =========== 3 | 4 | Perplexity AI open source garden for inference technology 5 | 6 | ## Research Paper 7 | 8 | [RDMA Point-to-Point Communication for LLM Systems](https://arxiv.org/abs/2510.27656) 9 | 10 | ## P2P MoE dispatch/combine kernel 11 | 12 | * Support both NVIDIA ConnectX-7 and AWS EFA (potentially other RDMA NICs as wel) 13 | * Use NVLink for intra-node data transfer and RDMA for inter-node 14 | * Optimize for decode, while also support prefill 15 | * Split send and recv stages for both dispatch and combine, allow micro-batching 16 | * SM-free while RDMA transfer 17 | * Support CUDA Graph 18 | 19 | ## RDMA TransferEngine library 20 | 21 | * Support both NVIDIA ConnectX-7 and AWS EFA (potentially other RDMA NICs as well) 22 | * Support aggregation of multiple NICs per GPU 23 | * Support reliable unordered transport protocol 24 | 25 | # System requirements 26 | 27 | * (Recommended) Linux Kernel 5.12 or higher (for DMA-BUF support) 28 | * CUDA 12.8 or higher 29 | * libfabric 30 | * libibverbs 31 | * GDRCopy 32 | * `SYS_PTRACE` and `SYS_ADMIN` Linux capabilities for `pidfd_getfd`. You can obtain these by running as root, with sudo, or inside docker with `--cap-add=SYS_PTRACE --cap-add=SYS_ADMIN`. 33 | * RDMA network with GPUDirect RDMA support. Each GPU should have at least one dedicated RDMA NIC. 34 | 35 | # Docker dev image 36 | 37 | We provide a docker image for the convenience of development. You can build it with the following command: 38 | 39 | ```bash 40 | docker build -t pplx-garden-dev - < docker/dev.Dockerfile 41 | ``` 42 | 43 | Run the container with the following command: 44 | 45 | ```bash 46 | ./scripts/run-docker.sh 47 | ``` 48 | 49 | # Run fabric-debug 50 | 51 | This is the benchmark for our network library. 52 | 53 | Build the benchmark binary: 54 | 55 | ```bash 56 | cd /app 57 | cargo build --release --bin fabric-debug 58 | ``` 59 | 60 | Usage: 61 | 62 | * Server: `fabric-debug [GPUs separated by comma] [NICs per GPU]` 63 | * Client: `fabric-debug [GPUs separated by comma] [NICs per GPU] [server address]` where the server address is the one printed by the server. 64 | 65 | 66 | Example: 67 | 68 | ``` 69 | server$ /app/target/release/fabric-debug 0,1,2,3,4,5,6,7 2 70 | client$ /app/target/release/fabric-debug 0,1,2,3,4,5,6,7 2 fe80xxxx 71 | ``` 72 | 73 | # Build and Install Python Wheel 74 | 75 | ```bash 76 | cd /app 77 | export TORCH_CMAKE_PREFIX_PATH=$(python3 -c "import torch; print(torch.utils.cmake_prefix_path)") 78 | python3 -m build --wheel 79 | python3 -m pip install /app/dist/*.whl 80 | ``` 81 | 82 | # Run All-to-All Benchmark 83 | 84 | ```bash 85 | # Environment variables 86 | NUM_NODES=... 87 | NODE_RANK=... # [0, NUM_NODES) 88 | MASTER_IP=... 89 | 90 | # Run on all nodes 91 | cd /app 92 | python3 -m benchmarks.bench_all_to_all \ 93 | --world-size $((NUM_NODES * 8)) --nets-per-gpu 2 --init-method=tcp://$MASTER_IP:29500 \ 94 | --node-rank=$NODE_RANK --nvlink=8 95 | ``` 96 | 97 | Note: 98 | 99 | * Remove `--nvlink` flag if you want to use RDMA only. 100 | * Set `--nets-per-gpu` accordingly based on the VM instance type. 101 | 102 | # All-to-All Performance Results 103 | 104 | Decode (128 tokens) Dispatch and Combine: 105 | 106 | | | pplx-EFA | pplx-CX7 | DeepEP-CX7 | x | pplx-EFA | pplx-CX7 | DeepEP-CX7 | 107 | |------|---------:|---------:|-----------:|---|---------:|---------:|-----------:| 108 | | EP64 | 266.7 μs | 187.5 μs | 177.9 μs | x | 391.2 μs | 309.1 μs | 325.0 μs | 109 | | EP32 | 229.1 μs | 153.9 μs | 159.1 μs | x | 335.0 μs | 266.3 μs | 285.0 μs | 110 | | EP16 | 214.8 μs | 110.2 μs | 123.9 μs | x | 241.5 μs | 185.5 μs | 203.0 μs | 111 | | EP8 | 49.7 μs | 50.5 μs | 42.6 μs | x | 64.2 μs | 65.3 μs | 72.0 μs | 112 | 113 | 114 | Prefill (4096 tokens) Dispatch and Combine: 115 | 116 | | x | pplx-EFA | pplx-CX7 | DeepEP-CX7 | x | pplx-EFA | pplx-CX7 | DeepEP-CX7 | 117 | |------|----------:|----------:|-----------:|---|----------:|----------:|-----------:| 118 | | EP64 | 5334.3 μs | 4665.2 μs | 5071.6 μs | x | 9779.3 μs | 8771.1 μs | 5922.7 μs | 119 | | EP32 | 4619.0 μs | 4011.8 μs | 3680.2 μs | x | 8271.5 μs | 7526.8 μs | 3565.4 μs | 120 | | EP16 | 3196.7 μs | 2734.8 μs | 2481.9 μs | x | 5379.1 μs | 1062.2 μs | 1863.9 μs | 121 | | EP8 | 1052.4 μs | 5071.1 μs | 1810.3 μs | x | 1396.7 μs | 1405.1 μs | 962.9 μs | 122 | 123 | 124 | # Directory Structure 125 | 126 | * `fabric-lib/`: RDMA TransferEngine library 127 | * `p2p-all-to-all/`: P2P MoE All-to-All implementation 128 | * `python-ext/`: Python extension module from Rust code 129 | * `python/pplx_garden/`: Python code for the `pplx_garden` package 130 | * `rust/`: Rust utility libraries 131 | 132 | # Acknowledgments 133 | 134 | Our RDMA library is inspired by [MoonCake](https://www.usenix.org/conference/fast25/presentation/qin). 135 | Our MoE kernel is inspired by [DeepEP](https://github.com/deepseek-ai/DeepEP). 136 | 137 | # Citation 138 | 139 | If you find this work useful, please cite: 140 | 141 | ``` 142 | @misc{pplx-rdma-p2p, 143 | title={RDMA Point-to-Point Communication for LLM Systems}, 144 | author={Nandor Licker and Kevin Hu and Vladimir Zaytsev and Lequn Chen}, 145 | year={2025}, 146 | eprint={2510.27656}, 147 | archivePrefix={arXiv}, 148 | primaryClass={cs.DC}, 149 | url={https://arxiv.org/abs/2510.27656}, 150 | } 151 | ``` 152 | -------------------------------------------------------------------------------- /p2p-all-to-all/a2a-kernels/src/core/vector.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "core/memory.cuh" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace rose { 10 | 11 | template 12 | struct VecStorageSize; 13 | 14 | template <> 15 | struct VecStorageSize { 16 | static constexpr int SIZE = 4; 17 | }; 18 | 19 | template <> 20 | struct VecStorageSize { 21 | static constexpr int SIZE = 8; 22 | }; 23 | 24 | template <> 25 | struct VecStorageSize<__nv_bfloat16> { 26 | static constexpr int SIZE = 8; 27 | }; 28 | 29 | template 30 | struct Vec { 31 | static constexpr int SIZE = VecStorageSize::SIZE; 32 | float value[SIZE]; 33 | 34 | __forceinline__ __device__ Vec() {} 35 | 36 | __forceinline__ __device__ explicit Vec(const float (&vals)[SIZE]) { 37 | #pragma unroll 38 | for (int i = 0; i < SIZE; ++i) { 39 | value[i] = vals[i]; 40 | } 41 | } 42 | 43 | __forceinline__ __device__ static Vec load(const T *ptr); 44 | 45 | __forceinline__ __device__ void store(T *ptr) const; 46 | }; 47 | 48 | namespace detail { 49 | 50 | struct HalfPairConverter { 51 | __forceinline__ __device__ static half2 apply(float x, float y) { 52 | return __floats2half2_rn(x, y); 53 | } 54 | }; 55 | 56 | struct BFloatPairConverter { 57 | __forceinline__ __device__ static __nv_bfloat162 apply(float x, float y) { 58 | return __floats2bfloat162_rn(x, y); 59 | } 60 | }; 61 | 62 | template 63 | __forceinline__ __device__ uint4 pack_float_pairs(const float *values) { 64 | uint4 data; 65 | uint32_t *raw = reinterpret_cast(&data); 66 | union { 67 | uint32_t raw; 68 | PairType pair; 69 | } convert; 70 | 71 | #pragma unroll 72 | for (int i = 0; i < NUM_PAIRS; ++i) { 73 | convert.pair = Converter::apply(values[2 * i], values[2 * i + 1]); 74 | raw[i] = convert.raw; 75 | } 76 | 77 | return data; 78 | } 79 | 80 | } // namespace detail 81 | 82 | template <> 83 | __forceinline__ __device__ Vec Vec::load(const float *ptr) { 84 | Vec vec; 85 | float x, y, z, w; 86 | asm volatile( 87 | "{ ld.global.v4.f32 {%0, %1, %2, %3}, [%4]; }" 88 | : "=f"(x), "=f"(y), "=f"(z), "=f"(w) 89 | : "l"(ptr) 90 | ); 91 | vec.value[0] = x; 92 | vec.value[1] = y; 93 | vec.value[2] = z; 94 | vec.value[3] = w; 95 | return vec; 96 | } 97 | 98 | template <> 99 | __forceinline__ __device__ void Vec::store(float *ptr) const { 100 | asm volatile( 101 | "{ st.global.v4.f32 [%0], {%1, %2, %3, %4}; }" 102 | : 103 | : "l"(ptr) 104 | , "f"(value[0]) 105 | , "f"(value[1]) 106 | , "f"(value[2]) 107 | , "f"(value[3]) 108 | ); 109 | } 110 | 111 | template <> 112 | __forceinline__ __device__ Vec Vec::load(const half *ptr) { 113 | Vec vec; 114 | uint4 data = ld_global_uint4(ptr); 115 | 116 | union { 117 | uint32_t raw; 118 | half2 h; 119 | } convert; 120 | 121 | const uint32_t *raw = reinterpret_cast(&data); 122 | 123 | #pragma unroll 124 | for (int i = 0; i < SIZE / 2; ++i) { 125 | convert.raw = raw[i]; 126 | float2 f = __half22float2(convert.h); 127 | vec.value[2 * i] = f.x; 128 | vec.value[2 * i + 1] = f.y; 129 | } 130 | 131 | return vec; 132 | } 133 | 134 | template <> 135 | __forceinline__ __device__ void Vec::store(half *ptr) const { 136 | uint4 data = detail::pack_float_pairs(value); 137 | st_global_uint4(ptr, data); 138 | } 139 | 140 | template <> 141 | __forceinline__ __device__ Vec<__nv_bfloat16> Vec<__nv_bfloat16>::load(const __nv_bfloat16 *ptr) { 142 | Vec<__nv_bfloat16> vec; 143 | uint4 data = ld_global_uint4(ptr); 144 | 145 | union { 146 | uint32_t raw; 147 | __nv_bfloat162 b; 148 | } convert; 149 | 150 | const uint32_t *raw = reinterpret_cast(&data); 151 | 152 | #pragma unroll 153 | for (int i = 0; i < SIZE / 2; ++i) { 154 | convert.raw = raw[i]; 155 | float2 f = __bfloat1622float2(convert.b); 156 | vec.value[2 * i] = f.x; 157 | vec.value[2 * i + 1] = f.y; 158 | } 159 | 160 | return vec; 161 | } 162 | 163 | template <> 164 | __forceinline__ __device__ void Vec<__nv_bfloat16>::store(__nv_bfloat16 *ptr) const { 165 | uint4 data = detail::pack_float_pairs<__nv_bfloat162, detail::BFloatPairConverter, SIZE / 2>(value); 166 | st_global_uint4(ptr, data); 167 | } 168 | 169 | template 170 | struct FloatConvert; 171 | 172 | template <> 173 | struct FloatConvert { 174 | __forceinline__ __device__ static float apply(float value) { 175 | return value; 176 | } 177 | }; 178 | 179 | template <> 180 | struct FloatConvert { 181 | __forceinline__ __device__ static half apply(float value) { 182 | return __float2half_rn(value); 183 | } 184 | }; 185 | 186 | template <> 187 | struct FloatConvert<__nv_bfloat16> { 188 | __forceinline__ __device__ static __nv_bfloat16 apply(float value) { 189 | return __float2bfloat16_rn(value); 190 | } 191 | }; 192 | 193 | } // namespace rose 194 | -------------------------------------------------------------------------------- /rust/torch-lib/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::{ 3 | ffi::{c_char, c_void}, 4 | ptr::NonNull, 5 | }; 6 | 7 | use cuda_lib::{CudaDeviceId, Device}; 8 | use cxx::UniquePtr; 9 | use pyo3::{ 10 | Borrowed, Bound, FromPyObject, IntoPyObject, PyAny, PyErr, PyResult, Python, 11 | exceptions::PyValueError, 12 | }; 13 | 14 | #[cxx::bridge(namespace = "torch_lib")] 15 | mod ffi { 16 | extern "Rust" { 17 | type FromBlobContext; 18 | } 19 | 20 | enum DeviceType { 21 | Cpu, 22 | Cuda, 23 | } 24 | 25 | #[allow(dead_code)] 26 | struct Device { 27 | device_type: DeviceType, 28 | device_index: u8, 29 | } 30 | 31 | #[allow(non_camel_case_types)] 32 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 33 | pub enum ScalarType { 34 | BOOL, 35 | I8, 36 | U8, 37 | I16, 38 | U16, 39 | I32, 40 | U32, 41 | I64, 42 | U64, 43 | F8_E4M3, 44 | F8_E5M2, 45 | F16, 46 | BF16, 47 | F32, 48 | F64, 49 | } 50 | 51 | unsafe extern "C++" { 52 | include!("torch-lib/src/torch_lib.h"); 53 | unsafe fn from_blob( 54 | data_ptr: *mut c_char, 55 | shape: &[i64], 56 | dtype: ScalarType, 57 | device: Device, 58 | context: Box, 59 | ) -> *mut c_char; 60 | 61 | unsafe fn torch_to_scalar_type(obj: *mut c_char) -> Result; 62 | 63 | unsafe fn scalar_to_torch_type(scalar_type: ScalarType) -> Result<*mut c_char>; 64 | 65 | unsafe fn current_stream() -> u64; 66 | 67 | type TorchProfilerGuard; 68 | unsafe fn profile_range(name: String) -> UniquePtr; 69 | } 70 | } 71 | 72 | impl From for ffi::Device { 73 | fn from(device: Device) -> Self { 74 | match device { 75 | Device::Host => { 76 | ffi::Device { device_type: ffi::DeviceType::Cpu, device_index: 0 } 77 | } 78 | Device::Cuda(CudaDeviceId(device_id)) => ffi::Device { 79 | device_type: ffi::DeviceType::Cuda, 80 | device_index: device_id, 81 | }, 82 | } 83 | } 84 | } 85 | 86 | #[allow(dead_code)] 87 | struct FromBlobContext(Box); 88 | 89 | pub use ffi::ScalarType; 90 | 91 | impl ScalarType { 92 | pub fn element_size(self) -> usize { 93 | match self { 94 | ScalarType::BOOL => 1, 95 | ScalarType::U8 => 1, 96 | ScalarType::I8 => 1, 97 | ScalarType::I16 => 2, 98 | ScalarType::U16 => 2, 99 | ScalarType::I32 => 4, 100 | ScalarType::U32 => 4, 101 | ScalarType::I64 => 8, 102 | ScalarType::U64 => 8, 103 | ScalarType::F8_E4M3 => 1, 104 | ScalarType::F8_E5M2 => 1, 105 | ScalarType::F16 => 2, 106 | ScalarType::BF16 => 2, 107 | ScalarType::F32 => 4, 108 | ScalarType::F64 => 8, 109 | _ => panic!("Unsupported scalar type"), 110 | } 111 | } 112 | } 113 | 114 | /// Attempts to convert a PyTorch DType object to a scalar dtype. 115 | impl<'py> FromPyObject<'_, 'py> for ScalarType { 116 | type Error = PyErr; 117 | 118 | fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result { 119 | unsafe { ffi::torch_to_scalar_type(obj.as_ptr() as *mut c_char) }.map_err(|e| { 120 | PyValueError::new_err(format!( 121 | "Failed to convert PyTorch dtype to ScalarType: {:?}", 122 | e 123 | )) 124 | }) 125 | } 126 | } 127 | 128 | /// Wraps a scalar dtype into a PyTorch DType object. 129 | impl<'py> IntoPyObject<'py> for ScalarType { 130 | type Target = PyAny; 131 | type Output = Bound<'py, PyAny>; 132 | type Error = PyErr; 133 | 134 | fn into_pyobject(self, py: Python<'py>) -> PyResult> { 135 | let ptr = unsafe { ffi::scalar_to_torch_type(self) }.map_err(|e| { 136 | PyValueError::new_err(format!( 137 | "Failed to convert ScalarType to PyTorch dtype: {:?}", 138 | e 139 | )) 140 | })?; 141 | let py_ptr = ptr as *mut pyo3::ffi::PyObject; 142 | Ok(unsafe { Bound::from_borrowed_ptr(py, py_ptr) }) 143 | } 144 | } 145 | 146 | pub fn from_blob( 147 | data_ptr: NonNull, 148 | shape: &[i64], 149 | dtype: ScalarType, 150 | device: Device, 151 | context: Box, 152 | ) -> *mut pyo3::ffi::PyObject { 153 | unsafe { 154 | ffi::from_blob( 155 | data_ptr.as_ptr() as *mut c_char, 156 | shape, 157 | dtype, 158 | device.into(), 159 | Box::new(FromBlobContext(context)), 160 | ) as *mut pyo3::ffi::PyObject 161 | } 162 | } 163 | 164 | pub fn current_stream() -> u64 { 165 | unsafe { ffi::current_stream() } 166 | } 167 | 168 | #[allow(dead_code)] 169 | pub struct TorchProfilerGuard(UniquePtr); 170 | 171 | unsafe impl Send for TorchProfilerGuard {} 172 | unsafe impl Sync for TorchProfilerGuard {} 173 | 174 | pub fn torch_profile_range(name: String) -> TorchProfilerGuard { 175 | TorchProfilerGuard(unsafe { ffi::profile_range(name) }) 176 | } 177 | 178 | #[cfg(test)] 179 | mod test_torch; 180 | -------------------------------------------------------------------------------- /fabric-lib/src/imm_count.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::HashMap, 3 | num::NonZeroU32, 4 | sync::{ 5 | Arc, 6 | atomic::{AtomicI64, Ordering::Relaxed}, 7 | }, 8 | }; 9 | 10 | use cuda_lib::gdr::GdrFlag; 11 | use parking_lot::RwLock; 12 | 13 | use crate::api::{GdrCounter, ImmCounter}; 14 | 15 | pub enum ImmCount { 16 | Expected { counter: Arc, expected: NonZeroU32 }, 17 | Imm { counter: Arc }, 18 | Gdr { counter: Arc, flag: Arc }, 19 | } 20 | 21 | impl ImmCount { 22 | /// Consume self and return the current value and the expected value. 23 | pub fn consume(self) -> (u32, Option) { 24 | match self { 25 | ImmCount::Expected { counter, expected } => { 26 | (counter.load(Relaxed) as u32, Some(expected)) 27 | } 28 | ImmCount::Imm { counter } => (counter.load(Relaxed) as u32, None), 29 | ImmCount::Gdr { counter, .. } => (counter.load(Relaxed) as u32, None), 30 | } 31 | } 32 | 33 | /// Returns true if the counter has reached the exact expected value. 34 | /// When reached, the counter is subtracted by the expected value (likely reset to 0). 35 | pub fn inc(&self) -> bool { 36 | match &self { 37 | ImmCount::Expected { counter, expected } => { 38 | let prev = counter.fetch_add(1, Relaxed); 39 | let reached = prev as u32 + 1 == expected.get(); 40 | if reached { 41 | counter.fetch_sub(expected.get() as i64, Relaxed); 42 | } 43 | reached 44 | } 45 | ImmCount::Imm { counter } => { 46 | counter.fetch_add(1, Relaxed); 47 | false 48 | } 49 | ImmCount::Gdr { counter, flag } => { 50 | let value = counter.fetch_add(1, Relaxed) + 1; 51 | if value == 0 { 52 | flag.set(true); 53 | counter.store(0, Relaxed); 54 | } 55 | false 56 | } 57 | } 58 | } 59 | } 60 | 61 | #[derive(Debug, Clone, Copy)] 62 | pub enum ImmCountStatus { 63 | Vacant, 64 | NotReached, 65 | Reached, 66 | } 67 | 68 | pub struct ImmCountMap { 69 | map: RwLock>, 70 | } 71 | 72 | impl ImmCountMap { 73 | pub fn new() -> Self { 74 | Self { map: RwLock::new(HashMap::new()) } 75 | } 76 | 77 | /// Use the imm as a counter. Reset the counter to 0. 78 | /// Returns the previous counter if any. 79 | pub fn set_expected(&self, imm: u32, expected: NonZeroU32) -> Option { 80 | self.map.write().insert( 81 | imm, 82 | ImmCount::Expected { counter: Arc::new(AtomicI64::new(0)), expected }, 83 | ) 84 | } 85 | 86 | /// Return an exposed imm counter. 87 | pub fn get_imm_counter(&self, imm: u32) -> ImmCounter { 88 | let counter = Arc::new(AtomicI64::new(0)); 89 | let imm_counter = ImmCounter::new(counter.clone()); 90 | self.map.write().insert(imm, ImmCount::Imm { counter }); 91 | imm_counter 92 | } 93 | 94 | /// Return an exposed gdr counter. 95 | pub fn get_gdr_counter(&self, imm: u32, flag: Arc) -> GdrCounter { 96 | let counter = Arc::new(AtomicI64::new(0)); 97 | let imm_counter = GdrCounter::new(counter.clone(), flag.clone()); 98 | self.map.write().insert(imm, ImmCount::Gdr { counter, flag }); 99 | imm_counter 100 | } 101 | 102 | /// Stop treating the imm as a counter. 103 | /// Return the previous counter if any. 104 | pub fn remove(&self, imm: u32) -> Option { 105 | self.map.write().remove(&imm) 106 | } 107 | 108 | /// Get the current value of the counter. 109 | /// If the imm is not used as a counter, returns None. 110 | pub fn get(&self, imm: u32) -> Option { 111 | self.map.read().get(&imm).map(|v| match v { 112 | ImmCount::Expected { counter, .. } => counter.load(Relaxed) as u32, 113 | ImmCount::Imm { counter } => counter.load(Relaxed) as u32, 114 | ImmCount::Gdr { counter, .. } => counter.load(Relaxed) as u32, 115 | }) 116 | } 117 | 118 | /// Get the expected value of the counter. 119 | /// If the imm is not used as a counter, returns None. 120 | pub fn get_expected(&self, imm: u32) -> Option { 121 | let counters = self.map.read(); 122 | match counters.get(&imm)? { 123 | ImmCount::Expected { expected, .. } => Some(*expected), 124 | ImmCount::Gdr { .. } => None, 125 | ImmCount::Imm { .. } => None, 126 | } 127 | } 128 | 129 | /// Increment the counter. 130 | /// Returns the status of the counter. 131 | /// If the imm is not used as a counter, returns Vacant. 132 | /// When reached, the counter is subtracted by the expected value (likely reset to 0). 133 | pub fn inc(&self, imm: u32) -> ImmCountStatus { 134 | if let Some(imm_count) = self.map.read().get(&imm) { 135 | if imm_count.inc() { 136 | ImmCountStatus::Reached 137 | } else { 138 | ImmCountStatus::NotReached 139 | } 140 | } else { 141 | ImmCountStatus::Vacant 142 | } 143 | } 144 | } 145 | 146 | impl Default for ImmCountMap { 147 | fn default() -> Self { 148 | Self::new() 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /fabric-lib/src/interface.rs: -------------------------------------------------------------------------------- 1 | use mockall::{automock, mock}; 2 | 3 | use std::{ 4 | ffi::c_void, 5 | num::{NonZeroU8, NonZeroU32}, 6 | ptr::NonNull, 7 | sync::Arc, 8 | }; 9 | 10 | use cuda_lib::Device; 11 | 12 | use crate::{ 13 | api::{DomainAddress, MemoryRegionDescriptor, MemoryRegionHandle, TransferRequest}, 14 | error::{FabricLibError, Result}, 15 | }; 16 | 17 | pub trait RdmaEngine { 18 | fn main_address(&self) -> DomainAddress; 19 | 20 | fn nets_per_gpu(&self) -> NonZeroU8; 21 | 22 | fn register_memory_local( 23 | &self, 24 | ptr: NonNull, 25 | len: usize, 26 | device: Device, 27 | ) -> Result; 28 | 29 | fn register_memory_allow_remote( 30 | &self, 31 | ptr: NonNull, 32 | len: usize, 33 | device: Device, 34 | ) -> Result<(MemoryRegionHandle, MemoryRegionDescriptor)>; 35 | 36 | fn unregister_memory(&self, ptr: NonNull) -> Result<()>; 37 | } 38 | 39 | #[derive(Clone, Copy)] 40 | pub struct SendBuffer { 41 | pub(crate) ptr: NonNull, 42 | pub(crate) len: usize, 43 | pub(crate) mr_handle: MemoryRegionHandle, 44 | } 45 | 46 | unsafe impl Send for SendBuffer {} 47 | unsafe impl Sync for SendBuffer {} 48 | 49 | impl SendBuffer { 50 | pub fn new( 51 | ptr: NonNull, 52 | len: usize, 53 | mr_handle: MemoryRegionHandle, 54 | ) -> Self { 55 | Self { ptr, len, mr_handle } 56 | } 57 | } 58 | pub type CallbackResult = std::result::Result<(), String>; 59 | 60 | pub type SendCallback = Box) -> CallbackResult + Send + Sync>; 61 | 62 | pub type RecvCallback = Box CallbackResult + Send + Sync>; 63 | 64 | pub type ErrorCallback = 65 | Box CallbackResult + Send + Sync>; 66 | 67 | pub type BouncingRecvCallback = Arc CallbackResult + Send + Sync>>; 68 | 69 | pub type BouncingErrorCallback = 70 | Arc CallbackResult + Send + Sync>>; 71 | 72 | #[automock] 73 | pub trait SendRecvEngine { 74 | fn submit_send( 75 | &self, 76 | addr: DomainAddress, 77 | buffer: SendBuffer, 78 | callback: SendCallback, 79 | ) -> Result<()>; 80 | 81 | fn submit_recv( 82 | &self, 83 | mr: MemoryRegionHandle, 84 | ptr: NonNull, 85 | len: usize, 86 | on_recv: RecvCallback, 87 | on_error: ErrorCallback, 88 | ) -> Result<()>; 89 | 90 | fn submit_bouncing_recvs( 91 | &self, 92 | len: usize, 93 | count: usize, 94 | on_recv: BouncingRecvCallback, 95 | on_error: BouncingErrorCallback, 96 | ) -> Result<()>; 97 | } 98 | 99 | #[automock] 100 | pub trait AsyncTransferEngine { 101 | fn wait_for_imm_count( 102 | &self, 103 | imm: u32, 104 | expected_count: NonZeroU32, 105 | ) -> impl Future> + Send + Sync; 106 | 107 | fn submit_send_async( 108 | &self, 109 | addr: DomainAddress, 110 | buffer: SendBuffer, 111 | ) -> impl Future> + Send + Sync; 112 | 113 | fn submit_transfer_async( 114 | &self, 115 | request: TransferRequest, 116 | ) -> impl Future> + Send + Sync; 117 | } 118 | 119 | mock! { 120 | pub TestTransferEngine {} 121 | 122 | impl RdmaEngine for TestTransferEngine { 123 | fn main_address(&self) -> DomainAddress; 124 | 125 | fn nets_per_gpu(&self) -> NonZeroU8; 126 | 127 | fn register_memory_local( 128 | &self, 129 | ptr: NonNull, 130 | len: usize, 131 | device: Device, 132 | ) -> Result; 133 | 134 | fn register_memory_allow_remote( 135 | &self, 136 | ptr: NonNull, 137 | len: usize, 138 | device: Device, 139 | ) -> Result<(MemoryRegionHandle, MemoryRegionDescriptor)>; 140 | 141 | fn unregister_memory(&self, ptr: NonNull) -> Result<()>; 142 | } 143 | 144 | impl SendRecvEngine for TestTransferEngine { 145 | fn submit_send( 146 | &self, 147 | addr: DomainAddress, 148 | buffer: SendBuffer, 149 | callback: SendCallback, 150 | ) -> Result<()>; 151 | 152 | fn submit_recv( 153 | &self, 154 | mr: MemoryRegionHandle, 155 | ptr: NonNull, 156 | len: usize, 157 | on_recv: RecvCallback, 158 | on_error: ErrorCallback, 159 | ) -> Result<()>; 160 | 161 | fn submit_bouncing_recvs( 162 | &self, 163 | len: usize, 164 | count: usize, 165 | on_recv: BouncingRecvCallback, 166 | on_error: BouncingErrorCallback, 167 | ) -> Result<()>; 168 | } 169 | 170 | impl AsyncTransferEngine for TestTransferEngine { 171 | fn wait_for_imm_count( 172 | &self, 173 | imm: u32, 174 | expected_count: NonZeroU32, 175 | ) -> impl Future> + Send + Sync; 176 | 177 | fn submit_send_async( 178 | &self, 179 | addr: DomainAddress, 180 | buffer: SendBuffer, 181 | ) -> impl Future> + Send + Sync; 182 | 183 | fn submit_transfer_async( 184 | &self, 185 | request: TransferRequest, 186 | ) -> impl Future> + Send + Sync; 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /rust/torch-lib/src/torch_lib.cc: -------------------------------------------------------------------------------- 1 | #include "torch_lib.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | 11 | namespace torch_lib { 12 | 13 | TorchProfilerGuard::~TorchProfilerGuard() = default; 14 | 15 | char *from_blob( 16 | char *data_ptr, 17 | rust::Slice shape, 18 | ScalarType dtype, 19 | Device device, 20 | rust::Box context) 21 | { 22 | // from_blob needs a copy-constructible lamba, so create a shared heap 23 | // reference to the supposedly uniquely-owned context box. 24 | auto shared_ctx = std::make_shared>(std::move(context)); 25 | 26 | // Convert the ScalarType enum from Rust to the at::ScalarType enum from C++. 27 | at::ScalarType scalar_type; 28 | switch (dtype) { 29 | case ScalarType::BOOL: scalar_type = at::ScalarType::Bool; break; 30 | case ScalarType::U8: scalar_type = at::ScalarType::Byte; break; 31 | case ScalarType::I8: scalar_type = at::ScalarType::Char; break; 32 | case ScalarType::I16: scalar_type = at::ScalarType::Short; break; 33 | case ScalarType::U16: scalar_type = at::ScalarType::UInt16; break; 34 | case ScalarType::I32: scalar_type = at::ScalarType::Int; break; 35 | case ScalarType::U32: scalar_type = at::ScalarType::UInt32; break; 36 | case ScalarType::I64: scalar_type = at::ScalarType::Long; break; 37 | case ScalarType::U64: scalar_type = at::ScalarType::UInt64; break; 38 | case ScalarType::F8_E4M3: scalar_type = at::ScalarType::Float8_e4m3fn; break; 39 | case ScalarType::F8_E5M2: scalar_type = at::ScalarType::Float8_e5m2; break; 40 | case ScalarType::F16: scalar_type = at::ScalarType::Half; break; 41 | case ScalarType::F32: scalar_type = at::ScalarType::Float; break; 42 | case ScalarType::F64: scalar_type = at::ScalarType::Double; break; 43 | case ScalarType::BF16: scalar_type = at::ScalarType::BFloat16; break; 44 | default: abort(); 45 | } 46 | 47 | auto options = at::TensorOptions().dtype(scalar_type); 48 | 49 | // Set the device. 50 | switch (device.device_type) { 51 | case DeviceType::Cpu: options = options.device(at::kCPU); break; 52 | case DeviceType::Cuda: options = options.device(at::kCUDA, device.device_index); break; 53 | default: abort(); 54 | } 55 | 56 | // Create a torch tensor. 57 | auto tensor = at::from_blob( 58 | (void *)data_ptr, 59 | at::IntArrayRef(shape.data(), shape.size()), 60 | [shared_ctx](void *) { (void)shared_ctx; }, 61 | options 62 | ); 63 | 64 | // Wrap it into a PyObject. 65 | return reinterpret_cast(THPVariable_Wrap(tensor)); 66 | } 67 | 68 | ScalarType torch_to_scalar_type(char *dtype_ptr) { 69 | PyObject* dtype = reinterpret_cast(dtype_ptr); 70 | switch (at::ScalarType scalar_type = ((THPDtype*)dtype)->scalar_type) { 71 | case at::ScalarType::Bool: return ScalarType::BOOL; 72 | case at::ScalarType::Char: return ScalarType::I8; 73 | case at::ScalarType::Byte: return ScalarType::U8; 74 | case at::ScalarType::Short: return ScalarType::I16; 75 | case at::ScalarType::UInt16: return ScalarType::U16; 76 | case at::ScalarType::Int: return ScalarType::I32; 77 | case at::ScalarType::UInt32: return ScalarType::U32; 78 | case at::ScalarType::Long: return ScalarType::I64; 79 | case at::ScalarType::UInt64: return ScalarType::U64; 80 | case at::ScalarType::Float8_e4m3fn: return ScalarType::F8_E4M3; 81 | case at::ScalarType::Float8_e5m2: return ScalarType::F8_E5M2; 82 | case at::ScalarType::Half: return ScalarType::F16; 83 | case at::ScalarType::Float: return ScalarType::F32; 84 | case at::ScalarType::Double: return ScalarType::F64; 85 | case at::ScalarType::BFloat16: return ScalarType::BF16; 86 | default: { 87 | throw std::runtime_error("Unsupported scalar type: " + std::to_string((int)scalar_type)); 88 | } 89 | } 90 | } 91 | 92 | char *scalar_to_torch_type(ScalarType scalar_type) { 93 | at::ScalarType dtype; 94 | switch (scalar_type) { 95 | case ScalarType::BOOL: dtype = at::ScalarType::Bool; break; 96 | case ScalarType::U8: dtype = at::ScalarType::Byte; break; 97 | case ScalarType::I8: dtype = at::ScalarType::Char; break; 98 | case ScalarType::I16: dtype = at::ScalarType::Short; break; 99 | case ScalarType::U16: dtype = at::ScalarType::UInt16; break; 100 | case ScalarType::I32: dtype = at::ScalarType::Int; break; 101 | case ScalarType::U32: dtype = at::ScalarType::UInt32; break; 102 | case ScalarType::I64: dtype = at::ScalarType::Long; break; 103 | case ScalarType::U64: dtype = at::ScalarType::UInt64; break; 104 | case ScalarType::F8_E4M3: dtype = at::ScalarType::Float8_e4m3fn; break; 105 | case ScalarType::F8_E5M2: dtype = at::ScalarType::Float8_e5m2; break; 106 | case ScalarType::F16: dtype = at::ScalarType::Half; break; 107 | case ScalarType::F32: dtype = at::ScalarType::Float; break; 108 | case ScalarType::F64: dtype = at::ScalarType::Double; break; 109 | case ScalarType::BF16: dtype = at::ScalarType::BFloat16; break; 110 | default: { 111 | throw std::runtime_error("Unsupported scalar type: " + std::to_string((int)scalar_type)); 112 | } 113 | } 114 | return reinterpret_cast(torch::getTHPDtype(dtype)); 115 | } 116 | 117 | uint64_t current_stream() { 118 | return (int64_t)(cudaStream_t)at::cuda::getCurrentCUDAStream(); 119 | } 120 | 121 | TorchProfilerGuard::TorchProfilerGuard(const char* name) { 122 | guard = std::make_unique(at::RecordScope::USER_SCOPE); 123 | if (guard->isActive()) { 124 | guard->before(name); 125 | } 126 | } 127 | 128 | std::unique_ptr profile_range(rust::String name) { 129 | return std::make_unique(name.c_str()); 130 | } 131 | 132 | } // namespace torch_lib 133 | -------------------------------------------------------------------------------- /p2p-all-to-all/a2a-kernels/src/core/launch_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | template 8 | class Fixed { 9 | public: 10 | __device__ Fixed(size_t value) {} 11 | 12 | __device__ operator size_t() const { return V; } 13 | 14 | static constexpr size_t Value = V; 15 | }; 16 | 17 | class NotFixed { 18 | public: 19 | __device__ NotFixed(size_t value) : value_(value) {} 20 | 21 | __device__ operator size_t() const { return value_; } 22 | private: 23 | size_t value_; 24 | }; 25 | 26 | #define _LAUNCH_TYPE(kind, var, value, ...) \ 27 | case kind: { \ 28 | using var = value; \ 29 | { __VA_ARGS__; } \ 30 | break; \ 31 | } 32 | 33 | #define _LAUNCH_VAL(kind, var, value, ...) \ 34 | case kind: { \ 35 | static constexpr decltype(value) var = value; \ 36 | { __VA_ARGS__; } \ 37 | break; \ 38 | } 39 | 40 | // Generic dtype dispatch macro 41 | #ifndef LAUNCH_DTYPE 42 | #define LAUNCH_DTYPE(dtype, var, ...) \ 43 | switch (dtype) { \ 44 | _LAUNCH_TYPE(DTYPE_FLOAT16, var, half, __VA_ARGS__) \ 45 | _LAUNCH_TYPE(DTYPE_BFLOAT16, var, __nv_bfloat16, __VA_ARGS__) \ 46 | _LAUNCH_TYPE(DTYPE_FLOAT32, var, float, __VA_ARGS__) \ 47 | default: { \ 48 | assert(false && "Unsupported dtype"); \ 49 | break; \ 50 | } \ 51 | } 52 | #endif 53 | 54 | // Static specialization for hidden dimension 55 | #ifndef LAUNCH_HIDDEN_DIM 56 | #define LAUNCH_HIDDEN_DIM(dtype, var, ...) \ 57 | switch (dtype) { \ 58 | _LAUNCH_TYPE(2048, var, Fixed<2048>, __VA_ARGS__) \ 59 | _LAUNCH_TYPE(4096, var, Fixed<4096>, __VA_ARGS__) \ 60 | _LAUNCH_TYPE(7168, var, Fixed<7168>, __VA_ARGS__) \ 61 | default: { \ 62 | using var = NotFixed; \ 63 | { __VA_ARGS__; } \ 64 | break; \ 65 | } \ 66 | } 67 | #endif 68 | 69 | // Static specialization for the token dimension 70 | #ifndef LAUNCH_TOKEN_DIM_DISPATCH 71 | #define LAUNCH_TOKEN_DIM_DISPATCH(dim, var, ...) \ 72 | switch (dim) { \ 73 | _LAUNCH_TYPE(2048, var, Fixed<2048>, __VA_ARGS__) \ 74 | _LAUNCH_TYPE(4096, var, Fixed<4096>, __VA_ARGS__) \ 75 | _LAUNCH_TYPE(7168, var, Fixed<7168>, __VA_ARGS__) \ 76 | default: { \ 77 | using var = NotFixed; \ 78 | { __VA_ARGS__; } \ 79 | break; \ 80 | } \ 81 | } 82 | #endif 83 | 84 | 85 | // Static specialization for the token dimension 86 | #ifndef LAUNCH_TOKEN_DIM_COMBINE 87 | #define LAUNCH_TOKEN_DIM_COMBINE(dim, var, ...) \ 88 | switch (dim) { \ 89 | _LAUNCH_TYPE(7168 * 2, var, Fixed<7168 * 2>, __VA_ARGS__) \ 90 | default: { \ 91 | using var = NotFixed; \ 92 | { __VA_ARGS__; } \ 93 | break; \ 94 | } \ 95 | } 96 | #endif 97 | 98 | 99 | // Static specialization for number of experts per token 100 | #ifndef LAUNCH_NUM_EXPERTS_PER_TOKEN 101 | #define LAUNCH_NUM_EXPERTS_PER_TOKEN(dtype, var, ...) \ 102 | switch (dtype) { \ 103 | _LAUNCH_TYPE(8, var, Fixed<8>, __VA_ARGS__) \ 104 | default: { \ 105 | using var = NotFixed; \ 106 | { __VA_ARGS__; } \ 107 | break; \ 108 | } \ 109 | } 110 | #endif 111 | 112 | // Static specialization for the hidden dim scale. 113 | #ifndef LAUNCH_HIDDEN_DIM_SCALE 114 | #define LAUNCH_HIDDEN_DIM_SCALE(dtype, var, ...) \ 115 | switch (dtype) { \ 116 | _LAUNCH_TYPE(8, var, Fixed<8>, __VA_ARGS__) \ 117 | _LAUNCH_TYPE(32, var, Fixed<32>, __VA_ARGS__) \ 118 | _LAUNCH_TYPE(56, var, Fixed<56>, __VA_ARGS__) \ 119 | default: { \ 120 | using var = NotFixed; \ 121 | { __VA_ARGS__; } \ 122 | break; \ 123 | } \ 124 | } 125 | #endif 126 | 127 | // Static specialization for the hidden dim scale. 128 | #ifndef LAUNCH_HIDDEN_DIM_SCALE_BYTES 129 | #define LAUNCH_HIDDEN_DIM_SCALE_BYTES(dtype, var, ...) \ 130 | switch (dtype) { \ 131 | _LAUNCH_TYPE(32, var, Fixed<32>, __VA_ARGS__) \ 132 | _LAUNCH_TYPE(128, var, Fixed<128>, __VA_ARGS__) \ 133 | _LAUNCH_TYPE(224, var, Fixed<224>, __VA_ARGS__) \ 134 | default: { \ 135 | using var = NotFixed; \ 136 | { __VA_ARGS__; } \ 137 | break; \ 138 | } \ 139 | } 140 | #endif 141 | 142 | // Static specialization for the world size. 143 | #ifndef LAUNCH_WORLD_SIZE 144 | #define LAUNCH_WORLD_SIZE(world_size, var, ...) \ 145 | switch (world_size) { \ 146 | _LAUNCH_VAL(1, var, 1, __VA_ARGS__) \ 147 | _LAUNCH_VAL(2, var, 2, __VA_ARGS__) \ 148 | _LAUNCH_VAL(4, var, 4, __VA_ARGS__) \ 149 | _LAUNCH_VAL(8, var, 8, __VA_ARGS__) \ 150 | default: { \ 151 | assert(false && "Unsupported world size"); \ 152 | break; \ 153 | } \ 154 | } 155 | #endif 156 | 157 | // Static specialization for the DP group size 158 | #ifndef LAUNCH_DP_SIZE 159 | #define LAUNCH_DP_SIZE(dp_size, var, ...) \ 160 | switch (dp_size) { \ 161 | _LAUNCH_VAL(1, var, 1, __VA_ARGS__) \ 162 | _LAUNCH_VAL(2, var, 2, __VA_ARGS__) \ 163 | _LAUNCH_VAL(4, var, 4, __VA_ARGS__) \ 164 | _LAUNCH_VAL(8, var, 8, __VA_ARGS__) \ 165 | default: { \ 166 | assert(false && "Unsupported DP size"); \ 167 | break; \ 168 | } \ 169 | } 170 | #endif 171 | 172 | // Static specialization for the DP group size 173 | #ifndef LAUNCH_ACCUMULATE 174 | #define LAUNCH_ACCUMULATE(accumulate, var, ...) \ 175 | switch (accumulate) { \ 176 | _LAUNCH_VAL(true, var, true, __VA_ARGS__) \ 177 | _LAUNCH_VAL(false, var, false, __VA_ARGS__) \ 178 | } 179 | #endif 180 | 181 | 182 | 183 | #ifndef LAUNCH_BASIC_FLOAT 184 | #define LAUNCH_BASIC_FLOAT(dtype, var, ...) \ 185 | switch (dtype) { \ 186 | _LAUNCH_TYPE(torch_lib::ScalarType::F16, var, half, __VA_ARGS__) \ 187 | _LAUNCH_TYPE(torch_lib::ScalarType::BF16, var, __nv_bfloat16, __VA_ARGS__) \ 188 | _LAUNCH_TYPE(torch_lib::ScalarType::F32, var, float, __VA_ARGS__) \ 189 | default: { \ 190 | assert(false && "Unsupported dtype"); \ 191 | break; \ 192 | } \ 193 | } 194 | #endif 195 | -------------------------------------------------------------------------------- /rust/cuda-lib/src/gdr.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | ffi::c_void, 3 | mem::{size_of, size_of_val}, 4 | ptr::null_mut, 5 | sync::{ 6 | Arc, 7 | atomic::{AtomicU8, Ordering}, 8 | }, 9 | }; 10 | 11 | use crate::CudaError; 12 | 13 | type GdrResult = Result; 14 | 15 | /// Reference-counted GDR context handle. 16 | struct GdrContextHandle { 17 | handle: gdrapi_sys::gdr_t, 18 | } 19 | unsafe impl Send for GdrContextHandle {} 20 | unsafe impl Sync for GdrContextHandle {} 21 | 22 | impl Drop for GdrContextHandle { 23 | fn drop(&mut self) { 24 | unsafe { gdrapi_sys::gdr_close(self.handle) }; 25 | } 26 | } 27 | 28 | /// Public wrapper around the GDRCopy context. 29 | pub struct GdrCopyContext { 30 | context: Arc, 31 | } 32 | 33 | fn align_to(ptr: u64, alignment: usize) -> u64 { 34 | (ptr + alignment as u64 - 1).div_ceil(alignment as u64) * alignment as u64 35 | } 36 | 37 | impl GdrCopyContext { 38 | pub fn new() -> GdrResult { 39 | let handle = unsafe { gdrapi_sys::gdr_open() }; 40 | if handle.is_null() { 41 | return Err(CudaError::GdrCopyError("Failed to create GDR copy handle")); 42 | } 43 | Ok(GdrCopyContext { context: Arc::new(GdrContextHandle { handle }) }) 44 | } 45 | 46 | fn alloc_buffer(&self, nbytes: usize) -> GdrResult { 47 | let mut device_ptr: u64 = 0; 48 | let page_size: usize = 1 << 16; // 64KB page size 49 | let bytesize = nbytes.div_ceil(page_size) * page_size; 50 | 51 | if unsafe { cuda_sys::cuMemAlloc(&mut device_ptr, bytesize + page_size) } 52 | != cuda_sys::CUDA_SUCCESS 53 | { 54 | return Err(CudaError::GdrCopyError("Failed to allocate GDR buffer")); 55 | } 56 | 57 | let aligned_device_ptr = align_to(device_ptr, page_size); 58 | 59 | let context = self.context.clone(); 60 | 61 | let g = context.handle; 62 | let mut mh = gdrapi_sys::gdr_mh_t { h: 0 }; 63 | 64 | let ret = unsafe { 65 | gdrapi_sys::gdr_pin_buffer(g, aligned_device_ptr, bytesize, 0, 0, &mut mh) 66 | }; 67 | if ret != 0 { 68 | unsafe { cuda_sys::cuMemFree(device_ptr) }; 69 | return Err(CudaError::GdrCopyError("Failed to pin GDR buffer")); 70 | } 71 | 72 | let mut mapped_ptr: *mut c_void = null_mut(); 73 | let ret = unsafe { gdrapi_sys::gdr_map(g, mh, &mut mapped_ptr, bytesize) }; 74 | if ret != 0 { 75 | unsafe { 76 | gdrapi_sys::gdr_unpin_buffer(g, mh); 77 | cuda_sys::cuMemFree(device_ptr); 78 | }; 79 | return Err(CudaError::GdrCopyError("Failed to map GDR buffer")); 80 | } 81 | 82 | Ok(GdrBuffer { 83 | device_ptr, 84 | aligned_device_ptr, 85 | mapped_ptr, 86 | bytesize, 87 | mh, 88 | context, 89 | }) 90 | } 91 | } 92 | 93 | /// Raw buffer allocated on the CPU, copied using GDRCopy. 94 | struct GdrBuffer { 95 | device_ptr: u64, 96 | aligned_device_ptr: u64, 97 | mapped_ptr: *mut c_void, 98 | mh: gdrapi_sys::gdr_mh_t, 99 | bytesize: usize, 100 | context: Arc, 101 | } 102 | 103 | unsafe impl Send for GdrBuffer {} 104 | unsafe impl Sync for GdrBuffer {} 105 | 106 | impl Drop for GdrBuffer { 107 | fn drop(&mut self) { 108 | let g = self.context.handle; 109 | unsafe { 110 | gdrapi_sys::gdr_unmap(g, self.mh, self.mapped_ptr, self.bytesize); 111 | gdrapi_sys::gdr_unpin_buffer(g, self.mh); 112 | cuda_sys::cuMemFree(self.device_ptr); 113 | }; 114 | } 115 | } 116 | 117 | trait GdrRead { 118 | fn read(mapped_ptr: *mut c_void) -> Self; 119 | } 120 | 121 | impl GdrRead for u8 { 122 | #[inline(always)] 123 | fn read(mapped_ptr: *mut c_void) -> Self { 124 | let flag = unsafe { AtomicU8::from_ptr(mapped_ptr as *mut u8) }; 125 | flag.load(Ordering::Acquire) 126 | } 127 | } 128 | 129 | trait GdrWrite { 130 | fn write(mapped_ptr: *mut c_void, value: Self); 131 | } 132 | 133 | impl GdrWrite for u8 { 134 | #[inline(always)] 135 | fn write(mapped_ptr: *mut c_void, value: Self) { 136 | let flag = unsafe { AtomicU8::from_ptr(mapped_ptr as *mut u8) }; 137 | flag.store(value, Ordering::Release); 138 | } 139 | } 140 | 141 | impl GdrBuffer { 142 | fn get_device_ptr(&self) -> *mut c_void { 143 | self.aligned_device_ptr as *mut c_void 144 | } 145 | 146 | #[inline(always)] 147 | fn read(&self) -> T { 148 | T::read(self.mapped_ptr) 149 | } 150 | 151 | #[inline(always)] 152 | fn write(&self, value: T) { 153 | T::write(self.mapped_ptr, value); 154 | } 155 | 156 | fn copy_to(&self, src: *const c_void, nbytes: usize) { 157 | unsafe { 158 | gdrapi_sys::gdr_copy_to_mapping(self.mh, self.mapped_ptr, src, nbytes); 159 | } 160 | } 161 | } 162 | 163 | /// Byte-flag implemented using GDRCopy. 164 | pub struct GdrFlag { 165 | buffer: GdrBuffer, 166 | } 167 | 168 | impl GdrFlag { 169 | pub fn new(context: &GdrCopyContext) -> GdrResult { 170 | let buffer = context.alloc_buffer(size_of::())?; 171 | Ok(GdrFlag { buffer }) 172 | } 173 | 174 | pub fn wait(&self) { 175 | while !self.is_set() { 176 | std::hint::spin_loop(); 177 | } 178 | self.set(false); 179 | } 180 | 181 | pub fn get_device_ptr(&self) -> *mut u8 { 182 | self.buffer.get_device_ptr() as *mut u8 183 | } 184 | 185 | pub fn set(&self, value: bool) { 186 | self.buffer.write(value as u8); 187 | } 188 | 189 | fn is_set(&self) -> bool { 190 | self.buffer.read::() != 0 191 | } 192 | } 193 | 194 | pub struct GdrVec { 195 | buffer: GdrBuffer, 196 | len: usize, 197 | _marker: std::marker::PhantomData, 198 | } 199 | 200 | impl GdrVec { 201 | pub fn new(context: &GdrCopyContext, len: usize) -> GdrResult { 202 | let buffer = context.alloc_buffer(len * size_of::())?; 203 | Ok(GdrVec { buffer, len, _marker: std::marker::PhantomData }) 204 | } 205 | 206 | pub fn get_device_ptr(&self) -> *mut T { 207 | self.buffer.get_device_ptr().cast::() 208 | } 209 | 210 | pub fn copy(&self, value: &[T]) { 211 | debug_assert!(value.len() <= self.len); 212 | self.buffer.copy_to(value.as_ptr() as *const c_void, size_of_val(value)); 213 | } 214 | } 215 | -------------------------------------------------------------------------------- /python/pplx_garden/fabric_lib.pyi: -------------------------------------------------------------------------------- 1 | # ruff: noqa: A002 2 | 3 | from collections.abc import Callable, Sequence 4 | from typing import Any 5 | 6 | import torch 7 | 8 | class DomainInfo: 9 | @property 10 | def name(self) -> str: ... 11 | @property 12 | def link_speed(self) -> int: ... 13 | 14 | class DomainAddress: 15 | @classmethod 16 | def from_bytes(cls, bytes: bytes) -> DomainAddress: ... 17 | def as_bytes(self) -> bytes: ... 18 | @classmethod 19 | def from_str(cls, s: str) -> DomainAddress: ... 20 | def __str__(self) -> str: ... 21 | def __repr__(self) -> str: ... 22 | def __eq__(self, other: object) -> bool: ... 23 | def __hash__(self) -> int: ... 24 | 25 | class TopologyGroup: 26 | @property 27 | def cuda_device(self) -> int: ... 28 | @property 29 | def numa(self) -> int: ... 30 | @property 31 | def domains(self) -> list[DomainInfo]: ... 32 | @property 33 | def cpus(self) -> list[int]: ... 34 | 35 | class MemoryRegionHandle: 36 | def capsule(self) -> Any: ... 37 | def debug_str(self) -> str: ... 38 | 39 | class MemoryRegionDescriptor: 40 | @classmethod 41 | def from_bytes(cls, bytes_: bytes) -> MemoryRegionDescriptor: ... 42 | def as_bytes(self) -> bytes: ... 43 | def debug_str(self) -> str: ... 44 | 45 | class PageIndices: 46 | def __init__(self, indices: Sequence[int]) -> None: ... 47 | 48 | class UvmWatcher: 49 | @property 50 | def ptr(self) -> int: ... 51 | 52 | class TransferEngineBuilder: 53 | def add_gpu_domains( 54 | self, 55 | cuda_device: int, 56 | domains: list[DomainInfo], 57 | pin_worker_cpu: int, 58 | pin_uvm_cpu: int, 59 | ) -> TransferEngineBuilder: ... 60 | def build(self) -> TransferEngine: ... 61 | 62 | class FabricEngine: ... 63 | 64 | class TransferEngine: 65 | def __init__(self, nets_per_gpu: int, cuda_devices: list[int]) -> None: ... 66 | @staticmethod 67 | def detect_topology() -> list[TopologyGroup]: ... 68 | @staticmethod 69 | def builder() -> TransferEngineBuilder: ... 70 | @property 71 | def main_address(self) -> DomainAddress: ... 72 | @property 73 | def num_domains(self) -> int: ... 74 | @property 75 | def aggregated_link_speed(self) -> int: ... 76 | @property 77 | def nets_per_gpu(self) -> int: ... 78 | @property 79 | def fabric_engine(self) -> FabricEngine: ... 80 | def stop(self) -> None: ... 81 | def register_tensor( 82 | self, 83 | tensor: torch.Tensor, 84 | ) -> tuple[MemoryRegionHandle, MemoryRegionDescriptor]: ... 85 | def register_memory( 86 | self, 87 | ptr: int, 88 | len: int, 89 | device: torch.device, 90 | ) -> tuple[MemoryRegionHandle, MemoryRegionDescriptor]: ... 91 | def unregister_memory(self, ptr: int) -> None: ... 92 | def alloc_scalar_watcher( 93 | self, 94 | callback: Callable[[int, int], bool], 95 | ) -> UvmWatcher: 96 | """ 97 | Allocates a watcher for a scalar value on Unified Memory. 98 | The returned watcher has a pointer to a 64-bit value. 99 | The value is initialized to 0. 100 | Callback: (old_value, new_value) -> continue_watch 101 | """ 102 | 103 | def set_imm_callback(self, callback: Callable[[int], None]) -> None: 104 | """ 105 | Sets a callback when receiving an immediate number that is not used as a counter. 106 | Callback signature: (imm: int) -> None 107 | """ 108 | def set_imm_count_expected( 109 | self, 110 | imm: int, 111 | expected_count: int, 112 | on_reached: Callable[[], None], 113 | ) -> tuple[int, int] | None: 114 | """ 115 | Use imm as a counter. Set the expected count. 116 | Once the expected count is reached, the callback will be called. 117 | Then, the imm is no longer used as a counter. 118 | 119 | If the imm was not previously used as a counter, return None. 120 | Otherwise, return the previous counter and the previous expected count. 121 | The previous counter and callback will be discarded. 122 | """ 123 | def remove_imm_count(self, imm: int) -> tuple[int, int] | None: 124 | """ 125 | If imm is not used as a counter, return None. 126 | Otherwise, return the previous counter and the previous expected count. 127 | The previous counter and callback will be discarded. 128 | 129 | Normally you don't need to call this function because set_imm_count_expected 130 | removes the counter after reaching the expected count. 131 | 132 | This function is useful if you know that the count will not reach the 133 | expected count, for example, when a transfer is cancelled. 134 | """ 135 | 136 | def submit_bouncing_recvs( 137 | self, 138 | count: int, 139 | len: int, 140 | on_recv: Callable[[bytes], None], 141 | on_error: Callable[[str], None], 142 | ) -> None: ... 143 | def submit_send( 144 | self, 145 | addr: DomainAddress, 146 | data: bytes, 147 | on_done: Callable[[], None], 148 | on_error: Callable[[str], None], 149 | ) -> None: ... 150 | def submit_imm( 151 | self, 152 | imm_data: int, 153 | dst_mr: MemoryRegionDescriptor, 154 | on_done: Callable[[], None], 155 | on_error: Callable[[str], None], 156 | ) -> None: ... 157 | def submit_write( 158 | self, 159 | src_mr: MemoryRegionHandle, 160 | offset: int, 161 | length: int, 162 | imm_data: int | None, 163 | dst_mr: MemoryRegionDescriptor, 164 | dst_offset: int, 165 | on_done: Callable[[], None], 166 | on_error: Callable[[str], None], 167 | num_shards: int | None = None, 168 | ) -> None: 169 | """ 170 | Args: 171 | num_shards: If None, shard the transfer across all domains. \ 172 | Otherwise, shard the transfer across the specified number of domains. 173 | """ 174 | def submit_paged_writes( 175 | self, 176 | length: int, 177 | src_mr: MemoryRegionHandle, 178 | src_page_indices: PageIndices, 179 | src_stride: int, 180 | src_offset: int, 181 | dst_mr: MemoryRegionDescriptor, 182 | dst_page_indices: PageIndices, 183 | dst_stride: int, 184 | dst_offset: int, 185 | imm_data: int | None, 186 | on_done: Callable[[], None], 187 | on_error: Callable[[str], None], 188 | ) -> None: ... 189 | -------------------------------------------------------------------------------- /python-ext/src/py_p2p_all_to_all.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | ffi::c_void, 3 | ptr::{null, null_mut}, 4 | }; 5 | 6 | use p2p_all_to_all::{AllToAllContext, AllToAllRankHandle}; 7 | use pyo3::{ 8 | Bound, PyResult, exceptions::PyRuntimeError, pyclass, pymethods, types::PyModule, 9 | types::PyModuleMethods, 10 | }; 11 | use torch_lib::ScalarType; 12 | 13 | use crate::py_fabric_lib::{ 14 | PyDomainAddress, PyMemoryRegionDescriptor, PyMemoryRegionHandle, PyTransferEngine, 15 | }; 16 | 17 | #[pyclass(name = "AllToAllContext", module = "pplx_garden._rust")] 18 | pub(crate) struct PyAllToAllContext { 19 | ctx: AllToAllContext, 20 | } 21 | 22 | #[pymethods] 23 | impl PyAllToAllContext { 24 | #[staticmethod] 25 | #[allow(clippy::too_many_arguments)] 26 | fn create( 27 | hidden_dim: usize, 28 | hidden_dim_scale: Option, 29 | in_elemsize: usize, 30 | out_elemsize: usize, 31 | out_dtype: ScalarType, 32 | scale_elemsize: Option, 33 | max_num_tokens: usize, 34 | max_recv_tokens: usize, 35 | max_private_tokens: usize, 36 | num_experts: usize, 37 | expert_padding: usize, 38 | num_experts_per_token: usize, 39 | rank: usize, 40 | dp_size: usize, 41 | node_size: usize, 42 | world_size: usize, 43 | num_routed_ptr: u64, 44 | num_routed_mr: PyMemoryRegionHandle, 45 | send_buffer_ptr: u64, 46 | send_buffer_mr: PyMemoryRegionHandle, 47 | recv_buffer_ptr: u64, 48 | recv_buffer_mr: PyMemoryRegionHandle, 49 | sync_ptrs: Vec, 50 | send_ptrs: Vec, 51 | recv_ptrs: Vec, 52 | device: u8, 53 | imm_base: u32, 54 | ranks: Vec<( 55 | PyDomainAddress, 56 | PyMemoryRegionDescriptor, 57 | PyMemoryRegionDescriptor, 58 | )>, 59 | transfer_engine: &PyTransferEngine, 60 | worker_cpu: Option, 61 | ) -> PyResult { 62 | let rank_handles = ranks 63 | .into_iter() 64 | .map(|data| AllToAllRankHandle::new(data.0.0, data.1.0, data.2.0)) 65 | .collect(); 66 | 67 | let ctx = AllToAllContext::new( 68 | hidden_dim, 69 | hidden_dim_scale.unwrap_or(0), 70 | in_elemsize, 71 | out_elemsize, 72 | out_dtype, 73 | scale_elemsize.unwrap_or(0), 74 | max_num_tokens, 75 | max_recv_tokens, 76 | max_private_tokens, 77 | num_experts, 78 | expert_padding, 79 | num_experts_per_token, 80 | rank, 81 | dp_size, 82 | node_size, 83 | world_size, 84 | num_routed_ptr as *mut u32, 85 | num_routed_mr.0, 86 | send_buffer_ptr as *mut c_void, 87 | send_buffer_mr.0, 88 | recv_buffer_ptr as *mut c_void, 89 | recv_buffer_mr.0, 90 | sync_ptrs, 91 | send_ptrs, 92 | recv_ptrs, 93 | device, 94 | imm_base, 95 | rank_handles, 96 | transfer_engine.get_fabric_engine(), 97 | worker_cpu, 98 | )?; 99 | Ok(Self { ctx }) 100 | } 101 | 102 | #[allow(clippy::too_many_arguments)] 103 | fn dispatch_send( 104 | &mut self, 105 | num_tokens: usize, 106 | x_ptr: u64, 107 | x_stride: usize, 108 | x_scale_ptr: Option, 109 | x_scale_stride_elem: Option, 110 | x_scale_stride_token: Option, 111 | indices_ptr: u64, 112 | indices_stride: usize, 113 | weights_ptr: u64, 114 | weights_stride: usize, 115 | bound_m_ptr: Option, 116 | stream: u64, 117 | ) -> PyResult<()> { 118 | self.ctx 119 | .dispatch_send( 120 | num_tokens, 121 | x_ptr as *const c_void, 122 | x_stride, 123 | x_scale_ptr.map(|ptr| ptr as *const c_void).unwrap_or(null()), 124 | x_scale_stride_elem.unwrap_or(0), 125 | x_scale_stride_token.unwrap_or(0), 126 | indices_ptr as *const i32, 127 | indices_stride, 128 | weights_ptr as *const f32, 129 | weights_stride, 130 | bound_m_ptr.map(|ptr| ptr as *const i32).unwrap_or(null()), 131 | stream, 132 | ) 133 | .map_err(|e| PyRuntimeError::new_err(e.to_string())) 134 | } 135 | 136 | #[allow(clippy::too_many_arguments)] 137 | fn dispatch_recv( 138 | &mut self, 139 | out_num_tokens_ptr: u64, 140 | out_x_ptr: u64, 141 | out_x_stride: usize, 142 | out_x_scale_ptr: Option, 143 | out_x_scale_stride_elem: Option, 144 | out_x_scale_stride_token: Option, 145 | stream: u64, 146 | ) -> PyResult<()> { 147 | self.ctx 148 | .dispatch_recv( 149 | out_num_tokens_ptr as *mut i32, 150 | out_x_ptr as *mut c_void, 151 | out_x_stride, 152 | out_x_scale_ptr.map(|ptr| ptr as *mut c_void).unwrap_or(null_mut()), 153 | out_x_scale_stride_elem.unwrap_or(0), 154 | out_x_scale_stride_token.unwrap_or(0), 155 | stream, 156 | ) 157 | .map_err(|e| PyRuntimeError::new_err(e.to_string())) 158 | } 159 | 160 | #[allow(clippy::too_many_arguments)] 161 | fn combine_send( 162 | &mut self, 163 | expert_x_ptr: u64, 164 | expert_x_stride: usize, 165 | stream: u64, 166 | ) -> PyResult<()> { 167 | self.ctx 168 | .combine_send(expert_x_ptr as *const c_void, expert_x_stride, stream) 169 | .map_err(|e| PyRuntimeError::new_err(e.to_string())) 170 | } 171 | 172 | #[allow(clippy::too_many_arguments)] 173 | fn combine_recv( 174 | &mut self, 175 | num_tokens: usize, 176 | num_recv_tokens: usize, 177 | expert_y_dtype: ScalarType, 178 | out_tokens_ptr: u64, 179 | out_tokens_stride: usize, 180 | indices_ptr: u64, 181 | indices_stride: usize, 182 | weights_ptr: u64, 183 | weights_stride: usize, 184 | bound_m_ptr: Option, 185 | accumulate: bool, 186 | stream: u64, 187 | ) -> PyResult<()> { 188 | self.ctx 189 | .combine_recv( 190 | num_tokens, 191 | num_recv_tokens, 192 | expert_y_dtype, 193 | out_tokens_ptr as *mut c_void, 194 | out_tokens_stride, 195 | indices_ptr as *const i32, 196 | indices_stride, 197 | weights_ptr as *const f32, 198 | weights_stride, 199 | bound_m_ptr.map(|ptr| ptr as *const i32).unwrap_or(null()), 200 | accumulate, 201 | stream, 202 | ) 203 | .map_err(|e| PyRuntimeError::new_err(e.to_string())) 204 | } 205 | } 206 | 207 | pub fn init(m: &Bound<'_, PyModule>) -> PyResult<()> { 208 | m.add_class::()?; 209 | Ok(()) 210 | } 211 | -------------------------------------------------------------------------------- /p2p-all-to-all/a2a-kernels/src/core/device_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "core/common_utils.h" 4 | #include 5 | 6 | #define ROSE_ENABLE_DEVICE_ASSERT 0 7 | 8 | #if ROSE_ENABLE_DEVICE_ASSERT == 1 9 | #define ROSE_DEVICE_ASSERT(cond) \ 10 | do { \ 11 | if (!(cond)) { \ 12 | printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ 13 | asm("trap;"); \ 14 | } \ 15 | } while (0) 16 | #else 17 | #define ROSE_DEVICE_ASSERT(cond) 18 | #endif 19 | 20 | namespace rose { 21 | namespace device { 22 | 23 | // A wrapper for the kernels that is used to guard against compilation on 24 | // architectures that will never use the kernel. 25 | template struct enable_sm90_or_later : Kernel { 26 | template __device__ void operator()(Args &&...args) { 27 | #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 28 | Kernel::operator()(std::forward(args)...); 29 | #endif 30 | } 31 | }; 32 | 33 | __forceinline__ __device__ unsigned warp_sum(unsigned value) { 34 | value += __shfl_xor_sync(0xffffffff, value, 16); 35 | value += __shfl_xor_sync(0xffffffff, value, 8); 36 | value += __shfl_xor_sync(0xffffffff, value, 4); 37 | value += __shfl_xor_sync(0xffffffff, value, 2); 38 | value += __shfl_xor_sync(0xffffffff, value, 1); 39 | return value; 40 | } 41 | 42 | __forceinline__ __device__ bool warp_and(bool value) { 43 | value &= __shfl_xor_sync(0xffffffff, value, 16); 44 | value &= __shfl_xor_sync(0xffffffff, value, 8); 45 | value &= __shfl_xor_sync(0xffffffff, value, 4); 46 | value &= __shfl_xor_sync(0xffffffff, value, 2); 47 | value &= __shfl_xor_sync(0xffffffff, value, 1); 48 | return value; 49 | } 50 | 51 | __forceinline__ __device__ float half_warp_reduce_max(float value) { 52 | auto mask = __activemask(); 53 | value = max(value, __shfl_xor_sync(mask, value, 8)); 54 | value = max(value, __shfl_xor_sync(mask, value, 4)); 55 | value = max(value, __shfl_xor_sync(mask, value, 2)); 56 | value = max(value, __shfl_xor_sync(mask, value, 1)); 57 | return value; 58 | } 59 | 60 | __forceinline__ __device__ int get_lane_id() { 61 | int lane_id; 62 | asm("mov.s32 %0, %laneid;" : "=r"(lane_id)); 63 | return lane_id; 64 | } 65 | 66 | __forceinline__ __device__ uint32_t elect_one_sync() { 67 | #if __CUDA_ARCH__ >= 900 68 | uint32_t pred = 0; 69 | asm volatile( 70 | "{\n" 71 | ".reg .b32 %%rx;\n" 72 | ".reg .pred %%px;\n" 73 | " elect.sync %%rx|%%px, %1;\n" 74 | "@%%px mov.s32 %0, 1;\n" 75 | "}\n" 76 | : "+r"(pred) 77 | : "r"(0xffffffff)); 78 | return pred; 79 | #else 80 | return get_lane_id() == 0; 81 | #endif 82 | } 83 | 84 | __forceinline__ __device__ int last_active_lane(uint32_t mask) { 85 | return mask ? (31 - __clz(mask)) : 0; 86 | } 87 | 88 | __forceinline__ __device__ float warp_reduce_max(float val) { 89 | #pragma unroll 90 | for (int offset = 16; offset > 0; offset >>= 1) { 91 | val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset)); 92 | } 93 | return val; 94 | } 95 | 96 | __forceinline__ __device__ float warp_reduce_sum(float val) { 97 | #pragma unroll 98 | for (int offset = 16; offset > 0; offset >>= 1) { 99 | val += __shfl_down_sync(0xffffffff, val, offset); 100 | } 101 | return val; 102 | } 103 | 104 | __device__ inline float block_reduce_max(float val, float* smem, int tid, int block_size) { 105 | const int warp_id = tid / 32; 106 | const int lane = tid & 31; 107 | const int num_warps = (block_size + 31) / 32; 108 | 109 | val = warp_reduce_max(val); 110 | 111 | if (lane == 0) { 112 | smem[warp_id] = val; 113 | } 114 | __syncthreads(); 115 | 116 | if (warp_id == 0) { 117 | val = (lane < num_warps) ? smem[lane] : -CUDART_INF_F; 118 | val = warp_reduce_max(val); 119 | } 120 | 121 | if (tid == 0) { 122 | smem[0] = val; 123 | } 124 | __syncthreads(); 125 | 126 | return smem[0]; 127 | } 128 | 129 | __device__ inline float block_reduce_sum(float val, float* smem, int tid, int block_size) { 130 | const int warp_id = tid / 32; 131 | const int lane = tid & 31; 132 | const int num_warps = (block_size + 31) / 32; 133 | 134 | val = warp_reduce_sum(val); 135 | 136 | if (lane == 0) { 137 | smem[warp_id] = val; 138 | } 139 | __syncthreads(); 140 | 141 | if (warp_id == 0) { 142 | val = (lane < num_warps) ? smem[lane] : 0.0f; 143 | val = warp_reduce_sum(val); 144 | } 145 | 146 | if (tid == 0) { 147 | smem[0] = val; 148 | } 149 | __syncthreads(); 150 | 151 | return smem[0]; 152 | } 153 | 154 | __device__ inline void build_cdf_tiled( 155 | const float* probs, 156 | float* cdf, 157 | size_t vocab_size, 158 | float* smem_workspace, 159 | int tid, 160 | int block_size 161 | ) { 162 | const int num_warps = (block_size + 31) / 32; 163 | const int warp_id = tid / 32; 164 | const int lane = tid & 31; 165 | 166 | const size_t chunk_size = (vocab_size + num_warps - 1) / num_warps; 167 | const size_t chunk_start = static_cast(warp_id) * chunk_size; 168 | const size_t chunk_end = rose::min(chunk_start + chunk_size, vocab_size); 169 | 170 | float running_sum = 0.0f; 171 | 172 | for (size_t idx = chunk_start + lane; idx < chunk_end; idx += 32) { 173 | uint32_t mask = __activemask(); 174 | float warp_sum = probs[idx]; 175 | 176 | #pragma unroll 177 | for (int offset = 1; offset < 32; offset *= 2) { 178 | float n = __shfl_up_sync(mask, warp_sum, offset); 179 | if (lane >= offset) { 180 | warp_sum += n; 181 | } 182 | } 183 | 184 | warp_sum += running_sum; 185 | cdf[idx] = warp_sum; 186 | 187 | int last_lane = last_active_lane(mask); 188 | running_sum = __shfl_sync(mask, warp_sum, last_lane); 189 | } 190 | 191 | if (lane == 0) { 192 | smem_workspace[warp_id] = running_sum; 193 | } 194 | __syncthreads(); 195 | 196 | if (warp_id == 0) { 197 | uint32_t active_mask = __ballot_sync(0xffffffff, lane < num_warps); 198 | if (lane < num_warps) { 199 | float warp_total = smem_workspace[lane]; 200 | 201 | #pragma unroll 202 | for (int offset = 1; offset < 32; offset *= 2) { 203 | float n = __shfl_up_sync(active_mask, warp_total, offset); 204 | if (lane >= offset) { 205 | warp_total += n; 206 | } 207 | } 208 | 209 | smem_workspace[lane] = warp_total; 210 | } 211 | } 212 | __syncthreads(); 213 | 214 | if (warp_id > 0) { 215 | float prefix = smem_workspace[warp_id - 1]; 216 | for (size_t idx = chunk_start + lane; idx < chunk_end; idx += 32) { 217 | cdf[idx] += prefix; 218 | } 219 | } 220 | __syncthreads(); 221 | } 222 | 223 | __device__ inline int binary_search_cdf( 224 | const float* cdf, 225 | size_t vocab_size, 226 | float sample 227 | ) { 228 | int left = 0; 229 | int right = static_cast(vocab_size) - 1; 230 | 231 | while (left < right) { 232 | int mid = left + (right - left) / 2; 233 | 234 | if (cdf[mid] < sample) { 235 | left = mid + 1; 236 | } else { 237 | right = mid; 238 | } 239 | } 240 | 241 | return rose::min(rose::max(left, 0), static_cast(vocab_size) - 1); 242 | } 243 | 244 | } // namespace device 245 | } // namespace rose 246 | -------------------------------------------------------------------------------- /fabric-lib/src/api.rs: -------------------------------------------------------------------------------- 1 | //! Types used in public API 2 | 3 | use std::{ 4 | ffi::c_void, 5 | num::NonZeroU8, 6 | ptr::NonNull, 7 | sync::{ 8 | Arc, 9 | atomic::{AtomicI64, Ordering}, 10 | }, 11 | }; 12 | 13 | use bytes::Bytes; 14 | use cuda_lib::gdr::GdrFlag; 15 | use serde::{Deserialize, Serialize}; 16 | 17 | use crate::{ 18 | error::FabricLibError, 19 | utils::hex::{fmt_hex, from_hex}, 20 | }; 21 | 22 | pub type SmallVec = ::smallvec::SmallVec<[T; 4]>; 23 | 24 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 25 | #[repr(transparent)] 26 | pub struct MemoryRegionHandle { 27 | pub ptr: NonNull, 28 | } 29 | 30 | impl MemoryRegionHandle { 31 | pub fn new(ptr: NonNull) -> Self { 32 | MemoryRegionHandle { ptr } 33 | } 34 | } 35 | 36 | unsafe impl Send for MemoryRegionHandle {} 37 | unsafe impl Sync for MemoryRegionHandle {} 38 | 39 | /// A remote key for a memory region. 40 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] 41 | #[repr(transparent)] 42 | pub struct MemoryRegionRemoteKey(pub u64); 43 | 44 | #[derive(Debug, Clone, Serialize, Deserialize)] 45 | pub struct MemoryRegionDescriptor { 46 | pub ptr: u64, 47 | pub addr_rkey_list: SmallVec<(DomainAddress, MemoryRegionRemoteKey)>, 48 | } 49 | 50 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 51 | pub struct TransferId(pub u64); 52 | 53 | #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] 54 | pub struct DomainAddress(pub Bytes); 55 | 56 | impl std::fmt::Debug for DomainAddress { 57 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 58 | fmt_hex(f, &self.0) 59 | } 60 | } 61 | 62 | impl std::fmt::Display for DomainAddress { 63 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 64 | fmt_hex(f, &self.0) 65 | } 66 | } 67 | 68 | impl std::str::FromStr for DomainAddress { 69 | type Err = FabricLibError; 70 | 71 | fn from_str(s: &str) -> std::result::Result { 72 | if !s.len().is_multiple_of(2) || s.is_empty() { 73 | return Err(FabricLibError::Custom("Invalid address length")); 74 | } 75 | Ok(Self(from_hex(s).map_err(|_| FabricLibError::Custom("Invalid address"))?)) 76 | } 77 | } 78 | 79 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 80 | pub struct UvmWatcherId(pub(crate) NonNull); 81 | unsafe impl Send for UvmWatcherId {} 82 | unsafe impl Sync for UvmWatcherId {} 83 | impl UvmWatcherId { 84 | pub fn as_non_null(&self) -> NonNull { 85 | self.0 86 | } 87 | 88 | pub fn as_u64(&self) -> u64 { 89 | self.0.as_ptr() as u64 90 | } 91 | } 92 | 93 | /// Determines how to shard the transfer across domains. 94 | /// Relevant if NICs per GPU is greater than 1. 95 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] 96 | pub enum DomainGroupRouting { 97 | /// Shard the transfer across `num_shards` domains. 98 | /// Domains are selected in a round-robin manner. 99 | RoundRobinSharded { num_shards: NonZeroU8 }, 100 | /// Send the transfer via a specific domain. 101 | Pinned { domain_idx: u8 }, 102 | } 103 | 104 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] 105 | pub enum GroupTransferRouting { 106 | /// Use all domains. Each domain handles a subset of peers. 107 | AllDomainsShardPeers, 108 | /// Use all domains. Each domain handles all peers but a subset of bytes. 109 | AllDomainsShardBytes, 110 | /// Use a single domain of the given index. 111 | Single { domain_idx: u8 }, 112 | } 113 | 114 | #[derive(Clone, Debug)] 115 | pub struct ImmTransferRequest { 116 | pub imm_data: u32, 117 | pub dst_mr: MemoryRegionDescriptor, 118 | pub domain: DomainGroupRouting, 119 | } 120 | 121 | #[derive(Clone, Debug)] 122 | pub struct BarrierTransferRequest { 123 | pub imm_data: u32, 124 | pub dst_mrs: Vec, 125 | pub domain: DomainGroupRouting, 126 | } 127 | 128 | #[derive(Clone, Debug)] 129 | pub struct SingleTransferRequest { 130 | pub src_mr: MemoryRegionHandle, 131 | pub src_offset: u64, 132 | pub length: u64, 133 | pub imm_data: Option, 134 | pub dst_mr: MemoryRegionDescriptor, 135 | pub dst_offset: u64, 136 | pub domain: DomainGroupRouting, 137 | } 138 | 139 | #[derive(Clone, Debug)] 140 | pub struct PagedTransferRequest { 141 | pub length: u64, 142 | pub src_mr: MemoryRegionHandle, 143 | pub src_page_indices: Arc>, 144 | pub src_stride: u64, 145 | pub src_offset: u64, 146 | pub dst_mr: MemoryRegionDescriptor, 147 | pub dst_page_indices: Arc>, 148 | pub dst_stride: u64, 149 | pub dst_offset: u64, 150 | pub imm_data: Option, 151 | } 152 | 153 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 154 | pub struct PeerGroupHandle(pub(crate) u32); 155 | 156 | #[derive(Clone, Debug)] 157 | pub struct ScatterTarget { 158 | pub dst_mr: MemoryRegionDescriptor, 159 | pub length: u64, 160 | pub src_offset: u64, 161 | pub dst_offset: u64, 162 | } 163 | 164 | #[derive(Clone, Debug)] 165 | pub struct ScatterTransferRequest { 166 | pub src_mr: MemoryRegionHandle, 167 | /// You can get lower overhead by providing a PeerGroupHandle. 168 | /// When PeerGroupHandle is provided, the order of ScatterTarget needs to 169 | /// match the order of the peer group. 170 | pub dst_handle: Option, 171 | pub dsts: Arc>, 172 | pub imm_data: Option, 173 | pub domain: GroupTransferRouting, 174 | } 175 | 176 | /// Use static dispatch for performance. 177 | #[derive(Clone)] 178 | pub enum TransferRequest { 179 | Imm(ImmTransferRequest), 180 | Single(SingleTransferRequest), 181 | Paged(PagedTransferRequest), 182 | Scatter(ScatterTransferRequest), 183 | Barrier(BarrierTransferRequest), 184 | } 185 | 186 | #[derive(Debug)] 187 | pub enum TransferCompletionEntry { 188 | Recv { transfer_id: TransferId, data_len: usize }, 189 | Send(TransferId), 190 | Transfer(TransferId), 191 | ImmData(u32), 192 | ImmCountReached(u32), 193 | UvmWatch { id: UvmWatcherId, old: u64, new: u64 }, 194 | Error(TransferId, FabricLibError), 195 | } 196 | 197 | /// A free-range immediate counter exposed to users. 198 | #[derive(Clone)] 199 | pub struct ImmCounter { 200 | counter: Arc, 201 | } 202 | 203 | impl ImmCounter { 204 | pub fn new(counter: Arc) -> Self { 205 | Self { counter } 206 | } 207 | 208 | pub fn wait(&self, target: u32) { 209 | let old = self.counter.fetch_sub(target as i64, Ordering::Relaxed); 210 | if old >= target as i64 { 211 | return; 212 | } 213 | while self.counter.load(Ordering::Relaxed) < 0 { 214 | std::hint::spin_loop(); 215 | } 216 | } 217 | } 218 | 219 | /// An immediate counter that sets a flag via GdrCopy. 220 | #[derive(Clone)] 221 | pub struct GdrCounter { 222 | counter: Arc, 223 | flag: Arc, 224 | } 225 | 226 | impl GdrCounter { 227 | pub fn new(counter: Arc, flag: Arc) -> Self { 228 | Self { counter, flag } 229 | } 230 | 231 | pub fn wait(&self, target: u32) { 232 | let old = self.counter.fetch_sub(target as i64, Ordering::Relaxed); 233 | if old >= target as i64 { 234 | self.flag.set(true); 235 | } 236 | } 237 | } 238 | 239 | /// Transfer counter exposing a pollable interface to transfer completion. 240 | pub struct TransferCounter { 241 | counter: Arc, 242 | err_counter: Arc, 243 | } 244 | 245 | impl TransferCounter { 246 | pub fn new(counter: Arc, err_counter: Arc) -> Self { 247 | Self { counter, err_counter } 248 | } 249 | 250 | pub(crate) fn error(&self) { 251 | self.err_counter.fetch_add(1, Ordering::Release); 252 | self.counter.fetch_add(1, Ordering::Release); 253 | } 254 | 255 | pub(crate) fn done(&self) { 256 | self.counter.fetch_add(1, Ordering::Release); 257 | } 258 | } 259 | --------------------------------------------------------------------------------