├── docs ├── requirements.txt ├── tutorials.md ├── get_started.md ├── figs │ ├── tracer.png │ ├── timeline.png │ ├── comm-pattern.png │ └── comm-requiremnt.png ├── README.md ├── index.md ├── sphinx_util.py ├── efa.md ├── env.md ├── api.md ├── how_to.md ├── c++.md ├── backend.md └── python.md ├── fserver ├── __init__.py └── csrc │ ├── ops.cc │ ├── util.hpp │ ├── kernel.hpp │ ├── private.hpp │ └── wait_kernel.cu ├── tests ├── travis │ ├── travis_before_cache.sh │ ├── travis_setup_env.sh │ └── travis_script.sh ├── CMakeLists.txt ├── utests │ ├── ut_scheduler.cc │ ├── CMakeLists.txt │ ├── ut_server.cc │ ├── run_single_gpu_ut.sh │ ├── run_multi_gpu_ut.sh │ ├── ut_common.h │ ├── stepmesh_push_test.cc │ ├── ut_tensor_worker.cc │ ├── stepmesh_pull_test.cc │ ├── stepmesh_echo_test.cc │ ├── test_common.h │ └── stepmesh_register_test.cc ├── vllm │ ├── cycle.py │ ├── ffn.py │ └── attn.py ├── fserver │ ├── test_fserver.py │ ├── run_single_gpu.sh │ ├── test_utils.py │ ├── run_multi_gpu.sh │ ├── test_fserver_diff_stage.py │ └── test_fserver_dynamic.py └── run.sh ├── .clang-format ├── pyproject.toml ├── tools ├── check_diff.sh ├── format_code.sh ├── install_deps.sh └── cpplint.sh ├── src ├── backend │ ├── backend.cc │ └── cpu_backend.cc ├── windows │ └── unistd.h ├── meta.h ├── network_utils.cc ├── ibvwarp.h ├── van_common.h ├── customer.cc ├── resender.h └── fabric_utils.h ├── .gitignore ├── include ├── ps │ ├── internal │ │ ├── multi_qp.h │ │ ├── trace.h │ │ ├── cpu_backend.h │ │ ├── gpu_backend.h │ │ ├── parallel_sort.h │ │ ├── assign_op.h │ │ ├── env.h │ │ ├── backend.h │ │ ├── customer.h │ │ ├── parallel_kv_match.h │ │ ├── threadsafe_queue.h │ │ └── utils.h │ ├── range.h │ ├── base.h │ └── simple_app.h └── dmlc │ └── base.h ├── make ├── ps.mk └── deps.mk ├── cmake ├── Modules │ └── FindZMQ.cmake ├── External │ └── zmq.cmake └── ProtoBuf.cmake ├── tracker ├── README.md ├── dmlc_mpi.py ├── dmlc_local.py └── dmlc_ssh.py ├── .travis.yml ├── .github └── workflows │ └── main.yml ├── Makefile ├── CMakeLists.txt ├── setup.py └── README.md /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | breathe 2 | -------------------------------------------------------------------------------- /docs/tutorials.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | -------------------------------------------------------------------------------- /docs/get_started.md: -------------------------------------------------------------------------------- 1 | # Get Started 2 | -------------------------------------------------------------------------------- /fserver/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import fserver_lib as f 5 | -------------------------------------------------------------------------------- /docs/figs/tracer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/StepMesh/HEAD/docs/figs/tracer.png -------------------------------------------------------------------------------- /docs/figs/timeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/StepMesh/HEAD/docs/figs/timeline.png -------------------------------------------------------------------------------- /tests/travis/travis_before_cache.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # do nothing for now 3 | ls -alLR ${CACHE_PREFIX} -------------------------------------------------------------------------------- /docs/figs/comm-pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/StepMesh/HEAD/docs/figs/comm-pattern.png -------------------------------------------------------------------------------- /docs/figs/comm-requiremnt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/StepMesh/HEAD/docs/figs/comm-requiremnt.png -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | BasedOnStyle: Google 3 | --- 4 | # To format your code, run `clang-format -i -style=file my-van.cc` 5 | Language: Cpp 6 | ColumnLimit: 80 7 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=40.8.0", 4 | "wheel", 5 | "torch>=2.0.0" 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC -Wall -ldl -lpthread -g") 2 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -Wall -std=c++17 -ldl -lpthread -g") 3 | 4 | add_subdirectory(./utests) -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # StepMesh Docs 2 | 3 | - [Environment Variables](./env.md) 4 | - [Backends](./backend.md) 5 | - [APIs](./api.md) 6 | - [StepMesh C++ API](./c++.md) 7 | - [StepMesh Python API](./python.md) 8 | -------------------------------------------------------------------------------- /tests/utests/ut_scheduler.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Step AI 3 | */ 4 | 5 | #include "ut_common.h" 6 | 7 | int main(int argc, char *argv[]) { 8 | StartPS(0, Node::SCHEDULER, -1, true); 9 | Finalize(0, Node::SCHEDULER, true); 10 | return 0; 11 | } 12 | -------------------------------------------------------------------------------- /tools/check_diff.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | bash tools/format_code.sh 5 | 6 | git diff 7 | diff_lines=$(git diff --stat|wc -l) 8 | if [[ $diff_lines -gt 0 ]]; then 9 | echo "Found diff: $diff_lines. run tools/format_code.sh to format your code first." 10 | exit 1 11 | else 12 | echo "No Diff" 13 | exit 0 14 | fi 15 | -------------------------------------------------------------------------------- /tools/format_code.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | bin=clang-format 4 | 5 | set -ex 6 | 7 | find src -name '*.cc' -o -name '*.h' ! -name 'ucx_van.h' | xargs $bin -i -style=file 8 | find include -name '*.h' ! -name spsc_queue.h ! -name logging.h ! -name base.h ! -name parallel_*.h | xargs $bin -i -style=file 9 | 10 | echo "format code done" 11 | -------------------------------------------------------------------------------- /src/backend/backend.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (C) by StepAI Contributors. 2025. 3 | */ 4 | 5 | #include "ps/internal/backend.h" 6 | 7 | #include 8 | #include 9 | 10 | namespace ps { 11 | 12 | std::mutex Backend::backends_mutex_; 13 | std::unordered_map Backend::backends_; 14 | 15 | } // namespace ps 16 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # PS-Lite Documents 2 | 3 | PS-Lite is a lightweight implementation of the parameter server. It provides 4 | asynchronous and zero-copy key-value pair communications between machines. 5 | 6 | 7 | ```eval_rst 8 | .. toctree:: 9 | :numbered: 10 | 11 | overview 12 | get_started 13 | tutorials 14 | how_to 15 | api 16 | ``` 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | .bundle 3 | bundle_lib 4 | deps 5 | *.pb.* 6 | .* 7 | *core* 8 | docs/html 9 | docs/xml 10 | docs/latex 11 | docs/_build 12 | *.pyc 13 | *.d 14 | 15 | cmake-build-debug/ 16 | CMakeFiles/ 17 | recommonmark/ 18 | 19 | test_benchmark 20 | stress_test_benchmark 21 | test 22 | .idea 23 | cmake_build 24 | 25 | *.so 26 | *egg-info 27 | dist/ 28 | dist/* -------------------------------------------------------------------------------- /fserver/csrc/ops.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2025, StepFun Authors. All rights reserved. */ 2 | 3 | 4 | #include "./util.hpp" 5 | 6 | #include "./public.hpp" 7 | #ifdef DMLC_USE_CUDA 8 | #include "./private.hpp" 9 | #include "./kernel.hpp" 10 | #endif 11 | 12 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 13 | pybind_public(m); 14 | #ifdef DMLC_USE_CUDA 15 | pybind_private(m); 16 | pybind_kernel(m); 17 | #endif 18 | } 19 | -------------------------------------------------------------------------------- /include/ps/internal/multi_qp.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (C) by StepAI Contributors. 2025. 3 | */ 4 | #ifndef PS_INTERNAL_MULTI_QP_H_ 5 | #define PS_INTERNAL_MULTI_QP_H_ 6 | 7 | #include "ps/internal/utils.h" 8 | 9 | // Limited by RDMA Write Inline Size 10 | #define QP_MAX_NUM 2 11 | 12 | static int QP_NUM = ps::GetEnv("STEPAF_QP_NUM", 2); // Number of QPs 13 | 14 | #define FOR_QPS for (int qpIndex = 0; qpIndex < QP_NUM; qpIndex++) 15 | 16 | #endif // PS_INTERNAL_MULTI_QP_H_ 17 | -------------------------------------------------------------------------------- /make/ps.mk: -------------------------------------------------------------------------------- 1 | #--------------------------------------------------------------------------------------- 2 | # parameter server configuration script 3 | # 4 | # include ps.mk after the variables are set 5 | # 6 | #---------------------------------------------------------------------------------------- 7 | 8 | ifeq ($(USE_KEY32), 1) 9 | ADD_CFLAGS += -DUSE_KEY32=1 10 | endif 11 | 12 | PS_LDFLAGS_SO = -L$(DEPS_PATH)/lib -lzmq 13 | PS_LDFLAGS_A = $(addprefix $(DEPS_PATH)/lib/, libzmq.a) 14 | -------------------------------------------------------------------------------- /tests/travis/travis_setup_env.sh: -------------------------------------------------------------------------------- 1 | # script to be sourced in travis yml 2 | # setup all enviroment variables 3 | 4 | export CACHE_PREFIX=${HOME}/.cache/usr 5 | export PATH=${HOME}/.local/bin:${PATH} 6 | export PATH=${PATH}:${CACHE_PREFIX}/bin 7 | export CPLUS_INCLUDE_PATH=${CPLUS_INCLUDE_PATH}:${CACHE_PREFIX}/include 8 | export C_INCLUDE_PATH=${C_INCLUDE_PATH}:${CACHE_PREFIX}/include 9 | export LIBRARY_PATH=${LIBRARY_PATH}:${CACHE_PREFIX}/lib 10 | export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${CACHE_PREFIX}/lib 11 | 12 | # alias make="make -j4" 13 | -------------------------------------------------------------------------------- /include/ps/internal/trace.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (C) by StepAI Contributors. 2025. 3 | */ 4 | #ifndef PS_INTERNAL_TRACE_H_ 5 | #define PS_INTERNAL_TRACE_H_ 6 | 7 | namespace ps { 8 | 9 | struct Trace { 10 | uint64_t pre_start = 0; 11 | // on start of a request/response 12 | uint64_t start = 0; 13 | // on rdma post send 14 | uint64_t postsend = 0; 15 | // on receive an reqiest/response 16 | uint64_t postrecv = 0; 17 | // on processed 18 | uint64_t process = 0; 19 | }; 20 | 21 | } // namespace ps 22 | 23 | #endif // PS_INTERNAL_TRACE_H_ 24 | -------------------------------------------------------------------------------- /include/ps/range.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef PS_RANGE_H_ 5 | #define PS_RANGE_H_ 6 | #include "ps/internal/utils.h" 7 | namespace ps { 8 | 9 | /** 10 | * \brief a range [begin, end) 11 | */ 12 | class Range { 13 | public: 14 | Range() : Range(0, 0) {} 15 | Range(uint64_t begin, uint64_t end) : begin_(begin), end_(end) {} 16 | 17 | uint64_t begin() const { return begin_; } 18 | uint64_t end() const { return end_; } 19 | uint64_t size() const { return end_ - begin_; } 20 | 21 | private: 22 | uint64_t begin_; 23 | uint64_t end_; 24 | }; 25 | 26 | } // namespace ps 27 | #endif // PS_RANGE_H_ 28 | -------------------------------------------------------------------------------- /docs/sphinx_util.py: -------------------------------------------------------------------------------- 1 | import sys, os, subprocess 2 | 3 | 4 | if not os.path.exists('../recommonmark'): 5 | subprocess.call('cd ..; git clone https://github.com/tqchen/recommonmark', shell = True) 6 | else: 7 | subprocess.call('cd ../recommonmark; git pull', shell=True) 8 | 9 | sys.path.insert(0, os.path.abspath('../recommonmark/')) 10 | 11 | from recommonmark import parser, transform 12 | MarkdownParser = parser.CommonMarkParser 13 | AutoStructify = transform.AutoStructify 14 | 15 | # MarkdownParser.github_doc_root = github_doc_root 16 | 17 | def generate_doxygen_xml(app): 18 | """Run the doxygen make commands""" 19 | subprocess.call('doxygen') 20 | -------------------------------------------------------------------------------- /cmake/Modules/FindZMQ.cmake: -------------------------------------------------------------------------------- 1 | # - Try to find ZMQ 2 | # Once done this will define 3 | # ZMQ_FOUND - System has ZMQ 4 | # ZMQ_INCLUDE_DIRS - The ZMQ include directories 5 | # ZMQ_LIBRARIES - The libraries needed to use ZMQ 6 | # ZMQ_DEFINITIONS - Compiler switches required for using ZMQ 7 | 8 | find_path ( ZMQ_INCLUDE_DIR zmq.h ) 9 | find_library ( ZMQ_LIBRARY NAMES zmq ) 10 | 11 | set ( ZMQ_LIBRARIES ${ZMQ_LIBRARY} ) 12 | set ( ZMQ_INCLUDE_DIRS ${ZMQ_INCLUDE_DIR} ) 13 | 14 | include ( FindPackageHandleStandardArgs ) 15 | # handle the QUIETLY and REQUIRED arguments and set ZMQ_FOUND to TRUE 16 | # if all listed variables are TRUE 17 | find_package_handle_standard_args ( ZMQ DEFAULT_MSG ZMQ_LIBRARY ZMQ_INCLUDE_DIR ) -------------------------------------------------------------------------------- /docs/efa.md: -------------------------------------------------------------------------------- 1 | ## Build with libfabric for elastic fabric accelerator 2 | 3 | AMI: Base Deep Learning AMI (ubuntu/AML) 4 | 5 | 1. install gcc-4.9 first 6 | 7 | ``` 8 | set -e 9 | 10 | wget https://ftp.gnu.org/gnu/gcc/gcc-4.9.3/gcc-4.9.3.tar.gz 11 | tar xzf gcc-4.9.3.tar.gz 12 | cd gcc-4.9.3 13 | ./contrib/download_prerequisites 14 | ./configure --disable-multilib --enable-languages=c,c++ 15 | make -j$(nproc) 16 | sudo make install 17 | ``` 18 | 19 | 2. build ps-lite 20 | ``` 21 | make clean; USE_FABRIC=1 make -j; 22 | ``` 23 | 24 | 3. run tests 25 | ``` 26 | DMLC_INTERFACE=eth0 DMLC_PS_ROOT_URI=ROOT_IP DMLC_ENABLE_RDMA=fabric bash tests/local_multi_workers.sh 1 1 tests/test_benchmark 4096000 100 1 27 | ``` 28 | -------------------------------------------------------------------------------- /tools/install_deps.sh: -------------------------------------------------------------------------------- 1 | sudo apt-get update 2 | sudo apt install -y build-essential libtool autoconf automake libnuma-dev unzip pkg-config librdmacm-dev rdma-core make cmake python3-pip 3 | 4 | THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )"/../ &> /dev/null && pwd )" 5 | 6 | rm -rf zeromq-4.1.4.tar.gz zeromq-4.1.4 7 | 8 | wget https://raw.githubusercontent.com/mli/deps/master/build/zeromq-4.1.4.tar.gz 9 | tar --no-same-owner -zxf zeromq-4.1.4.tar.gz 10 | pushd zeromq-4.1.4 || exit 11 | export CFLAGS=-fPIC 12 | export CXXFLAGS=-fPIC 13 | 14 | ./configure -prefix=${THIS_DIR}/deps/ --with-libsodium=no --with-libgssapi_krb5=no 15 | make -j 16 | make install 17 | popd || exit 18 | 19 | rm -rf zeromq-4.1.4.tar.gz zeromq-4.1.4 20 | -------------------------------------------------------------------------------- /tracker/README.md: -------------------------------------------------------------------------------- 1 | tracker 2 | ==== 3 | 4 | This folder contains tracker scripts that can be used to submit jobs to 5 | different distributed platforms. 6 | 7 | (Refactor of https://github.com/dmlc/dmlc-core/tree/master/tracker will merged 8 | back when ready) 9 | 10 | ## How to use 11 | 12 | Assume `prog` is an execuable program, and `[args...]` are the possible 13 | arguments that `prog` accepts. 14 | 15 | ### Local machine 16 | 17 | Run a job using 4 workers and 2 servers on the local machine: 18 | 19 | ```bash 20 | ./dmlc_local.py -s 2 -n 4 ./prog [args...] 21 | ``` 22 | 23 | ### Launch via `mpirun` 24 | 25 | If `mpirun` is available (shipped with `openmpi` or `mpich2`), we can launch a 26 | job using `dmlc_mpi.py`. 27 | -------------------------------------------------------------------------------- /fserver/csrc/util.hpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2025, StepFun Authors. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include "ps/ps.h" 20 | 21 | #ifndef UTIL_H_ 22 | typedef std::tuple, std::vector> 23 | ServerDataBatch; 24 | #define UTIL_H_ 25 | typedef std::tuple, std::vector> 26 | ServerDataBatch; 27 | #endif // UTIL_H_ 28 | -------------------------------------------------------------------------------- /fserver/csrc/kernel.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | 9 | torch::Tensor map_pinned_tensor(torch::Tensor tensor, int64_t device_index); 10 | void write_flag(torch::Tensor flag, torch::Tensor seq); 11 | void wait_flag(torch::Tensor flag, torch::Tensor seq); 12 | void seq_add_one(torch::Tensor seq); 13 | 14 | void pybind_kernel(py::module &m){ 15 | // StepMesh utils 16 | m.def("map_pinned_tensor", &map_pinned_tensor, py::arg("tensor"), py::arg("device_index")); 17 | m.def("write_flag", &write_flag, py::arg("flag"), py::arg("seq")); 18 | m.def("wait_flag", &wait_flag, py::arg("flag"), py::arg("seq")); 19 | m.def("seq_add_one", &seq_add_one, py::arg("seq")); 20 | } -------------------------------------------------------------------------------- /tests/travis/travis_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # main script of travis 3 | 4 | if [ ${TASK} == "lint" ]; then 5 | make lint || exit -1 6 | fi 7 | 8 | if [ ${TASK} == "build" ]; then 9 | make DEPS_PATH=${CACHE_PREFIX} CXX=${CXX} || exit -1 10 | fi 11 | 12 | if [ ${TASK} == "test" ]; then 13 | make test DEPS_PATH=${CACHE_PREFIX} CXX=${CXX} || exit -1 14 | cd tests 15 | # single-worker tests 16 | tests=( test_connection test_kv_app test_simple_app ) 17 | for test in "${tests[@]}" 18 | do 19 | find $test -type f -executable -exec ./repeat.sh 4 ./local.sh 2 2 ./{} \; 20 | done 21 | # multi-workers test 22 | multi_workers_tests=( test_kv_app_multi_workers ) 23 | for test in "${multi_workers_tests[@]}" 24 | do 25 | find $test -type f -executable -exec ./repeat.sh 4 ./local_multi_workers.sh 2 2 ./{} \; 26 | done 27 | fi 28 | -------------------------------------------------------------------------------- /include/ps/base.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef PS_BASE_H_ 5 | #define PS_BASE_H_ 6 | #include 7 | 8 | #include "ps/internal/utils.h" 9 | namespace ps { 10 | 11 | /*! \brief Use unsigned 64-bit int as the key type */ 12 | using Key = uint64_t; 13 | /*! \brief The maximal allowed key value */ 14 | static const Key kMaxKey = std::numeric_limits::max(); 15 | /** \brief node ID for the scheduler */ 16 | static const int kScheduler = 1; 17 | /** 18 | * \brief the server node group ID 19 | * 20 | * group id can be combined: 21 | * - kServerGroup + kScheduler means all server nodes and the scheuduler 22 | * - kServerGroup + kWorkerGroup means all server and worker nodes 23 | */ 24 | static const int kServerGroup = 2; 25 | /** \brief the worker node group ID */ 26 | static const int kWorkerGroup = 4; 27 | 28 | } // namespace ps 29 | #endif // PS_BASE_H_ 30 | -------------------------------------------------------------------------------- /include/ps/internal/cpu_backend.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (C) by StepAI Contributors. 2025. 3 | */ 4 | 5 | #ifndef PS_INTERNAL_CPU_BACKEND_H_ 6 | #define PS_INTERNAL_CPU_BACKEND_H_ 7 | 8 | #include 9 | 10 | #include "ps/internal/backend.h" 11 | 12 | namespace ps { 13 | 14 | class CpuBackend : public Backend { 15 | public: 16 | int SetDevice(int dev) override; 17 | int GetDeviceId() override; 18 | at::Device GetDevice() override; 19 | void* Alloc(uint64_t size) override; 20 | void Free(void* m) override; 21 | void* CreateEvent() override; 22 | int FreeEvent(void* event) override; 23 | int RecordEvent(void* event, void* stream) override; 24 | int SyncEvent(void* event) override; 25 | 26 | private: 27 | /** \brief for cpu backend, the device stands for numa id */ 28 | int numa_id_ = -1; 29 | }; 30 | 31 | } // namespace ps 32 | 33 | #endif // PS_INTERNAL_CPU_BACKEND_H_ 34 | -------------------------------------------------------------------------------- /tools/cpplint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cpplint --root=./src --filter=-whitespace/indent_namespace,-runtime/references \ 3 | --exclude=./src/zmq_van.h \ 4 | --exclude=./src/windows/* \ 5 | --exclude=./src/ucx_van.h \ 6 | --exclude=./src/tp_van.h \ 7 | --exclude=./src/resender.h \ 8 | --exclude=./src/multi_van.h \ 9 | --exclude=./src/fabric_van.h \ 10 | --exclude=./src/fabric_utils.h \ 11 | --exclude=./src/fabric_transport.h \ 12 | --recursive ./src 13 | cpplint --root=./include --filter=-whitespace/indent_namespace,-runtime/references \ 14 | --exclude=./include/ps/internal/spsc_queue.h \ 15 | --exclude=./include/ps/internal/parallel_sort.h \ 16 | --exclude=./include/ps/internal/parallel_kv_match.h \ 17 | --exclude=include/dmlc/* \ 18 | --recursive ./include 19 | cpplint --root=./fserver/csrc --filter=-whitespace/indent_namespace,-runtime/references \ 20 | --recursive ./fserver -------------------------------------------------------------------------------- /src/backend/cpu_backend.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (C) by StepAI Contributors. 2025. 3 | */ 4 | 5 | #include "ps/internal/cpu_backend.h" 6 | 7 | #include "ps/internal/backend.h" 8 | 9 | namespace ps { 10 | 11 | int CpuBackend::SetDevice(int dev) { 12 | PS_CHECK_GE(dev, 0) << "cannot set dev=" << dev << " for cpu backend"; 13 | numa_id_ = dev; 14 | return BACKEND_OK; 15 | } 16 | 17 | int CpuBackend::GetDeviceId() { return numa_id_; } 18 | 19 | at::Device CpuBackend::GetDevice() { return at::Device(at::kCPU); } 20 | 21 | void* CpuBackend::Alloc(uint64_t size) { return malloc(size); } 22 | 23 | void CpuBackend::Free(void* m) { 24 | PS_CHECK_NE(m, nullptr) << "cpu backend cannot free null memory"; 25 | free(m); 26 | } 27 | 28 | void* CpuBackend::CreateEvent() { return nullptr; } 29 | 30 | int CpuBackend::FreeEvent(void* event) { return BACKEND_OK; } 31 | 32 | int CpuBackend::RecordEvent(void* event, void* stream) { return BACKEND_OK; } 33 | 34 | int CpuBackend::SyncEvent(void* event) { return BACKEND_OK; } 35 | 36 | } // namespace ps 37 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # disable sudo to use container based build 2 | sudo: false 3 | 4 | # Use Build Matrix to do lint and build seperately 5 | env: 6 | matrix: 7 | - TASK=lint 8 | - TASK=test CXX=g++-4.8 9 | # - TASK=build CXX=g++-4.8 10 | # - TASK=build CXX=g++-5 11 | 12 | # dependent apt packages 13 | addons: 14 | apt: 15 | sources: 16 | - ubuntu-toolchain-r-test 17 | packages: 18 | - g++-4.8 19 | # - g++-5 20 | - wget 21 | - git 22 | # - unzip 23 | 24 | before_install: 25 | - export TRAVIS=tests/travis 26 | - source ${TRAVIS}/travis_setup_env.sh 27 | 28 | 29 | install: 30 | - pip install cpplint pylint --user `whoami` 31 | 32 | 33 | script: ${TRAVIS}/travis_script.sh 34 | 35 | 36 | before_cache: 37 | - ${TRAVIS}/travis_before_cache.sh 38 | 39 | cache: 40 | directories: 41 | - ${HOME}/.cache/usr 42 | 43 | 44 | notifications: 45 | # Emails are sent to the committer's git-configured email address by default, 46 | email: 47 | on_success: change 48 | on_failure: always 49 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-22.04 8 | # container: 9 | # image: pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel 10 | strategy: 11 | matrix: 12 | os: [ubuntu-22.04] 13 | python-version: ['3.10'] 14 | torch-version: [2.6.0] 15 | cuda-version: ['cu124'] 16 | 17 | steps: 18 | # Add a step to set up Python 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v5 # Use a more recent version of actions/setup-python 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Add dependencies 24 | run: sudo apt-get install -y clang-format 25 | - uses: actions/checkout@v4 # Use a more recent version of actions/checkout 26 | - name: Check Lint 27 | run: bash tools/check_diff.sh 28 | # - name: Install dependencies 29 | # run: bash tools/install_deps.sh 30 | # - name: Build af 31 | # run: make -j af 32 | # - name: Build pip 33 | # run: python3 setup.py build 34 | -------------------------------------------------------------------------------- /tests/vllm/cycle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | def get_cycles_per_ms() -> float: 4 | """Measure and return approximate number of cycles per millisecond for torch.cuda._sleep 5 | """ 6 | 7 | def measure() -> float: 8 | start = torch.cuda.Event(enable_timing=True) 9 | end = torch.cuda.Event(enable_timing=True) 10 | start.record() 11 | torch.cuda._sleep(1000000) 12 | end.record() 13 | end.synchronize() 14 | cycles_per_ms = 1000000 / start.elapsed_time(end) 15 | return cycles_per_ms 16 | 17 | # Get 10 values and remove the 2 max and 2 min and return the avg. 18 | # This is to avoid system disturbance that skew the results, e.g. 19 | # the very first cuda call likely does a bunch of init, which takes 20 | # much longer than subsequent calls. 21 | # 22 | # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs 23 | # and seems to return stable values. Therefore, we enable caching 24 | # using lru_cache decorator above. 25 | num = 10 26 | vals = [measure() for _ in range(num)] 27 | vals = sorted(vals) 28 | return np.mean(vals[2 : num - 2]) -------------------------------------------------------------------------------- /tests/utests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC -Wall -ldl -lpthread -g") 2 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -Wall -std=c++17 -ldl -lpthread -g") 3 | 4 | add_executable(ut_scheduler 5 | ut_scheduler.cc 6 | ut_common.h) 7 | target_link_libraries(ut_scheduler 8 | af) 9 | 10 | add_executable(ut_server 11 | ut_server.cc 12 | ut_common.h) 13 | target_link_libraries(ut_server 14 | af) 15 | 16 | add_executable(ut_tensor_worker 17 | ut_tensor_worker.cc 18 | ut_common.h) 19 | target_link_libraries(ut_tensor_worker 20 | af) 21 | 22 | add_executable(stepmesh_echo_test stepmesh_echo_test.cc test_common.h) 23 | target_link_libraries(stepmesh_echo_test dl pthread m af) 24 | 25 | add_executable(stepmesh_push_test stepmesh_push_test.cc test_common.h) 26 | target_link_libraries(stepmesh_push_test dl pthread m af) 27 | 28 | add_executable(stepmesh_pull_test stepmesh_pull_test.cc test_common.h) 29 | target_link_libraries(stepmesh_pull_test dl pthread m af) 30 | 31 | add_executable(stepmesh_register_test stepmesh_register_test.cc test_common.h) 32 | target_link_libraries(stepmesh_register_test dl pthread m af) -------------------------------------------------------------------------------- /tests/fserver/test_fserver.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | import time 3 | import fserver_lib as f 4 | is_worker = os.environ.get('DMLC_ROLE') == 'worker' 5 | is_server = os.environ.get('DMLC_ROLE') == 'server' 6 | 7 | f.init() 8 | 9 | if is_worker: 10 | gpu = os.environ.get('STEPMESH_GPU') 11 | push_tensors = [ 12 | torch.rand([1, 8192], dtype=torch.float32, device=f'cuda:{gpu}'), 13 | torch.rand([1, 8192], dtype=torch.float32, device=f'cuda:{gpu}'), 14 | torch.rand([1, 8192], dtype=torch.float32, device=f'cuda:{gpu}'), 15 | ] 16 | pull_tensors = [ 17 | torch.rand([1, 8192], dtype=torch.float32, device=f'cuda:{gpu}') 18 | ] 19 | handler = f.push_pull( 20 | push_tensors, 21 | [i for i in range(len(push_tensors))], 22 | pull_tensors, 23 | [i for i in range(len(pull_tensors))] 24 | ) 25 | f.wait(handler) 26 | assert torch.allclose(sum(push_tensors), pull_tensors[0]) 27 | print("worker test done") 28 | 29 | elif is_server: 30 | gpu = os.environ.get('STEPMESH_GPU') 31 | torch.set_default_device('cuda:{}'.format(gpu)) 32 | res = [] 33 | while len(res) == 0: 34 | time.sleep(1) 35 | res = f.get_batch() 36 | print(res) 37 | for r in res: 38 | comm_id, batch, _ = r 39 | f.respond([sum(batch)], comm_id) 40 | 41 | f.stop() 42 | -------------------------------------------------------------------------------- /tests/utests/ut_server.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Step AI 3 | */ 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #include "ut_common.h" 12 | 13 | std::vector g_tensors; 14 | 15 | void SumHandler(const AFTensorMeta& req_meta, AFTensorServer* server) { 16 | auto sum = at::zeros_like(req_meta.push_tensors[0].val, 17 | req_meta.push_tensors[0].val.scalar_type()); 18 | for (auto t : req_meta.push_tensors) { 19 | sum += t.val; 20 | } 21 | g_tensors.push_back(sum); 22 | server->Response(req_meta, {{ req_meta.pull_tensors[0].key, sum }}); 23 | } 24 | 25 | void StartFFNServer() { 26 | PS_LOG(INFO) << "run server over gpu " << g_conf.gpu; 27 | StartPS(0, Node::SERVER, g_conf.gpu, true); 28 | AFTensorServer* server = new AFTensorServer(g_conf.gpu); 29 | server->SetRequestHandle(SumHandler); 30 | ps::Postoffice::GetServer(g_conf.gpu)->Barrier( 31 | 0, ps::kServerGroup + ps::kWorkerGroup); 32 | RegisterExitCallback([server]() { delete server; }); 33 | Finalize(0, Node::SERVER, true); 34 | PS_LOG(INFO) << "FFN server ends"; 35 | } 36 | 37 | int main(int argc, char *argv[]) { 38 | InitUtestConfig(); 39 | PS_LOG(INFO) << "AF-Communication utest server: gpu=" 40 | << g_conf.gpu; 41 | StartFFNServer(); 42 | return 0; 43 | } 44 | -------------------------------------------------------------------------------- /tests/utests/run_single_gpu_ut.sh: -------------------------------------------------------------------------------- 1 | export SCHEDULER_BIN=${SCHEDULER_BIN:-./cmake_build/tests/utests/ut_scheduler} 2 | export SERVER_BIN=${SERVER_BIN:-./cmake_build/tests/utests/ut_server} 3 | export WORKER_BIN=${WORKER_BIN:-./cmake_build/tests/utests/ut_tensor_worker} 4 | 5 | export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH 6 | 7 | cmd=$1 8 | echo $cmd 9 | set -x 10 | 11 | function cleanup() { 12 | echo "kill all testing process of ps lite for user $USER" 13 | pkill -9 -f $SERVER_BIN 14 | pkill -9 -f $SCHEDULER_BIN 15 | pkill -9 -f $WORKER_BIN 16 | sleep 1 17 | } 18 | trap cleanup EXIT 19 | 20 | export DMLC_NUM_WORKER=${DMLC_NUM_WORKER:-1} 21 | export DMLC_NUM_SERVER=1 22 | export DMLC_INTERFACE=brainpf_bond0 # my RDMA interface 23 | export DMLC_PS_ROOT_URI=$(ip -o -4 addr | grep ${DMLC_INTERFACE} | awk '{print $4}' | cut -d'/' -f1) 24 | export DMLC_PS_ROOT_PORT=${DMLC_PS_ROOT_PORT:-12278} # scheduler's port (can random choose) 25 | export STEPMESH_SPLIT_QP_LAG=1 26 | export DMLC_NODE_RANK=0 27 | export DMLC_ENABLE_RDMA=ibverbs 28 | 29 | echo "SCHEDULER_IP is ${DMLC_PS_ROOT_URI}" 30 | 31 | # # launch scheduler 32 | export DMLC_NODE_HOST=${DMLC_PS_ROOT_URI} 33 | 34 | cleanup 35 | DMLC_ROLE=scheduler $SCHEDULER_BIN & 36 | 37 | export STEPMESH_GPU=0 38 | DMLC_ROLE=server $SERVER_BIN & 39 | 40 | sleep 1 41 | 42 | export STEPMESH_GPU=0 43 | export DMLC_INTERFACE=auto 44 | 45 | DMLC_ROLE=worker $WORKER_BIN 46 | -------------------------------------------------------------------------------- /tests/fserver/run_single_gpu.sh: -------------------------------------------------------------------------------- 1 | THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 2 | function cleanup() { 3 | echo "kill all testing process of ps lite for user $USER" 4 | # pkill -9 -f test_bench 5 | pkill -9 -f test_remote_moe 6 | pkill -9 -f test_fserver 7 | sleep 1 8 | } 9 | trap cleanup EXIT 10 | # cleanup 11 | 12 | # common setup 13 | export BIN=${BIN:-test_fserver} 14 | # export DMLC_INTERFACE=${RNIC:-brainpf_bond0} 15 | export SCHEDULER_IP=$(ip -o -4 addr | grep ${RNIC} | awk '{print $4}' | cut -d'/' -f1) 16 | export DMLC_NUM_WORKER=1 17 | export DMLC_NUM_SERVER=1 18 | export DMLC_PS_ROOT_URI=$SCHEDULER_IP # scheduler's RDMA interface IP 19 | export DMLC_PS_ROOT_PORT=8123 # scheduler's port (can random choose) 20 | export DMLC_ENABLE_RDMA=ibverbs 21 | export DMLC_INTERFACE=auto 22 | # export STEPMESH_BIND_CPU_CORE=1 23 | 24 | export DMLC_NODE_HOST=${SCHEDULER_IP} 25 | export DMLC_INTERFACE=auto 26 | export STEPMESH_SPLIT_QP_LAG=0 27 | export STEPMESH_BIND_CPU_CORE=1 28 | export STEPMESH_GPU=0 29 | export PS_VERBOSE=1 30 | 31 | DMLC_ROLE=scheduler numactl -m 0 python3 $THIS_DIR/$BIN.py & 32 | export STEPMESH_CPU_START_OFFSET=10 33 | DMLC_ROLE=server numactl -m 0 python3 $THIS_DIR/$BIN.py $@ & 34 | # DMLC_ROLE=worker python3 $THIS_DIR/$BIN.py $@ & 35 | # export STEPMESH_DROP_RATE=1 36 | export STEPMESH_CPU_START_OFFSET=15 37 | DMLC_ROLE=worker numactl -m 0 python3 $THIS_DIR/$BIN.py $@ 38 | 39 | wait 40 | -------------------------------------------------------------------------------- /tests/utests/run_multi_gpu_ut.sh: -------------------------------------------------------------------------------- 1 | export SCHEDULER_BIN=${SCHEDULER_BIN:-./cmake_build/tests/utests/ut_scheduler} 2 | export SERVER_BIN=${SERVER_BIN:-./cmake_build/tests/utests/ut_server} 3 | export WORKER_BIN=${WORKER_BIN:-./cmake_build/tests/utests/ut_tensor_worker} 4 | 5 | export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH 6 | 7 | cmd=$1 8 | echo $cmd 9 | set -x 10 | 11 | function cleanup() { 12 | echo "kill all testing process of ps lite for user $USER" 13 | pkill -9 -f $SERVER_BIN 14 | pkill -9 -f $SCHEDULER_BIN 15 | pkill -9 -f $WORKER_BIN 16 | sleep 1 17 | } 18 | trap cleanup EXIT 19 | 20 | export DMLC_NUM_WORKER=${DMLC_NUM_WORKER:-1} 21 | export DMLC_NUM_SERVER=1 22 | export DMLC_INTERFACE=brainpf_bond0 # my RDMA interface 23 | export DMLC_PS_ROOT_URI=$(ip -o -4 addr | grep ${DMLC_INTERFACE} | awk '{print $4}' | cut -d'/' -f1) 24 | export DMLC_PS_ROOT_PORT=${DMLC_PS_ROOT_PORT:-12278} # scheduler's port (can random choose) 25 | export STEPMESH_SPLIT_QP_LAG=0 26 | export DMLC_ENABLE_RDMA=ibverbs 27 | export DMLC_GROUP_SIZE=8 28 | export DMLC_NODE_RANK=0 29 | 30 | echo "SCHEDULER_IP is ${DMLC_PS_ROOT_URI}" 31 | 32 | # # launch scheduler 33 | export DMLC_NODE_HOST=${DMLC_PS_ROOT_URI} 34 | 35 | cleanup 36 | 37 | DMLC_ROLE=scheduler $SCHEDULER_BIN & 38 | 39 | export DMLC_INTERFACE=auto 40 | for P in {0..7}; do 41 | DMLC_ROLE=server STEPMESH_GPU=${P} $SERVER_BIN & 42 | done 43 | 44 | sleep 1 45 | 46 | for P in {0..7}; do 47 | DMLC_ROLE=worker STEPMESH_GPU=${P} $WORKER_BIN & 48 | done 49 | 50 | wait 51 | -------------------------------------------------------------------------------- /tests/run.sh: -------------------------------------------------------------------------------- 1 | export BINARY=${BINARY:-./cmake_build/tests/stepmesh_echo_test} 2 | export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH 3 | 4 | function cleanup() { 5 | echo "kill all testing process of ps lite for user $USER" 6 | pkill -9 -f $BINARY 7 | sleep 1 8 | } 9 | trap cleanup EXIT 10 | cleanup 11 | 12 | export DMLC_NUM_WORKER=1 13 | export DMLC_NUM_SERVER=1 14 | export STEPMESH_SPLIT_QP_LAG=1 15 | export DMLC_INTERFACE=${RNIC:-brainpf_bond0} 16 | export SCHEDULER_IP=`ip addr show $DMLC_INTERFACE | awk '/inet / {print $2}' | cut -d/ -f1` 17 | export DMLC_PS_ROOT_URI=${SCHEDULER_IP} 18 | export DMLC_PS_ROOT_PORT=${DMLC_PS_ROOT_PORT:-12278} 19 | export DMLC_ENABLE_RDMA=ibverbs 20 | #export PS_VERBOSE=3 21 | ROLE=${ROLE:-server} 22 | if [ $ROLE == "server" ]; then 23 | echo "Run server and scheduler, scheduler ip $SCHEDULER_IP " 24 | export DMLC_NODE_HOST=${SCHEDULER_IP} 25 | DMLC_ROLE=scheduler $BINARY & 26 | export DMLC_INTERFACE=auto 27 | DMLC_ROLE=server STEPMESH_GPU=0 $BINARY 28 | elif [ $ROLE == "worker" ]; then 29 | echo "Run worker with scheduler ip: $1" 30 | export DMLC_PS_ROOT_URI=$1 31 | export DMLC_INTERFACE=auto 32 | DMLC_ROLE=worker STEPMESH_GPU=0 $BINARY 33 | elif [ $ROLE == "joint" ]; then 34 | echo "Run scheduler, server, and worker jointly" 35 | export DMLC_NODE_HOST=${SCHEDULER_IP} 36 | export DMLC_PS_ROOT_URI=$SCHEDULER_IP 37 | DMLC_ROLE=scheduler $BINARY & 38 | export DMLC_INTERFACE=auto 39 | DMLC_ROLE=server STEPMESH_GPU=0 $BINARY & 40 | sleep 10 41 | export DMLC_INTERFACE=auto 42 | DMLC_ROLE=worker STEPMESH_GPU=0 $BINARY 43 | fi 44 | -------------------------------------------------------------------------------- /src/windows/unistd.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | * @file unistd.h 4 | * @brief This file intended to serve as a drop-in replacement for 5 | * unistd.h on Windows Please add functionality as neeeded 6 | */ 7 | #ifndef PS_WINDOWS_UNISTD_H_ 8 | #define PS_WINDOWS_UNISTD_H_ 9 | #include 10 | #include 11 | // #include "getopt.h" /* getopt at: https://gist.github.com/ashelly/7776712 */ 12 | #include /* for _getcwd() and _chdir() */ 13 | #include /* for getpid() and the exec..() family */ 14 | 15 | #define srandom srand 16 | #define random rand 17 | 18 | /* Values for the second argument to access. 19 | These may be OR'd together. */ 20 | #define R_OK 4 /* Test for read permission. */ 21 | #define W_OK 2 /* Test for write permission. */ 22 | // #define X_OK 1 /* execute permission - unsupported in windows*/ 23 | #define F_OK 0 /* Test for existence. */ 24 | 25 | #define access _access 26 | #define dup2 _dup2 27 | #define execve _execve 28 | #define ftruncate _chsize 29 | #define unlink _unlink 30 | #define fileno _fileno 31 | #define getcwd _getcwd 32 | #define chdir _chdir 33 | #define isatty _isatty 34 | #define lseek _lseek 35 | 36 | // read, write, and close are NOT being #defined here, because while there are 37 | // file handle specific versions for Windows, they probably don't work for 38 | // sockets. You need to look at your app and consider whether to call 39 | // e.g. closesocket(). 40 | 41 | #define ssize_t int 42 | 43 | #define STDIN_FILENO 0 44 | #define STDOUT_FILENO 1 45 | #define STDERR_FILENO 2 46 | /* should be in some equivalent to */ 47 | 48 | #endif // PS_WINDOWS_UNISTD_H_ 49 | -------------------------------------------------------------------------------- /include/ps/internal/gpu_backend.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (C) by StepAI Contributors. 2025. 3 | */ 4 | 5 | #ifndef PS_INTERNAL_GPU_BACKEND_H_ 6 | #define PS_INTERNAL_GPU_BACKEND_H_ 7 | 8 | #include 9 | 10 | #include "ps/internal/backend.h" 11 | 12 | namespace ps { 13 | 14 | /** 15 | * \brief Nvidia GPU Backend Class 16 | * the device for GPU Backend stands for gpu index 17 | */ 18 | class GpuBackend : public Backend { 19 | public: 20 | GpuBackend(); 21 | int SetDevice(int dev) override; 22 | int GetDeviceId() override; 23 | at::Device GetDevice() override; 24 | void* Alloc(uint64_t size) override; 25 | void Free(void* m) override; 26 | void* CreateEvent() override; 27 | int FreeEvent(void* event) override; 28 | int RecordEvent(void* event, void* stream) override; 29 | int SyncEvent(void* event) override; 30 | 31 | private: 32 | void* CreateCudaEvent(); 33 | int FreeCudaEvent(void* event); 34 | int RecordCudaEvent(void* event, void* stream); 35 | int SyncCudaEvent(void* event); 36 | 37 | void* CreateMemEvent(); 38 | int FreeMemEvent(void* event); 39 | int RecordMemEvent(void* event, void* stream); 40 | int SyncMemEvent(void* event); 41 | 42 | private: 43 | inline void DoInitGpu() { 44 | static thread_local int gpu_idx = -1; 45 | if (gpu_idx == -1) { 46 | PS_CHECK_GE(gpu_idx_, 0) 47 | << "cannot set device " << gpu_idx_ << " for gpu backend"; 48 | SetDevice(gpu_idx_); 49 | gpu_idx = gpu_idx_; 50 | } 51 | } 52 | 53 | /** \brief for cpu backend, the device stands for numa id */ 54 | int gpu_idx_ = -1; 55 | int mem_sync_ = 1; 56 | }; 57 | 58 | } // namespace ps 59 | 60 | #endif // PS_INTERNAL_GPU_BACKEND_H_ 61 | -------------------------------------------------------------------------------- /tests/fserver/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def gen_push_key(private_key, microbatch=0, worker_rank=-1): 4 | """ 5 | Generate a key for push tensors based on microbatch, worker rank, and private key. 6 | :param private_key: your own key, ranging from 0-255, can be used for identify different tensors 7 | :param microbatch: microbatch id 8 | :param worker_rank: current worker rank, otherwise retrieving it from environ 9 | :return: the key for fserver 10 | """ 11 | assert 0 <= private_key < 256, f"illegal private key: {private_key}" 12 | if worker_rank == -1: 13 | if "DMLC_NODE_RANK" in os.environ: 14 | worker_rank = int(os.environ["DMLC_NODE_RANK"]) 15 | else: 16 | worker_rank = 0 17 | return private_key + microbatch * (1 << 8) + worker_rank * (1 << 16) 18 | 19 | 20 | def gen_pull_key(private_key, microbatch=0, worker_rank=-1): 21 | """ 22 | Generate a key for pull tensors based on microbatch, worker rank, and private key. 23 | :param private_key: your own key, ranging from 0-255, can be used for identify different tensors 24 | :param microbatch: microbatch id 25 | :param worker_rank: current worker rank, otherwise retrieving it from environ 26 | :return: the key for fserver 27 | """ 28 | assert 0 <= private_key < 256, f"illegal private key: {private_key}" 29 | if worker_rank == -1: 30 | if "DMLC_NODE_RANK" in os.environ: 31 | worker_rank = int(os.environ["DMLC_NODE_RANK"]) 32 | else: 33 | worker_rank = 0 34 | return private_key + microbatch * (1 << 8) + worker_rank * (1 << 16) + (1 << 24) 35 | 36 | def get_worker_rank(key : int): 37 | return (key % (1 << 24)) / (1 << 16) -------------------------------------------------------------------------------- /include/ps/internal/parallel_sort.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | * @file parallel_sort.h 4 | * @brief Parallel sort 5 | */ 6 | #ifndef PS_INTERNAL_PARALLEL_SORT_H_ 7 | #define PS_INTERNAL_PARALLEL_SORT_H_ 8 | #include 9 | #include 10 | #include 11 | #include "ps/sarray.h" 12 | namespace ps { 13 | 14 | // NOLINT 15 | namespace { 16 | /** 17 | * \brief the thread function 18 | * 19 | * \param data start pointer of data 20 | * \param len length of data 21 | * \param grainsize max data length of one thread 22 | * \param cmp comparison function 23 | */ 24 | template 25 | void ParallelSort(T* data, size_t len, size_t grainsize, const Fn& cmp) { 26 | if (len <= grainsize) { 27 | std::sort(data, data + len, cmp); 28 | } else { 29 | std::thread thr(ParallelSort, data, len/2, grainsize, cmp); 30 | ParallelSort(data + len/2, len - len/2, grainsize, cmp); 31 | thr.join(); 32 | 33 | std::inplace_merge(data, data + len/2, data + len, cmp); 34 | } 35 | } 36 | } // namespace 37 | 38 | /** 39 | * \brief Parallel Sort 40 | * 41 | * \param arr the array for sorting 42 | * \param num_threads number of thread 43 | * \param cmp the comparision function such as 44 | * [](const T& a, const T& b) {* return a < b; } 45 | * or an even simplier version: 46 | * std::less() 47 | */ 48 | template 49 | void ParallelSort(SArray* arr, 50 | int num_threads = 2, 51 | const Fn& cmp = std::less()) { 52 | PS_CHECK_GT(num_threads, 0); 53 | PS_CHECK(cmp); 54 | size_t grainsize = std::max(arr->size() / num_threads + 5, 1024*16); 55 | ParallelSort(arr->data(), arr->size(), grainsize, cmp); 56 | } 57 | 58 | } // namespace ps 59 | #endif // PS_INTERNAL_PARALLEL_SORT_H_ 60 | -------------------------------------------------------------------------------- /docs/env.md: -------------------------------------------------------------------------------- 1 | # Environment Variables 2 | 3 | ## BytePS Environment Variables 4 | The variables must be set for starting 5 | 6 | - `DMLC_NUM_WORKER` : The number of workers 7 | - `DMLC_NUM_SERVER` : The number of servers 8 | - `DMLC_GROUP_SIZE` : The number of processes per worker or server 9 | - `DMLC_NODE_RANK` : The node rank for servers and workers 10 | - `DMLC_ROLE` : The role of the current node, can be `worker`, `server`, or `scheduler` 11 | - `DMLC_PS_ROOT_URI` : The ip or hostname of the scheduler node 12 | - `DMLC_PS_ROOT_PORT` : The port that the scheduler node is listening 13 | - `DMLC_ENABLE_RDMA` : Enable to use RDMA Van 14 | 15 | additional variables: 16 | 17 | - `DMLC_INTERFACE` : The network interface a node should use. 18 | `auto` can be used for automatically detection of mappings between RDMA NIC ports and GPU cards. 19 | - `DMLC_LOCAL` : Runs in local machines, no network is needed 20 | 21 | ## StepMesh Environment Variables 22 | 23 | Besides the above, StepMesh introduces some independent environment variables. 24 | 25 | - `STEPMESH_BAKCEND` : The backend to be is used, 26 | currently we only support `CPU` backend and `GPU` (default) backend. 27 | We are working on supporting more backends 28 | - `STEPMESH_GPU` : Set the device id used by the GPU backend 29 | - `STEPMESH_SPLIT_QP_LAG` : Enable QP traffic balance over bonding RDNA NICs, STEPMESH_SPLIT_QP_LAG=0 bydefault 30 | - `STEPMESH_MEM_SYNC` : Enable synchronize GPU kernel with CPU memory instead of cudaEvent, 31 | STEPMESH_MEM_SYNC=1 by default 32 | - `STEPMESH_BIND_CPU_CORE` : Enable CPU core binding, STEPMESH_BIND_CPU_CORE=1 by default 33 | - `STEPMESH_CPU_CORES_PER_SOCKET` : The count of cpu cores of each CPU socket, default is 48. 34 | - `STEPMESH_CPU_CORES_PER_GPU` : The count of cpu cores should used by each GPU, the default value is 5. 35 | - `STEPMESH_CPU_START_OFFSET` : The first idx of cpu core used by StepMesh. 36 | -------------------------------------------------------------------------------- /include/ps/internal/assign_op.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | * \file assign_op.h 4 | * \brief assignment operator 5 | * http://en.cppreference.com/w/cpp/language/operator_assignment 6 | */ 7 | #ifndef PS_INTERNAL_ASSIGN_OP_H_ 8 | #define PS_INTERNAL_ASSIGN_OP_H_ 9 | #include "ps/internal/utils.h" 10 | namespace ps { 11 | 12 | enum AssignOp { 13 | ASSIGN, // a = b 14 | PLUS, // a += b 15 | MINUS, // a -= b 16 | TIMES, // a *= b 17 | DIVIDE, // a -= b 18 | AND, // a &= b 19 | OR, // a |= b 20 | XOR // a ^= b 21 | }; 22 | 23 | /** 24 | * \brief return an assignment function: right op= left 25 | */ 26 | template 27 | inline void AssignFunc(const T& lhs, AssignOp op, T* rhs) { 28 | switch (op) { 29 | case ASSIGN: 30 | *rhs = lhs; 31 | break; 32 | case PLUS: 33 | *rhs += lhs; 34 | break; 35 | case MINUS: 36 | *rhs -= lhs; 37 | break; 38 | case TIMES: 39 | *rhs *= lhs; 40 | break; 41 | case DIVIDE: 42 | *rhs /= lhs; 43 | break; 44 | default: 45 | PS_LOG(FATAL) << "use AssignOpInt.."; 46 | } 47 | } 48 | 49 | /** 50 | * \brief return an assignment function including bit operations, only 51 | * works for integers 52 | */ 53 | template 54 | inline void AssignFuncInt(const T& lhs, AssignOp op, T* rhs) { 55 | switch (op) { 56 | case ASSIGN: 57 | *rhs = lhs; 58 | break; 59 | case PLUS: 60 | *rhs += lhs; 61 | break; 62 | case MINUS: 63 | *rhs -= lhs; 64 | break; 65 | case TIMES: 66 | *rhs *= lhs; 67 | break; 68 | case DIVIDE: 69 | *rhs /= lhs; 70 | break; 71 | case AND: 72 | *rhs &= lhs; 73 | break; 74 | case OR: 75 | *rhs |= lhs; 76 | break; 77 | case XOR: 78 | *rhs ^= lhs; 79 | break; 80 | } 81 | } 82 | 83 | } // namespace ps 84 | #endif // PS_INTERNAL_ASSIGN_OP_H_ 85 | -------------------------------------------------------------------------------- /make/deps.mk: -------------------------------------------------------------------------------- 1 | # Install dependencies 2 | 3 | URL1=https://raw.githubusercontent.com/mli/deps/master/build 4 | ifndef WGET 5 | WGET = wget 6 | endif 7 | 8 | # zmq 9 | ZMQ = ${DEPS_PATH}/include/zmq.h 10 | 11 | ifndef ZMQ_URL 12 | ZMQ_URL = $(URL1)/$(FILE) 13 | endif 14 | 15 | ${ZMQ}: 16 | $(eval FILE=zeromq-4.1.4.tar.gz) 17 | $(eval DIR=zeromq-4.1.4) 18 | rm -rf $(FILE) $(DIR) 19 | $(WGET) $(ZMQ_URL) && tar --no-same-owner -zxf $(FILE) 20 | cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) --with-libsodium=no --with-libgssapi_krb5=no && $(MAKE) && $(MAKE) install 21 | rm -rf $(FILE) $(DIR) 22 | 23 | # lz4 24 | LZ4 = ${DEPS_PATH}/include/lz4.h 25 | ${LZ4}: 26 | $(eval FILE=lz4-r129.tar.gz) 27 | $(eval DIR=lz4-r129) 28 | rm -rf $(FILE) $(DIR) 29 | wget $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE) 30 | cd $(DIR) && $(MAKE) && PREFIX=$(DEPS_PATH) $(MAKE) install 31 | rm -rf $(FILE) $(DIR) 32 | 33 | # cityhash 34 | CITYHASH = ${DEPS_PATH}/include/city.h 35 | ${CITYHASH}: 36 | $(eval FILE=cityhash-1.1.1.tar.gz) 37 | $(eval DIR=cityhash-1.1.1) 38 | rm -rf $(FILE) $(DIR) 39 | wget $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE) 40 | cd $(DIR) && ./configure -prefix=$(DEPS_PATH) --enable-sse4.2 && $(MAKE) CXXFLAGS="-g -O3 -msse4.2" && $(MAKE) install 41 | rm -rf $(FILE) $(DIR) 42 | 43 | 44 | # # gflags 45 | # ${DEPS_PATH}/include/google/gflags.h: 46 | # $(eval FILE=gflags-2.0-no-svn-files.tar.gz) 47 | # $(eval DIR=gflags-2.0) 48 | # rm -rf $(FILE) $(DIR) 49 | # wget $(URL)/$(FILE) && tar -zxf $(FILE) 50 | # cd $(DIR) && ./configure -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install 51 | # rm -rf $(FILE) $(DIR) 52 | # gflags: | ${DEPS_PATH}/include/google/gflags.h 53 | 54 | # # glog 55 | # ${DEPS_PATH}/include/glog/logging.h: | ${DEPS_PATH}/include/google/gflags.h 56 | # $(eval FILE=v0.3.4.tar.gz) 57 | # $(eval DIR=glog-0.3.4) 58 | # rm -rf $(FILE) $(DIR) 59 | # wget https://github.com/google/glog/archive/$(FILE) && tar -zxf $(FILE) 60 | # cd $(DIR) && ./configure -prefix=$(DEPS_PATH) --with-gflags=$(DEPS_PATH) && $(MAKE) && $(MAKE) install 61 | # rm -rf $(FILE) $(DIR) 62 | # glog: | ${DEPS_PATH}/include/glog/logging.h 63 | -------------------------------------------------------------------------------- /tests/utests/ut_common.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2024, STEP AI. All rights reserved. 3 | ************************************************************************/ 4 | #ifndef PS_TEST_COMMON_H_ 5 | #define PS_TEST_COMMON_H_ 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | #ifdef DMLC_USE_CUDA 18 | #include 19 | #include 20 | #endif 21 | 22 | #include "ps/ps.h" 23 | 24 | using namespace ps; 25 | 26 | static struct { 27 | int gpu = 0; 28 | } g_conf; 29 | 30 | static struct { 31 | int batch_max = 128; 32 | int tensor_num = 1; 33 | } g_worker_conf; 34 | 35 | #define CUDA_CALL(func) \ 36 | { \ 37 | cudaError_t e = (func); \ 38 | PS_CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ 39 | << "CUDA: " << cudaGetErrorString(e); \ 40 | } 41 | 42 | static inline void InitUtestConfig() { 43 | Environment::Get()->find("UTEST_WORKER_BATCH_MAX", 44 | &g_worker_conf.batch_max, 45 | g_worker_conf.batch_max); 46 | Environment::Get()->find("UTEST_WORKER_TENSOR_NUM", 47 | &g_worker_conf.tensor_num, 48 | g_worker_conf.tensor_num); 49 | 50 | Environment::Get()->find("STEPMESH_GPU", &g_conf.gpu, g_conf.gpu); 51 | } 52 | 53 | static inline at::Tensor CreateTensor( 54 | std::vector shape, 55 | at::ScalarType dtype, 56 | int gpu, 57 | bool random = false) { 58 | auto options = torch::TensorOptions() 59 | .dtype(dtype) 60 | .memory_format(at::MemoryFormat::Contiguous) 61 | .device(at::Device(at::kCUDA, gpu)); 62 | if (random) { 63 | return torch::rand(shape, options); 64 | } else { 65 | return torch::zeros(shape, options); 66 | } 67 | } 68 | 69 | #endif // PS_TEST_COMMON_H_ 70 | -------------------------------------------------------------------------------- /cmake/External/zmq.cmake: -------------------------------------------------------------------------------- 1 | if (NOT __ZMQ_INCLUDED) # guard against multiple includes 2 | set(__ZMQ_INCLUDED TRUE) 3 | 4 | # use the system-wide ZMQ if present 5 | find_package(ZMQ) 6 | if (ZMQ_FOUND) 7 | set(ZMQ_EXTERNAL FALSE) 8 | else() 9 | # ZMQ will use pthreads if it's available in the system, so we must link with it 10 | find_package(Threads) 11 | 12 | # build directory 13 | set(ZMQ_PREFIX ${CMAKE_BINARY_DIR}/external/ZMQ-prefix) 14 | # install directory 15 | set(ZMQ_INSTALL ${CMAKE_BINARY_DIR}/external/ZMQ-install) 16 | 17 | # we build ZMQ statically, but want to link it into the caffe shared library 18 | # this requires position-independent code 19 | if (UNIX) 20 | set(ZMQ_EXTRA_COMPILER_FLAGS "-fPIC") 21 | endif() 22 | 23 | set(ZMQ_CXX_FLAGS ${CMAKE_CXX_FLAGS} ${ZMQ_EXTRA_COMPILER_FLAGS}) 24 | set(ZMQ_C_FLAGS ${CMAKE_C_FLAGS} ${ZMQ_EXTRA_COMPILER_FLAGS}) 25 | 26 | ExternalProject_Add(ZMQ 27 | PREFIX ${ZMQ_PREFIX} 28 | GIT_REPOSITORY "https://github.com/zeromq/libZMQ.git" 29 | UPDATE_COMMAND "" 30 | INSTALL_DIR ${ZMQ_INSTALL} 31 | CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} 32 | -DCMAKE_INSTALL_PREFIX=${ZMQ_INSTALL} 33 | -DBUILD_SHARED_LIBS=OFF 34 | -DBUILD_STATIC_LIBS=ON 35 | -DBUILD_PACKAGING=OFF 36 | -DBUILD_TESTING=OFF 37 | -DBUILD_NC_TESTS=OFF 38 | -BUILD_CONFIG_TESTS=OFF 39 | -DINSTALL_HEADERS=ON 40 | -DCMAKE_C_FLAGS=${ZMQ_C_FLAGS} 41 | -DCMAKE_CXX_FLAGS=${ZMQ_CXX_FLAGS} 42 | LOG_DOWNLOAD 1 43 | LOG_INSTALL 1 44 | ) 45 | 46 | set(ZMQ_FOUND TRUE) 47 | set(ZMQ_INCLUDE_DIRS ${ZMQ_INSTALL}/include) 48 | 49 | if(MSVC) 50 | FILE(GLOB_RECURSE ZMQ_LIBRARIES "${ZMQ_INSTALL}/lib/libzmq-${CMAKE_VS_PLATFORM_TOOLSET}*.lib") 51 | #set(ZMQ_LIBRARIES ${ZMQ_INSTALL}/lib/ZMQ.lib ${CMAKE_THREAD_LIBS_INIT}) 52 | else() 53 | FILE(GLOB_RECURSE ZMQ_LIBRARIES "${ZMQ_INSTALL}/lib/libzmq-*.a") 54 | #set(ZMQ_LIBRARIES ${ZMQ_INSTALL}/lib/libZMQ.a ${CMAKE_THREAD_LIBS_INIT}) 55 | endif() 56 | set(ZMQ_LIBRARY_DIRS ${ZMQ_INSTALL}/lib) 57 | set(ZMQ_EXTERNAL TRUE) 58 | 59 | list(APPEND external_project_dependencies ZMQ) 60 | endif() 61 | 62 | endif() 63 | -------------------------------------------------------------------------------- /tests/fserver/run_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 2 | function cleanup() { 3 | echo "kill all testing process of ps lite for user $USER" 4 | # pkill -9 -f test_bench 5 | pkill -9 -f python3 6 | sleep 1 7 | } 8 | trap cleanup EXIT 9 | # cleanup 10 | 11 | export BIN=${BIN:-test_fserver} 12 | # common setup 13 | export DMLC_INTERFACE=${RNIC:-brainpf_bond0} 14 | export SCHEDULER_IP=$(ip -o -4 addr | grep ${DMLC_INTERFACE} | awk '{print $4}' | cut -d'/' -f1) 15 | export DMLC_NUM_WORKER=${NUM_WORKER:-1} 16 | export DMLC_NUM_SERVER=${NUM_SERVER:-1} 17 | export DMLC_GROUP_SIZE=8 18 | export DMLC_NODE_RANK=${NODE_RANK:-0} 19 | export DMLC_PS_ROOT_PORT=8123 20 | export DMLC_PS_ROOT_URI=$SCHEDULER_IP # scheduler's RDMA interface IP 21 | export DMLC_ENABLE_RDMA=ibverbs 22 | export NCCL_DEBUG=warning 23 | export STEPMESH_SPLIT_QP_LAG=1 24 | export STEPMESH_BIND_CPU_CORE=1 25 | 26 | # export PS_VERBOSE=2 27 | 28 | ROLE=${ROLE:-server} 29 | if [ $ROLE == "server" ]; then 30 | echo "Run server and scheduler, scheduler ip $SCHEDULER_IP " 31 | export DMLC_NODE_HOST=${SCHEDULER_IP} 32 | DMLC_ROLE=scheduler python3 $THIS_DIR/${BIN}.py & 33 | 34 | sleep 1 # wait scheduler 35 | 36 | export DMLC_INTERFACE=auto 37 | for P in {0..7}; do 38 | DMLC_ROLE=server STEPMESH_GPU=${P} python3 $THIS_DIR/${BIN}.py $@ & 39 | done 40 | elif [ $ROLE == "worker" ]; then 41 | echo "Run worker with scheduler ip: $1" 42 | export DMLC_PS_ROOT_URI=$1 43 | export DMLC_INTERFACE=auto 44 | export DMLC_NODE_HOST=${SCHEDULER_IP} 45 | for P in {0..7}; do 46 | DMLC_ROLE=worker STEPMESH_GPU=${P} python3 $THIS_DIR/${BIN}.py "${@:2}" & 47 | done 48 | elif [ $ROLE == "server-slave" ]; then 49 | echo "Run server with scheduler ip: $1" 50 | export DMLC_PS_ROOT_URI=$1 51 | export DMLC_INTERFACE=auto 52 | export DMLC_NODE_HOST=${SCHEDULER_IP} 53 | for P in {0..7}; do 54 | DMLC_ROLE=server STEPMESH_GPU=${P} python3 $THIS_DIR/${BIN}.py "${@:2}" & 55 | done 56 | elif [ $ROLE == "joint" ]; then 57 | echo "Run scheduler, server, and worker jointly" 58 | export DMLC_NODE_HOST=${SCHEDULER_IP} 59 | export DMLC_PS_ROOT_URI=$SCHEDULER_IP 60 | DMLC_ROLE=scheduler python3 $THIS_DIR/${BIN}.py & 61 | 62 | export DMLC_INTERFACE=auto 63 | for P in {0..7}; do 64 | DMLC_ROLE=server STEPMESH_GPU=${P} python3 $THIS_DIR/${BIN}.py & 65 | done 66 | 67 | for P in {0..7}; do 68 | DMLC_ROLE=worker STEPMESH_GPU=${P} python3 $THIS_DIR/${BIN}.py & 69 | done 70 | fi 71 | 72 | wait -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # APIs 2 | 3 | ## 1. StepMesh APIs 4 | 5 | StepMesh provides two levels of API, both of which use `torch.Tensor` as the core data struct. 6 | The Python Fserver API is designed for quick and easy integration into the system, 7 | making it ideal for users who need to get up and running rapidly. 8 | On the other hand, the C++ API is tailored for experienced developers 9 | who require deep-level performance tuning to optimize the system for specific use cases. 10 | 11 | ### 1.1 Python API 12 | 13 | Developers can import Python API with the following after installing StepMesh into your environment. 14 | 15 | For details, please goto [Python API](./python.md) 16 | 17 | ### 1.2 C++ API 18 | 19 | Developers can link `libaf.a` into your program for using C++ APIs. 20 | 21 | For details, please goto [C++ API](./c++.md) 22 | 23 | ## 2. BytePS APIs 24 | 25 | The data communicated are presented as key-value 26 | pairs, where the key might be the `uint64_t` (defined by `ps::Key`) feature 27 | index and the value might be the according `float` gradient. 28 | 1. Basic synchronization functions: \ref ps::KVWorker::Push, \ref 29 | ps::KVWorker::Pull, and \ref ps::KVWorker::Wait 30 | 2. Zero-copy versions: \ref ps::KVWorker::ZPush, \ref 31 | ps::KVWorker::ZPull 32 | 33 | To support dynamic length, pull operations(`Pull` and `ZPull`), do not require the buffer(`vals`) to be the same size as the total data size of pulling down. Larger buffer is allowed while `lens` records the actual size of each key. So the reliable way to read a valid message is to read `lens` bytes. If you ensure that the data size of a key does not change during push or pull, you can verify it by checking whether `lens` of the key is equal to the fixed size. 34 | 35 | often server *i* handles the keys (feature indices) within the i-th 36 | segment of [0, uint64_max]. The server node allows user-defined handles to 37 | process the `push` and `pull` requests from the workers. 38 | 1. Online key-value store \ref ps::OnlineServer 39 | 2. Example user-defined value: \ref ps::IVal 40 | 3. Example user-defined handle: \ref ps::IOnlineHandle 41 | 42 | 43 | 44 | also We can 45 | also implement 46 | 47 | , which is often used to monitor and control the 48 | progress of the machine learning application. It also can be used to deal with node 49 | failures. See an example in [asynchronous SGD](https://github.com/dmlc/wormhole/blob/master/learn/solver/async_sgd.h#L27). 50 | 51 | ```eval_rst 52 | .. automodule:: ps::KVWorker 53 | :members: 54 | ``` 55 | 56 | ```eval_rst 57 | .. doxygenstruct:: ps::KVPairs 58 | :members: 59 | ``` 60 | -------------------------------------------------------------------------------- /fserver/csrc/private.hpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2025, StepFun Authors. All rights reserved. */ 2 | 3 | #include "./util.hpp" 4 | #include "./public.hpp" 5 | #include 6 | #ifdef DMLC_USE_CUDA 7 | #include 8 | #include 9 | #endif 10 | 11 | #ifndef PRIAVET_OPS_ 12 | #define PRIVATE_OPS_ 13 | 14 | using namespace ps; 15 | #ifdef DMLC_USE_CUDA 16 | class SimpleNotify{ 17 | private: 18 | int notify_cnt = 1; 19 | CUdeviceptr dflag; 20 | uint32_t* hflag; 21 | std::thread th_; 22 | std::future> fut; 23 | public: 24 | void init() { 25 | cudaHostAlloc(&hflag, sizeof(uint32_t), cudaHostAllocMapped); 26 | cudaHostGetDevicePointer((void**)&dflag, (void*)hflag, 0); 27 | } 28 | 29 | // for worker 30 | void wait_event_done(){ 31 | if (th_.joinable()) { 32 | th_.join(); 33 | } 34 | } 35 | 36 | // for worker 37 | void stream_wait_event(int handler) { 38 | auto stream = at::cuda::getCurrentCUDAStream(); 39 | cuStreamWaitValue32((CUstream)stream, dflag, notify_cnt, CU_STREAM_WAIT_VALUE_EQ); 40 | th_ = std::thread([handler, this]{ 41 | fworker_->Wait(handler); 42 | *(this->hflag) = this->notify_cnt; 43 | ++(this->notify_cnt); 44 | }); 45 | } 46 | 47 | void block_now_stream() { 48 | auto stream = at::cuda::getCurrentCUDAStream(); 49 | cuStreamWaitValue32((CUstream)stream, dflag, notify_cnt, CU_STREAM_WAIT_VALUE_EQ); 50 | } 51 | 52 | // for server 53 | void block_now_stream_and_get_batch() { 54 | auto stream = at::cuda::getCurrentCUDAStream(); 55 | cuStreamWaitValue32((CUstream)stream, dflag, notify_cnt, CU_STREAM_WAIT_VALUE_EQ); 56 | fut = std::async(std::launch::async, [this]{ 57 | auto ret = get_batch(); 58 | *(this->hflag) = this->notify_cnt; 59 | ++(this->notify_cnt); 60 | return ret; 61 | }); 62 | } 63 | 64 | // for server 65 | std::vector get_future_batch_data(){ 66 | return fut.get(); 67 | } 68 | }; 69 | 70 | void pybind_private(py::module &m){ 71 | py::class_(m, "SimpleNotify") 72 | .def(py::init<>()) 73 | .def("init", &SimpleNotify::init) 74 | .def("block_now_stream_and_get_batch", &SimpleNotify::block_now_stream_and_get_batch) 75 | .def("get_future_batch_data", &SimpleNotify::get_future_batch_data) 76 | .def("block_now_stream", &SimpleNotify::block_now_stream) 77 | .def("wait_event_done", &SimpleNotify::wait_event_done) 78 | .def("stream_wait_event", &SimpleNotify::stream_wait_event); 79 | } 80 | #else 81 | void pybind_private(py::module &m){} 82 | #endif //DMLC_USE_CUDA 83 | 84 | #endif //PRIVATE_OPS_ 85 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ifdef config 2 | include $(config) 3 | endif 4 | 5 | include make/ps.mk 6 | 7 | ifndef CXX 8 | CXX = g++ 9 | endif 10 | 11 | ifndef DEPS_PATH 12 | DEPS_PATH = $(shell pwd)/deps 13 | endif 14 | 15 | ifndef FABRIC_PATH 16 | FABRIC_PATH = /opt/amazon/efa 17 | endif 18 | 19 | ifndef PROTOC 20 | PROTOC = ${DEPS_PATH}/bin/protoc 21 | endif 22 | 23 | CMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc 24 | CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda 25 | ifdef CUDA_HOME 26 | CMAKE_CUDA_COMPILER=$(CUDA_HOME)/bin/nvcc 27 | CUDA_TOOLKIT_ROOT_DIR=$(CUDA_HOME) 28 | endif 29 | 30 | INCPATH = -I./src -I./include -I$(DEPS_PATH)/include 31 | CFLAGS = -std=c++14 -msse2 -fPIC -O3 -ggdb -Wall -finline-functions $(INCPATH) $(ADD_CFLAGS) 32 | LIBS = -pthread -lrt 33 | 34 | ifeq ($(USE_CUDA), 1) 35 | LIBS += -lcudart -L$(CUDA_HOME)/lib64 36 | CFLAGS += -DDMLC_USE_CUDA 37 | INCPATH += -I$(CUDA_HOME)/include 38 | endif 39 | 40 | ifeq ($(USE_RDMA), 1) 41 | LIBS += -lrdmacm -libverbs 42 | CFLAGS += -DDMLC_USE_RDMA 43 | endif 44 | 45 | ifeq ($(USE_GDR), 1) 46 | CFLAGS += -DSTEPMESH_USE_GDR 47 | endif 48 | 49 | ifeq ($(ENABLE_TRACE), 1) 50 | CFLAGS += -DSTEPMESH_ENABLE_TRACE 51 | endif 52 | 53 | ifeq ($(USE_FABRIC), 1) 54 | LIBS += -lfabric -L$(FABRIC_PATH)/lib64 -L$(FABRIC_PATH)/lib 55 | CFLAGS += -DDMLC_USE_FABRIC 56 | INCPATH += -I$(FABRIC_PATH)/include 57 | endif 58 | 59 | ifeq ($(USE_UCX), 1) 60 | LIBS += -lucp -luct -lucs -lucm 61 | CFLAGS += -DDMLC_USE_UCX 62 | ifdef UCX_PATH 63 | LIBS += -L$(UCX_PATH)/lib 64 | INCPATH += -I$(UCX_PATH)/include 65 | endif 66 | endif 67 | 68 | ifeq ($(USE_TP), 1) 69 | # Make sure the build of TP is compliant with ps-lite (e.g., -fPIC, C++ ABI) 70 | INCPATH += -I$(TP_INSTALL_PATH)/include 71 | CFLAGS += -DDMLC_USE_TP 72 | endif 73 | 74 | ifdef ASAN 75 | CFLAGS += -fsanitize=address -fno-omit-frame-pointer -fno-optimize-sibling-calls 76 | endif 77 | 78 | 79 | all: ps test 80 | 81 | include make/deps.mk 82 | 83 | clean: 84 | rm -rf build $(TEST) tests/*.d tests/*.dSYM 85 | 86 | lint: 87 | python tests/lint.py ps all include/ps src 88 | 89 | ps: build/libps.a 90 | 91 | OBJS = $(addprefix build/, customer.o postoffice.o van.o network_utils.o) 92 | build/libps.a: $(OBJS) 93 | ar crv $@ $(filter %.o, $?) 94 | 95 | build/%.o: src/%.cc ${ZMQ} 96 | @mkdir -p $(@D) 97 | $(CXX) $(CFLAGS) $(INCPATH) -MM -MT build/$*.o $< >build/$*.d 98 | $(CXX) $(CFLAGS) $(LIBS) -c $< -o $@ 99 | 100 | -include build/*.d 101 | -include build/*/*.d 102 | 103 | test: $(TEST) 104 | 105 | af: 106 | @mkdir -p cmake_build 107 | @cd cmake_build; cmake .. -DCMAKE_CUDA_COMPILER=$(CMAKE_CUDA_COMPILER) -DPython_EXECUTABLE=$(shell which python3) -DCUDA_TOOLKIT_ROOT_DIR=$(CUDA_TOOLKIT_ROOT_DIR); make -j 108 | @mkdir -p build 109 | 110 | -------------------------------------------------------------------------------- /fserver/csrc/wait_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | __global__ void write_flag_kernel(int64_t* flag, int64_t* seq) { 9 | int64_t seq_value = seq[0]; 10 | if (threadIdx.x == 0) { 11 | flag[0] = seq_value; 12 | // 写入后执行 system fence,确保写入对所有线程和 CPU 可见 13 | } 14 | __threadfence_system(); 15 | } 16 | 17 | __global__ void wait_flag_kernel(int64_t* flag, int64_t* seq) { 18 | if (threadIdx.x == 0) { 19 | // Mark pointer volatile so we reload host-written values each iteration. 20 | volatile int64_t* flag_ptr = flag, *seq_ptr = seq; 21 | int64_t flag_value = flag_ptr[0]; 22 | int64_t seq_value = seq_ptr[0]; 23 | while (flag_value < seq_value) { 24 | __nanosleep(128); 25 | flag_value = flag_ptr[0]; 26 | } 27 | } 28 | } 29 | 30 | __global__ void seq_add_one_kernel(int64_t* seq) { 31 | if (threadIdx.x == 0) { 32 | seq[0]++; 33 | } 34 | __threadfence_system(); 35 | } 36 | 37 | static void check_cuda(cudaError_t err, const char* msg) { 38 | TORCH_CHECK(err == cudaSuccess, msg, ": ", cudaGetErrorString(err)); 39 | } 40 | 41 | torch::Tensor map_pinned_tensor(torch::Tensor tensor, int64_t device_index) { 42 | TORCH_CHECK(tensor.is_pinned(), "tensor must be pinned"); 43 | void* host_ptr = tensor.data_ptr(); 44 | void* device_ptr = nullptr; 45 | check_cuda(cudaHostGetDevicePointer(&device_ptr, host_ptr, 0), 46 | "cudaHostGetDevicePointer failed"); 47 | auto options = tensor.options().device(torch::kCUDA, device_index); 48 | auto sizes = tensor.sizes(); 49 | auto strides = tensor.strides(); 50 | return torch::from_blob(device_ptr, sizes, strides, [](void*){}, options); 51 | } 52 | 53 | void write_flag(torch::Tensor flag, torch::Tensor seq) { 54 | TORCH_CHECK(flag.is_cuda(), "flag must be a CUDA tensor"); 55 | auto stream = at::cuda::getCurrentCUDAStream(flag.device().index()); 56 | write_flag_kernel<<<1, 1, 0, stream>>>(flag.data_ptr(), seq.data_ptr()); 57 | check_cuda(cudaGetLastError(), "write_flag_kernel launch failed"); 58 | } 59 | 60 | void wait_flag(torch::Tensor flag, torch::Tensor seq) { 61 | TORCH_CHECK(flag.is_cuda(), "flag must be a CUDA tensor"); 62 | auto stream = at::cuda::getCurrentCUDAStream(flag.device().index()); 63 | wait_flag_kernel<<<1, 1, 0, stream>>>(flag.data_ptr(), seq.data_ptr()); 64 | check_cuda(cudaGetLastError(), "wait_flag_kernel launch failed"); 65 | } 66 | 67 | void seq_add_one(torch::Tensor seq) { 68 | TORCH_CHECK(seq.is_cuda(), "seq must be a CUDA tensor"); 69 | auto stream = at::cuda::getCurrentCUDAStream(seq.device().index()); 70 | seq_add_one_kernel<<<1, 1, 0, stream>>>(seq.data_ptr()); 71 | check_cuda(cudaGetLastError(), "seq_add_one_kernel launch failed"); 72 | } -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.22 FATAL_ERROR) 2 | 3 | project(af LANGUAGES C CXX) 4 | 5 | set(CMAKE_CXX_STANDARD 17) 6 | execute_process(COMMAND ${Python_EXECUTABLE} 7 | -c "import torch; print(int(torch.compiled_with_cxx11_abi()))" 8 | OUTPUT_VARIABLE TORCH_CXX11_ABI OUTPUT_STRIP_TRAILING_WHITESPACE) 9 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fPIC -O3 -Wall -finline-functions -msse2 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI} ") 10 | 11 | 12 | # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fPIC -O3 -Wall -finline-functions -msse2 ") 13 | 14 | # import pytorch library 15 | find_package (Python COMPONENTS Interpreter Development) 16 | execute_process(COMMAND ${Python_EXECUTABLE} 17 | -c "import torch; print(torch.utils.cmake_prefix_path)" 18 | OUTPUT_VARIABLE PYTORCH_CMAKE_PREFIX_PATH OUTPUT_STRIP_TRAILING_WHITESPACE) 19 | message("MY PYTHON_EXECUTABLE ${Python_EXECUTABLE}") 20 | message("MY PYTORCH_CMAKE_PREFIX_PATH ${PYTORCH_CMAKE_PREFIX_PATH}") 21 | 22 | list(APPEND CMAKE_PREFIX_PATH "${PYTORCH_CMAKE_PREFIX_PATH}/Torch") 23 | 24 | find_package(Torch REQUIRED CONFIG) 25 | message("MY TORCH_INCLUDE_DIRS ${TORCH_INCLUDE_DIRS}") 26 | message("MY CUDA_INCLUDE_DIRS ${CUDA_INCLUDE_DIRS}") 27 | include_directories(${TORCH_INCLUDE_DIRS}) 28 | # Save ABI setting before adding TORCH_CXX_FLAGS (which might override it) 29 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 30 | # Ensure ABI setting is preserved after TORCH_CXX_FLAGS 31 | 32 | list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") 33 | list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) 34 | 35 | 36 | if("$ENV{USE_CUDA}" STREQUAL "0") 37 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_ZMQ -DDMLC_USE_RDMA -DSTEPMESH_USE_GDR") 38 | else() 39 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_ZMQ -DDMLC_USE_CUDA -DSTEPMESH_USE_GDR -DDMLC_USE_RDMA -DSTEPMESH_ENABLE_TRACE") 40 | endif() 41 | 42 | link_directories("${PROJECT_SOURCE_DIR}/deps/lib") 43 | 44 | list(APPEND STEPAF_LIBS pthread dl zmq ibverbs rdmacm rt) 45 | 46 | include_directories("${PROJECT_SOURCE_DIR}/include/") 47 | include_directories("${PROJECT_SOURCE_DIR}/deps/include/") 48 | include_directories("${PROJECT_BINARY_DIR}/include/") 49 | include_directories("${PROJECT_BINARY_DIR}/src/") 50 | 51 | if("${CUDA_INCLUDE_DIRS}" STREQUAL "") 52 | message(WARNING "CUDA_INCLUDE_DIRS is empty") 53 | else() 54 | include_directories("${CUDA_INCLUDE_DIRS}") 55 | endif() 56 | 57 | FILE(GLOB SOURCE "src/*.cc") 58 | FILE(GLOB BACKEND_SOURCE "src/backend/*.cc") 59 | message("MY SOURCES ${SOURCE}") 60 | message("MY BACKEND_SOURCE ${BACKEND_SOURCE}") 61 | 62 | add_library(af ${SOURCE} ${BACKEND_SOURCE}) 63 | 64 | target_link_libraries(af 65 | ${STEPAF_LIBS} 66 | ${TORCH_LIBRARIES} 67 | ${TORCH_PYTHON_LIBRARY}) 68 | 69 | add_subdirectory(tests) 70 | -------------------------------------------------------------------------------- /tests/utests/stepmesh_push_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (C) 2025 by StepAI Contributors. 3 | */ 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #include "test_common.h" 12 | 13 | void PushHandler(const AFTensorMeta& req_meta, AFTensorServer* server) { 14 | server->Response(req_meta, {}); 15 | } 16 | 17 | void RunWorker(AFTensorWorker* kv) { 18 | auto push_tensor = CreateTensor({g_conf.size}, 19 | at::kBFloat16, g_conf.gpu, true); 20 | auto Push = [kv, push_tensor] () { 21 | auto start = std::chrono::high_resolution_clock::now(); 22 | std::vector timestamps; 23 | auto push_batch = KeyTensorBatch(); 24 | push_batch.push_back( 25 | KeyTensor{ 26 | 0, push_tensor, 27 | }); 28 | auto pull_batch = KeyTensorBatch(); 29 | 30 | kv->Wait(kv->ZBatchPushPull(push_batch, pull_batch)); 31 | auto end = std::chrono::high_resolution_clock::now(); 32 | return (end - start).count(); 33 | }; 34 | 35 | PS_LOG(INFO) << "warmup starts"; 36 | std::vector timestamps; 37 | 38 | for (int iter = 0; iter < g_conf.iter; iter++) { 39 | auto pushpull_ts = Push(); 40 | timestamps.emplace_back(pushpull_ts); 41 | if ((iter % 1000 == 999)) { 42 | DumpLatency("push batch latency: ", timestamps); 43 | timestamps.clear(); 44 | } 45 | } 46 | } 47 | 48 | void StartPushServer() { 49 | PS_LOG(INFO) << "run server: gpu=" << g_conf.gpu 50 | << ", node rank=" << g_conf.node_rank 51 | << ", group size=" << g_conf.group_size; 52 | StartPS(0, Node::SERVER, 53 | g_conf.node_rank * g_conf.group_size + g_conf.gpu, true); 54 | Backend::Get()->SetDevice(g_conf.gpu); 55 | StartServer(PushHandler); 56 | Finalize(0, Node::SERVER, true); 57 | PS_LOG(INFO) << "FFN server ends"; 58 | } 59 | 60 | void StartWorkers() { 61 | PS_LOG(INFO) << "run worker: gpu=" << g_conf.gpu 62 | << ", node rank=" << g_conf.node_rank 63 | << ", group size=" << g_conf.group_size; 64 | StartPS(0, Node::WORKER, 65 | g_conf.node_rank * g_conf.group_size + g_conf.gpu, true); 66 | Backend::Get()->SetDevice(g_conf.gpu); 67 | AFTensorWorker af_worker(g_conf.gpu); 68 | InitWorker(&af_worker); 69 | RunWorker(&af_worker); 70 | Finalize(0, Node::WORKER, true); 71 | PS_LOG(INFO) << "Simulated worker is DONE"; 72 | } 73 | 74 | int main(int argc, char *argv[]) { 75 | InitConfig(); 76 | PS_LOG(INFO) << "StepMesh Echo Tests: gpu_num=" 77 | << g_conf.gpu_num << ", role=" << g_conf.role_str; 78 | if (g_conf.role == Node::SCHEDULER) { 79 | StartScheduler(); 80 | } else if (g_conf.role == Node::SERVER) { 81 | StartPushServer(); 82 | } else if (g_conf.role == Node::WORKER) { 83 | StartWorkers(); 84 | } 85 | return 0; 86 | } 87 | -------------------------------------------------------------------------------- /tests/utests/ut_tensor_worker.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Step AI 3 | */ 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include "ut_common.h" 10 | 11 | int ServerCount = ps::GetEnv("DMLC_NUM_SERVER", 1); 12 | 13 | void StartWorkers() { 14 | PS_LOG(INFO) << "run worker over gpu" << g_conf.gpu; 15 | StartPS(0, Node::WORKER, g_conf.gpu, true); 16 | Backend::Get()->SetDevice(g_conf.gpu); 17 | AFTensorWorker af_worker = AFTensorWorker(g_conf.gpu); 18 | ps::Postoffice::GetWorker(g_conf.gpu)->Barrier( 19 | 0, ps::kServerGroup + ps::kWorkerGroup); 20 | 21 | std::vector tensors; 22 | for (int b = 1; b < g_worker_conf.batch_max; b++) { 23 | // g_worker_conf.batch_max; b++) { 24 | auto start = std::chrono::high_resolution_clock::now(); 25 | auto push_batch = KeyTensorBatch(); 26 | int failed_count = 0; 27 | for (int t = 0; t < g_worker_conf.tensor_num; t++) { 28 | auto push = CreateTensor({b, 7168}, at::kBFloat16, g_conf.gpu, true); 29 | push_batch.push_back(KeyTensor{ 30 | uint64_t((b << 16) + t), 31 | push, 32 | }); 33 | tensors.push_back(push); 34 | } 35 | auto pull_batch = KeyTensorBatch(); 36 | for (int i = 0 ;i < ServerCount ; i ++) { 37 | pull_batch.push_back(KeyTensor{ 38 | uint64_t((b << 16) + g_worker_conf.tensor_num), 39 | CreateTensor({b, 7168}, at::kBFloat16, g_conf.gpu), 40 | }); 41 | } 42 | for (int i = 0; i < 32; i++) { 43 | for (int j = 0; j < ServerCount; j++) { 44 | pull_batch[j].val.zero_(); 45 | } 46 | 47 | af_worker.Wait(af_worker.ZBatchPushPull(push_batch, pull_batch)); 48 | auto sum = CreateTensor({b, 7168}, at::kBFloat16, g_conf.gpu); 49 | for (auto t : push_batch) { 50 | sum += t.val; 51 | } 52 | 53 | 54 | bool fail_flag = false; 55 | for (int j = 0; j < ServerCount; j++) { 56 | if (!sum.allclose(pull_batch[j].val)) { 57 | LOG(WARNING) << "check failed: batch=" << b << " , iter=" << i; 58 | fail_flag = true; 59 | break; 60 | } 61 | } 62 | if (fail_flag) failed_count += 1; 63 | } 64 | auto end = std::chrono::high_resolution_clock::now(); 65 | if (failed_count == 0) { 66 | std::cout << "GPU " << g_conf.gpu << " Batch " << b << ": ALL PASS" 67 | << " duration=" << (end - start).count() << "ns" << std::endl; 68 | } else { 69 | std::cout << "GPU " << g_conf.gpu << " Batch " << b << " FAILED " 70 | << failed_count << "/" << 32 71 | << " duration=" << (end - start).count() << "ns" << std::endl; 72 | } 73 | } 74 | 75 | // stop worker 76 | Finalize(0, Node::WORKER, true); 77 | PS_LOG(INFO) << "Simulated attention worker is DONE"; 78 | } 79 | 80 | int main(int argc, char *argv[]) { 81 | InitUtestConfig(); 82 | StartWorkers(); 83 | return 0; 84 | } 85 | -------------------------------------------------------------------------------- /tests/fserver/test_fserver_diff_stage.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | import time 3 | import fserver_lib as f 4 | import random 5 | 6 | is_worker = os.environ.get('DMLC_ROLE') == 'worker' 7 | is_server = os.environ.get('DMLC_ROLE') == 'server' 8 | server_count = int(os.environ.get('DMLC_NUM_SERVER','1')) 9 | worker_count = int(os.environ.get('DMLC_NUM_WORKER','1')) 10 | 11 | iter_count = 100 12 | 13 | 14 | if is_worker: 15 | f.init() 16 | f.barrier(False, True) 17 | gpu = os.environ.get('STEPMESH_GPU') 18 | print(f"worker gpu: {gpu}") 19 | max_num_tokens = 128 20 | recv_buffer: list[list[torch.Tensor]] = [ 21 | [ 22 | torch.empty((max_num_tokens, 4), dtype=torch.bfloat16, device=torch.device('cuda')) for _ in range(1) 23 | ] for _ in range(3) 24 | ] 25 | send_buffer: list[torch.Tensor] = [ 26 | torch.empty((max_num_tokens, 4), dtype=torch.bfloat16, device=torch.device('cuda')) 27 | for _ in range(3) 28 | ] 29 | 30 | for i in range(iter_count): 31 | for stage in range(3): 32 | send_buffer[stage].random_() 33 | if i == 0: 34 | rand_width = 128 35 | else: 36 | rand_width = 128 #random.randint(1, 128) 37 | push_tensors = [send_buffer[stage][:]] 38 | print(f"{send_buffer[stage].data_ptr()}, {push_tensors[0].data_ptr()}") 39 | pull_tensors = [recv_buffer[stage][0]] 40 | handler =f.push_pull( 41 | push_tensors, 42 | [stage], 43 | pull_tensors, 44 | [stage + i for i in range(server_count)], 45 | ) 46 | print(f"iter: {i}, stage: {stage}, handler: {handler}") 47 | f.wait(handler) 48 | print(f"match :{pull_tensors}, {push_tensors}") 49 | 50 | print(f"iter: {i}") 51 | print("worker test done") 52 | 53 | elif is_server: 54 | f.init() 55 | gpu = os.environ.get('STEPMESH_GPU') 56 | ret_buffer = torch.rand([1, 2048], dtype=torch.float32, device=f'cuda:{gpu}') 57 | f.barrier(True, False) 58 | batches = [] 59 | for i in range(iter_count * 3): 60 | batches = f.get_batch() 61 | if len(batches) != 0: 62 | # recv_tensor_list = batches[0][1] 63 | # comm_id_list = [batches[0][0]] 64 | # f.respond_vec(ret_buffer, recv_tensor_list, comm_id_list) 65 | # buff = ret_buffer[:,:batches[0][1][0].size(1)] 66 | print(batches) 67 | # time.sleep(1) 68 | # recv_tensor_list = [batches[i][1][0] for i in range(worker_count)] 69 | # comm_id_list = [batches[i][0] for i in range(worker_count)] 70 | 71 | # f.respond_vec(ret_buffer, recv_tensor_list, comm_id_list) 72 | for i in range(worker_count): 73 | f.respond(batches[i][1], batches[i][0], True) 74 | print(f"iter: {i//3}, stage: {i%3}, server ") 75 | 76 | else: 77 | f.init() 78 | f.stop() 79 | -------------------------------------------------------------------------------- /src/meta.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2018-2019 Bytedance Inc. 3 | * Author: zhuyibo@bytedance.com (Yibo Zhu) 4 | * Modifications Copyright (C) by StepAI Contributors. 2025. 5 | */ 6 | #ifndef META_H_ 7 | #define META_H_ 8 | 9 | #include 10 | 11 | #include "ps/internal/trace.h" 12 | 13 | namespace ps { 14 | 15 | struct RawNode { 16 | // the node role 17 | int role; 18 | // node id 19 | int id; 20 | // hostname or ip 21 | char hostname[64]; 22 | // number of ports 23 | int num_ports; 24 | // all the ports this node is binding 25 | int ports[32]; 26 | // the port this node is binding (ports[0]) 27 | int port; 28 | // the type of devices 29 | int dev_types[32]; 30 | // the id of devices 31 | int dev_ids[32]; 32 | // whether this node is created by failover 33 | bool is_recovery; 34 | // the locally unique id of an customer 35 | int customer_id; 36 | // endpoint name; 37 | char endpoint_name[64]; 38 | // endpoint name len; 39 | size_t endpoint_name_len; 40 | // auxilary id 41 | int aux_id; 42 | }; 43 | 44 | // system control info 45 | struct RawControl { 46 | int cmd; 47 | int node_size; 48 | int barrier_group; 49 | uint64_t msg_sig; 50 | }; 51 | 52 | // mete information about a message 53 | struct RawMeta { 54 | // message.head 55 | int head; 56 | // message.body 57 | int body_size; 58 | // if set, then it is system control task. otherwise, it is for app 59 | RawControl control; 60 | // true: a request task 61 | // false: the response task to the request task with the same *time* 62 | bool request; 63 | // the unique id of an application 64 | int app_id; 65 | // the timestamp of this message 66 | int timestamp; 67 | // data type of message.data[i] 68 | int data_type_size; 69 | /** \brief src device type of message.data[i] */ 70 | int src_dev_type; 71 | /** \brief src device id of message.data[i] */ 72 | int src_dev_id; 73 | /** \brief dst device type of message.data[i] */ 74 | int dst_dev_type; 75 | /** \brief dst device id of message.data[i] */ 76 | int dst_dev_id; 77 | // the locally unique id of an customer 78 | int customer_id; 79 | // whether or not a push message 80 | bool push; 81 | // whether or not it's for SimpleApp 82 | bool simple_app; 83 | // message.data_size 84 | int data_size; 85 | // message.key 86 | uint64_t key; 87 | // message.addr 88 | uint64_t addr; 89 | // the length of the message's value 90 | int val_len; 91 | // the option field 92 | int option; 93 | // the sequence id 94 | int sid; 95 | // is a tensor 96 | int is_tensor; 97 | // tensor dtype 98 | int dtype; 99 | // tensor dimension 100 | int dim; 101 | // tensor shape 102 | int64_t shape[8]; 103 | // #ifdef STEPMESH_ENABLE_TRACE 104 | // timestamp traces for the request message 105 | struct Trace request_trace; 106 | // timestamp traces for the response message 107 | struct Trace response_trace; 108 | // #endif 109 | // counter fro each qp 110 | uint64_t slave_qp_counter[QP_MAX_NUM]; 111 | // the number of slave qp 112 | int slave_qp_num; 113 | // body 114 | // data_type 115 | // node 116 | }; 117 | 118 | } // namespace ps 119 | 120 | #endif // META_H_ 121 | -------------------------------------------------------------------------------- /tests/utests/stepmesh_pull_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (C) 2025 by StepAI Contributors. 3 | */ 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #include "test_common.h" 12 | 13 | std::unordered_map g_mem; 14 | 15 | void PullHandler(const AFTensorMeta& req_meta, AFTensorServer* server) { 16 | auto key = req_meta.pull_tensors[0].key; 17 | KeyTensor key_tensor; 18 | key_tensor.key = key; 19 | auto iter = g_mem.find(key); 20 | if (iter != g_mem.end()) { 21 | key_tensor.val = iter->second; 22 | } else { 23 | key_tensor.val = CreateTensor({g_conf.size}, 24 | at::kBFloat16, g_conf.gpu, true); 25 | g_mem[key] = key_tensor.val; 26 | } 27 | 28 | server->Response(req_meta, { key_tensor }); 29 | } 30 | 31 | void RunWorker(AFTensorWorker* kv) { 32 | auto pull_tensor = CreateTensor({g_conf.size}, 33 | at::kBFloat16, g_conf.gpu, false); 34 | auto PushPull = [kv, pull_tensor] () { 35 | auto start = std::chrono::high_resolution_clock::now(); 36 | std::vector timestamps; 37 | auto push_batch = KeyTensorBatch(); 38 | auto pull_batch = KeyTensorBatch(); 39 | pull_batch.push_back(KeyTensor{ 40 | 1, pull_tensor, 41 | }); 42 | 43 | kv->Wait(kv->ZBatchPushPull(push_batch, pull_batch)); 44 | auto end = std::chrono::high_resolution_clock::now(); 45 | return (end - start).count(); 46 | }; 47 | 48 | PS_LOG(INFO) << "warmup starts"; 49 | std::vector timestamps; 50 | 51 | for (int iter = 0; iter < g_conf.iter; iter++) { 52 | auto pushpull_ts = PushPull(); 53 | timestamps.emplace_back(pushpull_ts); 54 | if ((iter % 1000 == 999)) { 55 | DumpLatency("pull batch latency: ", timestamps); 56 | timestamps.clear(); 57 | } 58 | } 59 | } 60 | 61 | void StartFFNServer() { 62 | PS_LOG(INFO) << "Run server: gpu=" << g_conf.gpu 63 | << ", node rank=" << g_conf.node_rank 64 | << ", group size=" << g_conf.group_size; 65 | StartPS(0, Node::SERVER, 66 | g_conf.node_rank * g_conf.group_size + g_conf.gpu, true); 67 | Backend::Get()->SetDevice(g_conf.gpu); 68 | StartServer(PullHandler); 69 | Finalize(0, Node::SERVER, true); 70 | PS_LOG(INFO) << "Server ends"; 71 | } 72 | 73 | void StartWorkers() { 74 | PS_LOG(INFO) << "run worker: gpu=" << g_conf.gpu 75 | << ", node rank=" << g_conf.node_rank 76 | << ", group size=" << g_conf.group_size; 77 | StartPS(0, Node::WORKER, 78 | g_conf.node_rank * g_conf.group_size + g_conf.gpu, true); 79 | Backend::Get()->SetDevice(g_conf.gpu); 80 | AFTensorWorker af_worker(g_conf.gpu); 81 | InitWorker(&af_worker); 82 | RunWorker(&af_worker); 83 | Finalize(0, Node::WORKER, true); 84 | PS_LOG(INFO) << "Simulated worker is DONE"; 85 | } 86 | 87 | int main(int argc, char *argv[]) { 88 | InitConfig(); 89 | PS_LOG(INFO) << "StepMesh Echo Tests: gpu_num=" 90 | << g_conf.gpu_num << ", role=" << g_conf.role_str; 91 | if (g_conf.role == Node::SCHEDULER) { 92 | StartScheduler(); 93 | } else if (g_conf.role == Node::SERVER) { 94 | StartFFNServer(); 95 | } else if (g_conf.role == Node::WORKER) { 96 | StartWorkers(); 97 | } 98 | return 0; 99 | } 100 | -------------------------------------------------------------------------------- /docs/how_to.md: -------------------------------------------------------------------------------- 1 | # How To 2 | 3 | ## Debug PS-Lite 4 | 5 | One way to debug is loggining all communications. We can do it by specifying 6 | the environment variable `PS_VERBOSE`: 7 | - `PS_VERBOSE=1`: logging connection information 8 | - `PS_VERBOSE=2`: logging all data communication information 9 | 10 | For example, first run `make test; cd tests` in the root directory. Then 11 | ```bash 12 | export PS_VERBOSE=1; ./local.sh 1 1 ./test_connection 13 | ``` 14 | Possible outputs are 15 | ```bash 16 | [19:57:18] src/van.cc:72: Node Info: role=schedulerid=1, ip=127.0.0.1, port=8000 17 | [19:57:18] src/van.cc:72: Node Info: role=worker, ip=128.2.211.110, port=58442 18 | [19:57:18] src/van.cc:72: Node Info: role=server, ip=128.2.211.110, port=40112 19 | [19:57:18] src/van.cc:336: assign rank=8 to node role=server, ip=128.2.211.110, port=40112 20 | [19:57:18] src/van.cc:336: assign rank=9 to node role=worker, ip=128.2.211.110, port=58442 21 | [19:57:18] src/van.cc:347: the scheduler is connected to 1 workers and 1 servers 22 | [19:57:18] src/van.cc:354: S[8] is connected to others 23 | [19:57:18] src/van.cc:354: W[9] is connected to others 24 | [19:57:18] src/van.cc:296: H[1] is stopped 25 | [19:57:18] src/van.cc:296: S[8] is stopped 26 | [19:57:18] src/van.cc:296: W[9] is stopped 27 | ``` 28 | where `H`, `S` and `W` stand for scheduler, server, and worker respectively. 29 | 30 | ## Use a Particular Network Interface 31 | 32 | In default PS-Lite automatically chooses an available network interface. But for 33 | machines have multiple interfaces, we can specify the network interface to use 34 | by the environment variable `DMLC_INTERFACE`. For example, to use the 35 | infinite-band interface `ib0`, we can 36 | ```bash 37 | export DMLC_INTERFACE=ib0; commands_to_run 38 | ``` 39 | 40 | If all PS-Lite nodes run in the same machine, we can set `DMLC_LOCAL` to use 41 | memory copy rather than the local network interface, which may improve the 42 | performance: 43 | ```bash 44 | export DMLC_LOCAL=1; commands_to_run 45 | ``` 46 | 47 | ## Environment Variables to Start PS-Lite 48 | 49 | This section is useful if we want to port PS-Lite to other cluster resource 50 | managers besides the provided ones such as `ssh`, `mpirun`, `yarn` and `sge`. 51 | 52 | To start a PS-Lite node, we need to give proper values to the following 53 | environment variables. 54 | - `DMLC_NUM_WORKER` : the number of workers 55 | - `DMLC_NUM_SERVER` : the number of servers 56 | - `DMLC_ROLE` : the role of the current node, can be `worker`, `server`, or `scheduler` 57 | - `DMLC_PS_ROOT_URI` : the ip or hostname of the scheduler node 58 | - `DMLC_PS_ROOT_PORT` : the port that the scheduler node is listening 59 | 60 | ## Retransmission for Unreliable Network 61 | 62 | It's not uncommon that a message disappear when sending from one node to another 63 | node. The program hangs when a critical message is not delivered 64 | successfully. In that case, we can let PS-Lite send an additional ACK for each 65 | message, and resend that message if the ACK is not received within a given 66 | time. To enable this feature, we can set the environment variables 67 | 68 | - `PS_RESEND` : if or not enable retransmission. Default is 0. 69 | - `PS_RESEND_TIMEOUT` : timeout in millisecond if an ACK message if not 70 | received. PS-Lite then will resend that message. Default is 1000. 71 | 72 | We can set `PS_DROP_MSG`, the percent of probability to drop a received 73 | message, for testing. For example, `PS_DROP_MSG=10` will let a node drop a 74 | received message with 10% probability. 75 | -------------------------------------------------------------------------------- /tracker/dmlc_mpi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | DMLC submission script, MPI version 4 | """ 5 | import argparse 6 | import sys 7 | import os 8 | import subprocess 9 | import tracker 10 | from threading import Thread 11 | 12 | parser = argparse.ArgumentParser(description='DMLC script to submit dmlc job using MPI') 13 | parser.add_argument('-n', '--nworker', required=True, type=int, 14 | help = 'number of worker proccess to be launched') 15 | parser.add_argument('-s', '--server-nodes', default = 0, type=int, 16 | help = 'number of server nodes to be launched') 17 | parser.add_argument('--log-level', default='INFO', type=str, 18 | choices=['INFO', 'DEBUG'], 19 | help = 'logging level') 20 | parser.add_argument('--log-file', type=str, 21 | help = 'output log to the specific log file') 22 | parser.add_argument('-H', '--hostfile', type=str, 23 | help = 'the hostfile of mpi server') 24 | parser.add_argument('command', nargs='+', 25 | help = 'command for dmlc program') 26 | parser.add_argument('--host-ip', type=str, 27 | help = 'the scheduler ip', default='ip') 28 | args, unknown = parser.parse_known_args() 29 | # 30 | # submission script using MPI 31 | # 32 | 33 | def get_mpi_env(envs): 34 | """get the mpirun command for setting the envornment 35 | 36 | support both openmpi and mpich2 37 | """ 38 | outfile="/tmp/mpiver" 39 | os.system("mpirun 1>/tmp/mpiver 2>/tmp/mpiver") 40 | with open (outfile, "r") as infile: 41 | mpi_ver = infile.read() 42 | cmd = '' 43 | if 'Open MPI' in mpi_ver: 44 | for k, v in envs.items(): 45 | cmd += ' -x %s=%s' % (k, str(v)) 46 | elif 'mpich' in mpi_ver: 47 | for k, v in envs.items(): 48 | cmd += ' -env %s %s' % (k, str(v)) 49 | else: 50 | raise Exception('unknow mpi version %s' % (mpi_ver)) 51 | 52 | return cmd 53 | 54 | def mpi_submit(nworker, nserver, pass_envs): 55 | """ 56 | customized submit script, that submit nslave jobs, each must contain args as parameter 57 | note this can be a lambda function containing additional parameters in input 58 | Parameters 59 | nworker number of slave process to start up 60 | nserver number of server nodes to start up 61 | pass_envs enviroment variables to be added to the starting programs 62 | """ 63 | def run(prog): 64 | """""" 65 | subprocess.check_call(prog, shell = True) 66 | 67 | cmd = '' 68 | if args.hostfile is not None: 69 | cmd = '--hostfile %s' % (args.hostfile) 70 | cmd += ' ' + ' '.join(args.command) + ' ' + ' '.join(unknown) 71 | 72 | # start servers 73 | if nserver > 0: 74 | pass_envs['DMLC_ROLE'] = 'server' 75 | prog = 'mpirun -n %d %s %s' % (nserver, get_mpi_env(pass_envs), cmd) 76 | thread = Thread(target = run, args=(prog,)) 77 | thread.setDaemon(True) 78 | thread.start() 79 | 80 | if nworker > 0: 81 | pass_envs['DMLC_ROLE'] = 'worker' 82 | prog = 'mpirun -n %d %s %s' % (nworker, get_mpi_env(pass_envs), cmd) 83 | thread = Thread(target = run, args=(prog,)) 84 | thread.setDaemon(True) 85 | thread.start() 86 | 87 | tracker.config_logger(args) 88 | 89 | tracker.submit(args.nworker, args.server_nodes, fun_submit = mpi_submit, 90 | hostIP=args.host_ip, 91 | pscmd=(' '.join(args.command) + ' ' + ' '.join(unknown))) 92 | -------------------------------------------------------------------------------- /tracker/dmlc_local.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | DMLC submission script, local machine version 4 | """ 5 | 6 | import argparse 7 | import sys 8 | import os 9 | import subprocess 10 | from threading import Thread 11 | import tracker 12 | import signal 13 | import logging 14 | 15 | keepalive = """ 16 | nrep=0 17 | rc=254 18 | while [ $rc -eq 254 ]; 19 | do 20 | export DMLC_NUM_ATTEMPT=$nrep 21 | %s 22 | rc=$?; 23 | nrep=$((nrep+1)); 24 | done 25 | """ 26 | 27 | class LocalLauncher(object): 28 | def __init__(self, args, unknown): 29 | self.args = args 30 | self.cmd = ' '.join(args.command) + ' ' + ' '.join(unknown) 31 | 32 | def exec_cmd(self, cmd, role, pass_env): 33 | env = os.environ.copy() 34 | for k, v in pass_env.items(): 35 | env[k] = str(v) 36 | 37 | env['DMLC_ROLE'] = role 38 | 39 | ntrial = 0 40 | while True: 41 | if os.name == 'nt': 42 | env['DMLC_NUM_ATTEMPT'] = str(ntrial) 43 | ret = subprocess.call(cmd, shell=True, env = env) 44 | if ret == 254: 45 | ntrial += 1 46 | continue 47 | else: 48 | bash = keepalive % (cmd) 49 | ret = subprocess.call(bash, shell=True, executable='bash', env = env) 50 | if ret == 0: 51 | logging.debug('Thread %d exit with 0') 52 | return 53 | else: 54 | if os.name == 'nt': 55 | os.exit(-1) 56 | else: 57 | raise Exception('Get nonzero return code=%d' % ret) 58 | 59 | def submit(self): 60 | def mthread_submit(nworker, nserver, envs): 61 | """ 62 | customized submit script 63 | """ 64 | procs = {} 65 | for i in range(nworker + nserver): 66 | role = 'worker' if i < nworker else 'server' 67 | procs[i] = Thread(target = self.exec_cmd, args = (self.cmd, role, envs)) 68 | procs[i].setDaemon(True) 69 | procs[i].start() 70 | return mthread_submit 71 | 72 | def run(self): 73 | tracker.config_logger(self.args) 74 | tracker.submit(self.args.num_workers, 75 | self.args.num_servers, 76 | fun_submit = self.submit(), 77 | pscmd = self.cmd) 78 | 79 | def main(): 80 | parser = argparse.ArgumentParser( 81 | description='DMLC script to submit dmlc jobs as local process') 82 | 83 | parser.add_argument('-n', '--num-workers', required=True, type=int, 84 | help = 'number of worker nodes to be launched') 85 | parser.add_argument('-s', '--num-servers', type=int, 86 | help = 'number of server nodes to be launched') 87 | parser.add_argument('--log-level', default='INFO', type=str, 88 | choices=['INFO', 'DEBUG'], 89 | help = 'logging level') 90 | parser.add_argument('--log-file', type=str, 91 | help = 'output log to the specific log file') 92 | parser.add_argument('command', nargs='+', 93 | help = 'command for launching the program') 94 | args, unknown = parser.parse_known_args() 95 | 96 | launcher = LocalLauncher(args, unknown) 97 | launcher.run() 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from setuptools import setup 3 | from torch.utils import cpp_extension 4 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 5 | import subprocess 6 | import pathlib 7 | import os 8 | import re 9 | import sys 10 | from pathlib import Path 11 | 12 | def get_version(): 13 | version = '0.0.5.post1' 14 | # with open('stepkv/version.py', 'r') as fd: 15 | # version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', 16 | # fd.read(), re.MULTILINE).group(1) 17 | if len(sys.argv) >= 2: 18 | if sys.argv[1] == 'bdist_wheel': 19 | import torch 20 | torch_version = torch.__version__.replace("+", "") 21 | version = f"{version}+torch{torch_version}" 22 | assert version, 'Cannot find version information' 23 | print(version) 24 | return version 25 | 26 | def _get_cuda_bare_metal_version(cuda_dir): 27 | assert cuda_dir is not None, "Please ensure cuda is installed" 28 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], 29 | universal_newlines=True) 30 | output = raw_output.split() 31 | release_idx = output.index("release") + 1 32 | release = output[release_idx].split(".") 33 | bare_metal_major = release[0] 34 | bare_metal_minor = release[1][0] 35 | 36 | return bare_metal_major, bare_metal_minor 37 | 38 | 39 | __SRC_PATH__ = 'fserver/csrc/' 40 | __PS_PATH__ = f'{Path.cwd()}' 41 | 42 | if __name__ == "__main__": 43 | cc_flag = [] 44 | 45 | torch_cxx11_abi = torch.compiled_with_cxx11_abi() 46 | use_cuda = os.environ.get("USE_CUDA",'1')=='1' 47 | extra_link = ['-lrdmacm', '-libverbs'] 48 | extra_compile_args={ 49 | 'cxx': [ 50 | '-O3', '-fPIC', 51 | f'-I{__PS_PATH__}/include', 52 | f'-D_GLIBCXX_USE_CXX11_ABI={str(int(torch_cxx11_abi))}', 53 | '-DDMLC_USE_ZMQ', 54 | '-DSTEPMESH_USE_GDR', 55 | '-DDMLC_USE_RDMA', 56 | '-DSTEPMESH_USE_TORCH', 57 | '-DSTEPMESH_ENABLE_TRACE', 58 | '-fvisibility=hidden', 59 | ], 60 | 'nvcc': [], 61 | } 62 | if use_cuda: 63 | extra_link += ['-lcuda', '-lcudart'] 64 | extra_compile_args['cxx'] += ['-DDMLC_USE_CUDA',] 65 | extra_compile_args['nvcc'] = ['-O3', '-gencode', 'arch=compute_90,code=sm_90', '-gencode', 'arch=compute_80,code=sm_80', '-gencode', 'arch=compute_89,code=sm_89','-gencode', 'arch=compute_90a,code=sm_90a', 66 | '--use_fast_math', f'-D_GLIBCXX_USE_CXX11_ABI={str(int(torch_cxx11_abi))}'] + cc_flag 67 | bare_metal_major, bare_metal_minor = \ 68 | _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) 69 | 70 | setup( 71 | name='FServer', 72 | description='A Remote FFN Server Implementation for AF Disaggregation', 73 | author='StepFun', 74 | version=get_version(), 75 | packages=['fserver'], 76 | url='', 77 | ext_modules=[ 78 | CUDAExtension( 79 | 'fserver_lib', 80 | [ 81 | __SRC_PATH__ + 'ops.cc', 82 | __SRC_PATH__ + 'wait_kernel.cu', 83 | ], 84 | extra_compile_args=extra_compile_args, 85 | extra_link_args=extra_link, 86 | extra_objects=[f"{__PS_PATH__}/cmake_build/libaf.a", f"{__PS_PATH__}/deps/lib/libzmq.a"], 87 | ) 88 | ], 89 | cmdclass={ 90 | 'build_ext': BuildExtension 91 | } 92 | ) 93 | -------------------------------------------------------------------------------- /include/ps/internal/env.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016 by Contributors 3 | * Modifications Copyright (C) by StepAI Contributors. 2025. 4 | */ 5 | #ifndef PS_INTERNAL_ENV_H_ 6 | #define PS_INTERNAL_ENV_H_ 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | namespace ps { 13 | 14 | /** 15 | * \brief Environment configurations 16 | */ 17 | class Environment { 18 | public: 19 | /** 20 | * \brief return the singleton instance 21 | */ 22 | static inline Environment* Get() { return _GetSharedRef(nullptr).get(); } 23 | /** 24 | * \brief return a shared ptr of the singleton instance 25 | */ 26 | static inline std::shared_ptr _GetSharedRef() { 27 | return _GetSharedRef(nullptr); 28 | } 29 | /** 30 | * \brief initialize the environment 31 | * \param envs key-value environment variables 32 | * \return the initialized singleton instance 33 | */ 34 | static inline Environment* Init( 35 | const std::unordered_map& envs) { 36 | Environment* env = _GetSharedRef(&envs).get(); 37 | env->kvs = envs; 38 | return env; 39 | } 40 | 41 | /** 42 | * \brief find the env value. 43 | * User-defined env vars first. If not found, check system's environment 44 | * \param k the environment key 45 | * \return the related environment value, nullptr when not found 46 | */ 47 | const char* find(const char* k) { 48 | std::string key(k); 49 | return kvs.find(key) == kvs.end() ? getenv(k) : kvs[key].c_str(); 50 | } 51 | 52 | void set(const char* k, const std::string& rewrite_val) { 53 | std::string key(k); 54 | kvs[key] = std::move(rewrite_val); 55 | } 56 | 57 | /** 58 | * \brief find the env value with the integer type with a default value. 59 | * 60 | * \param k the environment key 61 | * \param val the pointer to the environment result with the integer type 62 | * \param default_val the default environment result when the environment 63 | * key is not found 64 | * \return the integer value or the default value 65 | */ 66 | int find(const char* k, int* val, int default_val = -1) { 67 | std::string key(k); 68 | auto val_str = kvs.find(key) == kvs.end() ? getenv(k) : kvs[key].c_str(); 69 | if (val_str == nullptr) { 70 | *val = default_val; 71 | return -1; 72 | } else { 73 | *val = atoi(val_str); 74 | return 0; 75 | } 76 | } 77 | 78 | /** 79 | * \brief find the env value with the string type with a default value. 80 | * 81 | * \param k the environment key 82 | * \param val the pointer to the environment result with the integer type 83 | * \param default_val the default environment result when the environment 84 | * key is not found 85 | * \return the string value 86 | */ 87 | std::string find(const char* k, std::string& default_value) { 88 | std::string key(k); 89 | if (kvs.find(key) != kvs.end()) { 90 | return kvs[key]; 91 | } 92 | 93 | if (getenv(k) == nullptr) { 94 | return default_value; 95 | } 96 | 97 | return {getenv(k)}; 98 | } 99 | 100 | private: 101 | explicit Environment( 102 | const std::unordered_map* envs) { 103 | if (envs) kvs = *envs; 104 | } 105 | 106 | static std::shared_ptr _GetSharedRef( 107 | const std::unordered_map* envs) { 108 | static std::shared_ptr inst_ptr(new Environment(envs)); 109 | return inst_ptr; 110 | } 111 | 112 | std::unordered_map kvs; 113 | }; 114 | 115 | } // namespace ps 116 | #endif // PS_INTERNAL_ENV_H_ 117 | -------------------------------------------------------------------------------- /cmake/ProtoBuf.cmake: -------------------------------------------------------------------------------- 1 | # Finds Google Protocol Buffers library and compilers and extends 2 | # the standard cmake script with version and python generation support 3 | 4 | find_package( Protobuf REQUIRED ) 5 | include_directories(SYSTEM ${PROTOBUF_INCLUDE_DIR}) 6 | 7 | 8 | # As of Ubuntu 14.04 protoc is no longer a part of libprotobuf-dev package 9 | # and should be installed separately as in: sudo apt-get install protobuf-compiler 10 | if(EXISTS ${PROTOBUF_PROTOC_EXECUTABLE}) 11 | message(STATUS "Found PROTOBUF Compiler: ${PROTOBUF_PROTOC_EXECUTABLE}") 12 | else() 13 | message(FATAL_ERROR "Could not find PROTOBUF Compiler") 14 | endif() 15 | 16 | 17 | # place where to generate protobuf sources 18 | set(proto_gen_folder "${PROJECT_BINARY_DIR}/include/pslite/proto") 19 | include_directories(SYSTEM "${PROJECT_BINARY_DIR}/include") 20 | 21 | set(PROTOBUF_GENERATE_CPP_APPEND_PATH TRUE) 22 | 23 | ################################################################################################ 24 | # Modification of standard 'protobuf_generate_cpp()' with output dir parameter and python support 25 | # Usage: 26 | # pslite_protobuf_generate_cpp_py( ) 27 | function(pslite_protobuf_generate_cpp_py output_dir srcs_var hdrs_var python_var work_path proto_path) 28 | if(NOT ARGN) 29 | message(SEND_ERROR "Error: pslite_protobuf_generate_cpp_py() called without any proto files") 30 | return() 31 | endif() 32 | 33 | if(PROTOBUF_GENERATE_CPP_APPEND_PATH) 34 | # Create an include path for each file specified 35 | foreach(fil ${ARGN}) 36 | get_filename_component(abs_fil ${fil} ABSOLUTE) 37 | get_filename_component(abs_path ${abs_fil} PATH) 38 | list(FIND _protoc_include ${abs_path} _contains_already) 39 | if(${_contains_already} EQUAL -1) 40 | list(APPEND _protoc_include -I ${abs_path}) 41 | endif() 42 | endforeach() 43 | else() 44 | set(_protoc_include -I ${CMAKE_CURRENT_SOURCE_DIR}) 45 | endif() 46 | 47 | if(DEFINED PROTOBUF_IMPORT_DIRS) 48 | foreach(dir ${PROTOBUF_IMPORT_DIRS}) 49 | get_filename_component(abs_path ${dir} ABSOLUTE) 50 | list(FIND _protoc_include ${abs_path} _contains_already) 51 | if(${_contains_already} EQUAL -1) 52 | list(APPEND _protoc_include -I ${abs_path}) 53 | endif() 54 | endforeach() 55 | endif() 56 | 57 | set(${srcs_var}) 58 | set(${hdrs_var}) 59 | set(${python_var}) 60 | foreach(fil ${ARGN}) 61 | get_filename_component(abs_fil ${fil} ABSOLUTE) 62 | get_filename_component(fil_we ${fil} NAME_WE) 63 | string(REPLACE ${work_path}/ "" o_fil ${abs_fil}) 64 | string(REPLACE "${fil_we}.proto" "" o_fil_path ${o_fil}) 65 | 66 | list(APPEND ${srcs_var} "${o_fil_path}/${fil_we}.pb.cc") 67 | list(APPEND ${hdrs_var} "${o_fil_path}/${fil_we}.pb.h") 68 | list(APPEND ${python_var} "${o_fil_path}/${fil_we}_pb2.py") 69 | 70 | add_custom_command( 71 | OUTPUT "${o_fil_path}/${fil_we}.pb.cc" 72 | "${o_fil_path}/${fil_we}.pb.h" 73 | "${o_fil_path}/${fil_we}_pb2.py" 74 | COMMAND ${CMAKE_COMMAND} -E make_directory "${output_dir}" 75 | COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --cpp_out ${output_dir} ${o_fil} --proto_path ${proto_path} 76 | COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --python_out ${output_dir} ${o_fil} --proto_path ${proto_path} 77 | DEPENDS ${abs_fil} 78 | WORKING_DIRECTORY ${work_path} 79 | COMMENT "Running C++/Python protocol buffer compiler on ${o_fil}" VERBATIM ) 80 | endforeach() 81 | 82 | set_source_files_properties(${${srcs_var}} ${${hdrs_var}} ${${python_var}} PROPERTIES GENERATED TRUE) 83 | set(${srcs_var} ${${srcs_var}} PARENT_SCOPE) 84 | set(${hdrs_var} ${${hdrs_var}} PARENT_SCOPE) 85 | set(${python_var} ${${python_var}} PARENT_SCOPE) 86 | endfunction() 87 | -------------------------------------------------------------------------------- /tests/utests/stepmesh_echo_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Step AI 3 | */ 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #include "test_common.h" 12 | 13 | std::unordered_map g_mem; 14 | 15 | void EchoHandler(const AFTensorMeta& req_meta, AFTensorServer* server) { 16 | auto key = req_meta.pull_tensors[0].key; 17 | KeyTensor key_tensor; 18 | key_tensor.key = key; 19 | auto iter = g_mem.find(key); 20 | if (iter != g_mem.end()) { 21 | key_tensor.val = iter->second; 22 | } else { 23 | key_tensor.val = CreateTensor({g_conf.size}, 24 | at::kBFloat16, g_conf.gpu, true); 25 | g_mem[key] = key_tensor.val; 26 | } 27 | 28 | server->Response(req_meta, { key_tensor }); 29 | } 30 | 31 | void RunWorker(AFTensorWorker* kv) { 32 | auto push_tensor = CreateTensor({g_conf.size}, 33 | at::kBFloat16, g_conf.gpu, true); 34 | auto pull_tensor = CreateTensor({g_conf.size}, 35 | at::kBFloat16, g_conf.gpu, false); 36 | auto PushPull = [kv, push_tensor, pull_tensor] () { 37 | auto start = std::chrono::high_resolution_clock::now(); 38 | std::vector timestamps; 39 | auto push_batch = KeyTensorBatch(); 40 | push_batch.push_back( 41 | KeyTensor{ 42 | 0, push_tensor, 43 | }); 44 | auto pull_batch = KeyTensorBatch(); 45 | pull_batch.push_back(KeyTensor{ 46 | 1, pull_tensor, 47 | }); 48 | 49 | kv->Wait(kv->ZBatchPushPull(push_batch, pull_batch)); 50 | auto end = std::chrono::high_resolution_clock::now(); 51 | return (end - start).count(); 52 | }; 53 | 54 | PS_LOG(INFO) << "warmup starts"; 55 | std::vector overall_timestamps; 56 | std::vector timestamps; 57 | 58 | for (int iter = 0; iter < g_conf.iter; iter++) { 59 | auto pushpull_ts = PushPull(); 60 | overall_timestamps.emplace_back(pushpull_ts); 61 | timestamps.emplace_back(pushpull_ts); 62 | if ((iter % 1000 == 999)) { 63 | DumpLatency("pushpull batch latency: ", timestamps); 64 | timestamps.clear(); 65 | } 66 | } 67 | 68 | DumpLatency("pushpull overall latency: ", overall_timestamps); 69 | } 70 | 71 | void StartEchoServer() { 72 | PS_LOG(INFO) << "run server: gpu=" << g_conf.gpu 73 | << ", node rank=" << g_conf.node_rank 74 | << ", group size=" << g_conf.group_size; 75 | StartPS(0, Node::SERVER, 76 | g_conf.node_rank * g_conf.group_size + g_conf.gpu, true); 77 | Backend::Get()->SetDevice(g_conf.gpu); 78 | StartServer(EchoHandler); 79 | Finalize(0, Node::SERVER, true); 80 | PS_LOG(INFO) << "FFN server ends"; 81 | } 82 | 83 | void StartWorkers() { 84 | PS_LOG(INFO) << "run worker: gpu=" << g_conf.gpu 85 | << ", node rank=" << g_conf.node_rank 86 | << ", group size=" << g_conf.group_size; 87 | StartPS(0, Node::WORKER, 88 | g_conf.node_rank * g_conf.group_size + g_conf.gpu, true); 89 | Backend::Get()->SetDevice(g_conf.gpu); 90 | AFTensorWorker af_worker(g_conf.gpu); 91 | InitWorker(&af_worker); 92 | RunWorker(&af_worker); 93 | Finalize(0, Node::WORKER, true); 94 | PS_LOG(INFO) << "Simulated worker is DONE"; 95 | } 96 | 97 | int main(int argc, char *argv[]) { 98 | InitConfig(); 99 | PS_LOG(INFO) << "StepMesh Echo Tests: gpu_num=" 100 | << g_conf.gpu_num << ", role=" << g_conf.role_str; 101 | if (g_conf.role == Node::SCHEDULER) { 102 | StartScheduler(); 103 | } else if (g_conf.role == Node::SERVER) { 104 | StartEchoServer(); 105 | } else if (g_conf.role == Node::WORKER) { 106 | StartWorkers(); 107 | } 108 | return 0; 109 | } 110 | -------------------------------------------------------------------------------- /src/network_utils.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (C) 2025 by StepAI Contributors 3 | */ 4 | 5 | #include "network_utils.h" // NOLINT 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | namespace ps { 12 | 13 | #ifdef DMLC_USE_RDMA 14 | NetDev::NetDev(struct ibv_context* context, int dev_id, struct ibv_device* dev, 15 | int port_id, struct ibv_port_attr port) 16 | : dev_name_(dev->name), 17 | context_(context), 18 | dev_id_(dev_id), 19 | port_(port_id), 20 | link_(port.link_layer), 21 | pci_path_(nullptr), 22 | gid_tbl_len_(port.gid_tbl_len) { 23 | read_pci_path(); 24 | get_best_gid_index(); 25 | } 26 | 27 | NetDev::~NetDev() { 28 | if (pci_path_ != nullptr) { 29 | free(pci_path_); 30 | pci_path_ = nullptr; 31 | } 32 | 33 | if (context_ != nullptr) { 34 | ibv_close_device(context_); 35 | context_ = nullptr; 36 | } 37 | } 38 | 39 | int NetDev::get_port() { return port_; } 40 | 41 | int NetDev::get_link() { return link_; } 42 | 43 | std::string NetDev::get_name() { return dev_name_; } 44 | 45 | char* NetDev::get_pci_path() { return pci_path_; } 46 | 47 | int NetDev::read_pci_path() { 48 | char device_path[1024]; 49 | snprintf(device_path, sizeof(device_path), "/sys/class/infiniband/%s/device", 50 | dev_name_.c_str()); 51 | char* p = realpath(device_path, nullptr); 52 | pci_path_ = p; 53 | if (p == nullptr) { 54 | return -1; 55 | } 56 | return 0; 57 | } 58 | 59 | int NetDev::get_best_gid_index() { 60 | if (gid_idx_ > 0) { 61 | return gid_idx_; 62 | } 63 | int gid_idx = -1; 64 | ibv_gid gid = {}; 65 | for (int i = 1; i < gid_tbl_len_; i++) { 66 | bzero(&gid, sizeof(gid)); 67 | ibv_query_gid(context_, port_, i, &gid); 68 | if ((gid.raw[10] == 0xFF) && (gid.raw[11] == 0xFF)) { 69 | gid_idx = i; 70 | gid_idx_ = i; 71 | memcpy(&gid_, &gid, sizeof(gid)); 72 | } 73 | } 74 | return gid_idx; 75 | } 76 | 77 | int NetDev::get_best_gid(ibv_gid* gid, int* gid_idx) { 78 | if (gid_idx_ > 0) { 79 | *gid_idx = gid_idx_; 80 | memcpy(gid, &gid_, sizeof(gid_)); 81 | return 0; 82 | } 83 | 84 | ibv_gid tmp_gid = {}; 85 | 86 | for (int i = 1; i < gid_tbl_len_; i++) { 87 | bzero(&tmp_gid, sizeof(tmp_gid)); 88 | ibv_query_gid(context_, port_, i, &tmp_gid); 89 | if ((tmp_gid.raw[10] == 0xFF) && (tmp_gid.raw[11] == 0xFF)) { 90 | *gid_idx = i; 91 | gid_idx_ = i; 92 | memcpy(&gid_, &tmp_gid, sizeof(tmp_gid)); 93 | memcpy(gid, &tmp_gid, sizeof(tmp_gid)); 94 | } 95 | } 96 | if (*gid_idx == -1) { 97 | return -1; 98 | } 99 | return 0; 100 | } 101 | 102 | std::string NetDev::get_ip() { 103 | std::stringstream ip; 104 | ip << static_cast(gid_.raw[12]) << "." << static_cast(gid_.raw[13]) 105 | << "." << static_cast(gid_.raw[14]) << "." 106 | << static_cast(gid_.raw[15]); 107 | return ip.str(); 108 | } 109 | 110 | std::string NetDev::get_interface_name() { 111 | int found = 0; 112 | struct ifaddrs *interfaces, *intf; 113 | getifaddrs(&interfaces); 114 | 115 | auto ip = get_ip(); 116 | std::string name = ""; 117 | for (intf = interfaces; intf && found < 32; intf = intf->ifa_next) { 118 | if (intf->ifa_addr == NULL) continue; 119 | 120 | /* We only support IPv4 & IPv6 */ 121 | int family = intf->ifa_addr->sa_family; 122 | if (family != AF_INET) { 123 | continue; 124 | } 125 | 126 | std::string intf_ip = inet_ntoa( 127 | reinterpret_cast(intf->ifa_addr)->sin_addr); 128 | if (intf_ip == ip) { 129 | name = intf->ifa_name; 130 | break; 131 | } 132 | } 133 | 134 | freeifaddrs(interfaces); 135 | return name; 136 | } 137 | 138 | #endif // DMLC_USE_RDMA 139 | } // namespace ps 140 | -------------------------------------------------------------------------------- /tests/fserver/test_fserver_dynamic.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | import time 3 | import fserver_lib as f 4 | import random 5 | 6 | is_worker = os.environ.get('DMLC_ROLE') == 'worker' 7 | is_server = os.environ.get('DMLC_ROLE') == 'server' 8 | server_count = int(os.environ.get('DMLC_NUM_SERVER','1')) 9 | worker_count = int(os.environ.get('DMLC_NUM_WORKER','1')) 10 | 11 | iter_count = 2000 12 | 13 | max_num_tokens = 128 14 | 15 | if is_worker: 16 | f.init() 17 | f.barrier(False, True) 18 | gpu = os.environ.get('STEPMESH_GPU') 19 | rank = int(os.environ.get('DMLC_NODE_RANK', '0')) 20 | print(f"worker gpu: {gpu}") 21 | 22 | recv_buffer: list[list[torch.Tensor]] = [ 23 | [ 24 | torch.empty((max_num_tokens ,7196), dtype=torch.bfloat16, device=torch.device('cuda')).contiguous() for _ in range(server_count) 25 | ] for _ in range(3) 26 | ] 27 | send_buffer: list[torch.Tensor] = [ 28 | torch.empty((max_num_tokens ,7196), dtype=torch.bfloat16, device=torch.device('cuda')).contiguous() 29 | for _ in range(3) 30 | ] 31 | 32 | for i in range(iter_count): 33 | h = [] 34 | pslist_cost = [] 35 | recv = [] 36 | cnt = i%1000 37 | for stage in range(3): 38 | send_buffer[stage].random_() 39 | if i == 0: 40 | rand_width = max_num_tokens 41 | else: 42 | rand_width = random.randint(1, max_num_tokens) 43 | data = torch.rand([rand_width, 7196], dtype=torch.bfloat16, device=torch.device('cuda')) 44 | 45 | push_tensors = [send_buffer[stage][:rand_width,:]] 46 | 47 | push_tensors[0].copy_(data) 48 | 49 | pull_tensors = [recv_buffer[stage][i][:rand_width,:] for i in range(server_count)] 50 | print(f"buffer 0x{send_buffer[stage].data_ptr():x}, tensor 0x{push_tensors[0].data_ptr():x}") 51 | key = stage + int(1e6) 52 | handler =f.push_pull( 53 | push_tensors, 54 | [key], 55 | pull_tensors, 56 | [1000 + stage + i for i in range(server_count)], 57 | ) 58 | 59 | beg = time.perf_counter_ns() 60 | f.wait(handler, 1000 * 500) 61 | end = time.perf_counter_ns() 62 | costs = f.fetch_trace(handler) 63 | pslist_cost.append(costs) 64 | recv.append((end-beg)) 65 | 66 | # print(f"wait time: {wait_time}") 67 | # print(f"Client, iter: {i}, stage: {stage}, handler: {handler}; shape {push_tensors[0].shape}, buffer 0x{send_buffer[stage].data_ptr():x}, tensor 0x{push_tensors[0].data_ptr():x}, Match : {torch.allclose(push_tensors[0], pull_tensors[0])}") 68 | print("\n") 69 | print("worker test done") 70 | f.stop() 71 | 72 | elif is_server: 73 | f.init() 74 | gpu = os.environ.get('STEPMESH_GPU') 75 | ret_buffer = torch.rand([8 * max_num_tokens,7196], dtype=torch.bfloat16, device=f'cuda:{gpu}') 76 | f.barrier(True, False) 77 | batches = [] 78 | for i in range(iter_count * 3): 79 | batches = f.get_batch() 80 | if len(batches) != 0: 81 | # recv_tensor_list = batches[0][1] 82 | # comm_id_list = [batches[0][0]] 83 | # f.respond_vec(ret_buffer, recv_tensor_list, comm_id_list) 84 | # buff = ret_buffer[:batches[0][1][0].size(0),:].contiguous() 85 | # buff.copy_(batches[0][1][0]) 86 | # print(batches) 87 | # print(f"Server iter: {i//3}, stage: {i%3}; buff 0x{ret_buffer.data_ptr():x}, tensor 0x{buff.data_ptr():x},") 88 | for j in range(worker_count): 89 | f.respond(batches[j][1], batches[j][0], need_event=True) 90 | f.stop() 91 | 92 | else: 93 | f.init() 94 | f.stop() 95 | 96 | 97 | -------------------------------------------------------------------------------- /include/ps/internal/backend.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (C) by StepAI Contributors. 2025. 3 | */ 4 | #ifndef PS_INTERNAL_BACKEND_H_ 5 | #define PS_INTERNAL_BACKEND_H_ 6 | 7 | #include 8 | #ifdef DMLC_USE_CUDA 9 | #include 10 | #include 11 | #endif 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include "dmlc/logging.h" 20 | #include "ps/internal/env.h" 21 | 22 | namespace ps { 23 | 24 | enum { BACKEND_OK = 0, BACKEND_FAILED = -1 }; 25 | 26 | /** 27 | * \brief Abstract Backend Class 28 | */ 29 | class Backend { 30 | public: 31 | /** 32 | * \brief Set device index for current thread 33 | * @param dev device index 34 | * @return BACKEND_OK if success to set device, otherwise BACKEND_FAILED 35 | */ 36 | virtual int SetDevice(int dev) = 0; 37 | 38 | /** 39 | * \brief Get device index for current thread 40 | * @return device index 41 | */ 42 | virtual int GetDeviceId() = 0; 43 | 44 | /** 45 | * \brief Get the torch device of current device 46 | * @return torch device 47 | */ 48 | virtual at::Device GetDevice() = 0; 49 | 50 | /** 51 | * \brief Alloc memory over the device 52 | * @param size size to alloc 53 | * @return nullptr if failed to alloc, otherwise the memory pointer 54 | */ 55 | virtual void* Alloc(uint64_t size) = 0; 56 | 57 | /** 58 | * \brief Free memory allocated over device 59 | * @param m 60 | */ 61 | virtual void Free(void* m) = 0; 62 | 63 | /** 64 | * \brief Create stream event 65 | * @return nullptr or the event pointer 66 | */ 67 | virtual void* CreateEvent() = 0; 68 | 69 | /** 70 | * \brief Free the event 71 | * @return BACKEND_OK if succeed to free the event, otherwise BACKEND_FAILED 72 | */ 73 | virtual int FreeEvent(void* event) = 0; 74 | 75 | /** 76 | *\brief Record the event over the stream 77 | * @param event the created event 78 | * @param stream user designated stream, can be nullptr (using default stream) 79 | * @return BACKEND_OK if succeed to record the event, otherwise BACKEND_FAILED 80 | */ 81 | virtual int RecordEvent(void* event, void* stream) = 0; 82 | 83 | /** 84 | *\brief Sync and wait the event 85 | * @param event the created event 86 | * @return BACKEND_OK if succeed to synchronize the event, 87 | * otherwise BACKEND_FAILED 88 | */ 89 | virtual int SyncEvent(void* event) = 0; 90 | 91 | /** 92 | * \brief Get the backend implementation 93 | * @return the backend implementation 94 | */ 95 | static inline Backend* Get() { return GetImpl(); } 96 | 97 | static void Register(const std::string& name, Backend* backend) { 98 | RegisterImpl(name, backend); 99 | } 100 | 101 | protected: 102 | Backend() = default; 103 | 104 | private: 105 | static std::mutex backends_mutex_; 106 | static std::unordered_map backends_; 107 | 108 | static Backend* GetImpl() { 109 | static Backend* backend_impl = nullptr; 110 | if (backend_impl == nullptr) { 111 | std::unique_lock lock(backends_mutex_); 112 | std::string backend_type = "GPU"; 113 | backend_type = Environment::Get()->find("STEPMESH_BAKCEND", backend_type); 114 | PS_CHECK_NE(backends_.find(backend_type), backends_.end()) 115 | << "failed to get backend impl: " << backend_type; 116 | backend_impl = backends_[backend_type]; 117 | } 118 | return backend_impl; 119 | } 120 | 121 | static void RegisterImpl(const std::string& name, Backend* backend) { 122 | std::unique_lock lock(backends_mutex_); 123 | backends_[name] = backend; 124 | } 125 | }; 126 | 127 | } // namespace ps 128 | 129 | #endif // PS_INTERNAL_BACKEND_H_ 130 | -------------------------------------------------------------------------------- /include/ps/internal/customer.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef PS_INTERNAL_CUSTOMER_H_ 5 | #define PS_INTERNAL_CUSTOMER_H_ 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "ps/hash_table8.hpp" 17 | #include "ps/internal/message.h" 18 | #include "ps/internal/threadsafe_queue.h" 19 | namespace ps { 20 | /** 21 | * \brief The object for communication. 22 | * 23 | * As a sender, a customer tracks the responses for each request sent. 24 | * 25 | * It has its own receiving thread which is able to process any message received 26 | * from a remote node with `msg.meta.customer_id` equal to this customer's id 27 | */ 28 | 29 | class Postoffice; 30 | 31 | struct CustomerTracker { 32 | int count; 33 | std::atomic response_count; 34 | int response_count_cache; 35 | struct Trace request; 36 | struct Trace response; 37 | uint64_t start_time; 38 | bool done = false; 39 | }; 40 | 41 | class Customer { 42 | public: 43 | /** 44 | * \brief the handle for a received message 45 | * \param recved the received message 46 | */ 47 | using RecvHandle = std::function; 48 | 49 | /** 50 | * \brief constructor 51 | * \param app_id the globally unique id indicating the application the 52 | * postoffice serving for 53 | * \param customer_id the locally unique id indicating 54 | * the customer of a postoffice 55 | * \param recv_handle the functino for processing 56 | * a received message 57 | */ 58 | Customer(int app_id, int customer_id, const RecvHandle& recv_handle, 59 | Postoffice* postoffice); 60 | 61 | /** 62 | * \brief desconstructor 63 | */ 64 | ~Customer(); 65 | 66 | /** 67 | * \brief return the globally unique application id 68 | */ 69 | inline int app_id() { return app_id_; } 70 | 71 | /** 72 | * \brief return the locally unique customer id 73 | */ 74 | inline int customer_id() { return customer_id_; } 75 | 76 | /** 77 | * \brief get a timestamp for a new request. threadsafe 78 | * \param recver the receive node id of this request 79 | * \return the timestamp of this request 80 | */ 81 | int NewRequest(int recver); 82 | 83 | /** 84 | * \brief wait until the request is finished. threadsafe 85 | * \param timestamp the timestamp of the request 86 | */ 87 | void WaitRequest(int timestamp, uint64_t timeout_ms = 10000); 88 | 89 | /** 90 | * \brief return the number of responses received for the request. threadsafe 91 | * \param timestamp the timestamp of the request 92 | */ 93 | int NumResponse(int timestamp); 94 | 95 | /** 96 | * \brief add a number of responses to timestamp 97 | */ 98 | void AddResponse(int timestamp, int num = 1); 99 | 100 | /** 101 | * \brief accept a received message from \ref Van. threadsafe 102 | * \param recved the received the message 103 | */ 104 | inline void Accept(const Message& recved) { recv_queue_.Push(recved); } 105 | 106 | void DirectProcess(Message& recv); 107 | 108 | std::pair FetchTrace(int timestamp); 109 | 110 | private: 111 | /** 112 | * \brief the thread function 113 | */ 114 | void Receiving(); 115 | 116 | int app_id_; 117 | 118 | int customer_id_; 119 | 120 | RecvHandle recv_handle_; 121 | Postoffice* postoffice_; 122 | 123 | ThreadsafeQueue recv_queue_; 124 | std::unique_ptr recv_thread_; 125 | 126 | std::mutex tracker_mu_; 127 | std::condition_variable tracker_cond_; 128 | std::vector tracker_; 129 | DISALLOW_COPY_AND_ASSIGN(Customer); 130 | }; 131 | 132 | } // namespace ps 133 | #endif // PS_INTERNAL_CUSTOMER_H_ 134 | -------------------------------------------------------------------------------- /src/ibvwarp.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (C) by StepAI Contributors. 2025. 3 | */ 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #ifndef IBVWARP_H_ 15 | #define IBVWARP_H_ 16 | namespace ps { 17 | 18 | // Attempt to load a specific symbol version - fail silently 19 | #define LOAD_SYM_VERSION(handle, symbol, funcptr, version) \ 20 | do { \ 21 | cast = reinterpret_cast(&funcptr); \ 22 | *cast = dlvsym(handle, symbol, version); \ 23 | } while (0) 24 | 25 | #define IBV_INT_PS_CHECK_RET_ERRNO(container, internal_name, call, \ 26 | success_retval) \ 27 | PS_CHECK_NOT_NULL(container, internal_name); \ 28 | int ret = container.call; \ 29 | if (ret != success_retval) { \ 30 | PS_LOG(WARNING) << "call to " << #internal_name << " failed with error (" \ 31 | << strerror(errno) << ")"; \ 32 | return -1; \ 33 | } \ 34 | return 1; 35 | 36 | /* PS_CHECK_NOT_NULL: helper macro to check for NULL symbol */ 37 | #define PS_CHECK_NOT_NULL(container, internal_name) \ 38 | if (container.internal_name == NULL) { \ 39 | PS_LOG(WARNING) << "lib wrapper not initialized."; \ 40 | return -1; \ 41 | } 42 | 43 | struct dmlcMlx5dvSymbols { 44 | int (*mlx5dv_internal_query_qp_lag_port)(struct ibv_qp* qp, uint8_t* port_num, 45 | uint8_t* active_port_num); 46 | int (*mlx5dv_internal_modify_qp_lag_port)(struct ibv_qp* qp, 47 | uint8_t port_num); 48 | } mlx5dvSymbols; 49 | 50 | int buildMlx5dvSymbols(struct dmlcMlx5dvSymbols* mlx5dvSymbols) { 51 | static void* mlx5dvhandle = NULL; 52 | void** cast; 53 | 54 | mlx5dvhandle = dlopen("libmlx5.so", RTLD_NOW); 55 | if (!mlx5dvhandle) { 56 | mlx5dvhandle = dlopen("libmlx5.so.1", RTLD_NOW); 57 | if (!mlx5dvhandle) { 58 | printf("Failed to open libmlx5.so[.1]"); 59 | goto teardown; 60 | } 61 | } 62 | 63 | LOAD_SYM_VERSION(mlx5dvhandle, "mlx5dv_query_qp_lag_port", 64 | mlx5dvSymbols->mlx5dv_internal_query_qp_lag_port, 65 | "MLX5_1.14"); 66 | LOAD_SYM_VERSION(mlx5dvhandle, "mlx5dv_modify_qp_lag_port", 67 | mlx5dvSymbols->mlx5dv_internal_modify_qp_lag_port, 68 | "MLX5_1.14"); 69 | return 1; 70 | 71 | teardown: 72 | mlx5dvSymbols->mlx5dv_internal_query_qp_lag_port = NULL; 73 | mlx5dvSymbols->mlx5dv_internal_modify_qp_lag_port = NULL; 74 | if (mlx5dvhandle != NULL) dlclose(mlx5dvhandle); 75 | return -1; 76 | } 77 | 78 | int wrap_mlx5dv_query_qp_lag_port(struct ibv_qp* qp, uint8_t* port_num, 79 | uint8_t* active_port_num) { 80 | IBV_INT_PS_CHECK_RET_ERRNO( 81 | mlx5dvSymbols, mlx5dv_internal_query_qp_lag_port, 82 | mlx5dv_internal_query_qp_lag_port(qp, port_num, active_port_num), 0); 83 | } 84 | 85 | int wrap_mlx5dv_modify_qp_lag_port(struct ibv_qp* qp, uint8_t port_num) { 86 | IBV_INT_PS_CHECK_RET_ERRNO(mlx5dvSymbols, mlx5dv_internal_modify_qp_lag_port, 87 | mlx5dv_internal_modify_qp_lag_port(qp, port_num), 88 | 0); 89 | } 90 | 91 | static pthread_once_t initOnceControl = PTHREAD_ONCE_INIT; 92 | int initResult = -1; 93 | int wrap_ibv_symbols(void) { 94 | pthread_once(&initOnceControl, 95 | []() { initResult = buildMlx5dvSymbols(&mlx5dvSymbols); }); 96 | return initResult; 97 | } 98 | 99 | } // namespace ps 100 | 101 | #endif // IBVWARP_H_ 102 | -------------------------------------------------------------------------------- /src/van_common.h: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Amazon Web Services Inc. or its affiliates. All Rights 2 | // Reserved. 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // ============================================================================= 16 | #ifndef VAN_COMMON_H_ 17 | #define VAN_COMMON_H_ 18 | 19 | #include 20 | 21 | #if defined(DMLC_USE_RDMA) || defined(DMLC_USE_FABRIC) 22 | #include 23 | 24 | #define DIVUP(x, y) (((x) + (y)-1) / (y)) 25 | #define ROUNDUP(x, y) (DIVUP((x), (y)) * (y)) 26 | 27 | namespace ps { 28 | 29 | static const int kMaxDataFields = 4; 30 | static const int kMaxAddressEntries = 10240; 31 | 32 | template 33 | static inline T align_floor(T v, T align) { 34 | return v - (v % align); 35 | } 36 | 37 | template 38 | static inline T align_ceil(T v, T align) { 39 | return align_floor(v + align - 1, align); 40 | } 41 | 42 | enum MessageTypes : uint32_t { 43 | kRendezvousStart, 44 | kRendezvousReply, 45 | }; 46 | 47 | static inline void aligned_malloc(void **ptr, size_t size) { 48 | size_t page_size = sysconf(_SC_PAGESIZE); 49 | void *p; 50 | int size_aligned = ROUNDUP(size, page_size); 51 | int ret = posix_memalign(&p, page_size, size_aligned); 52 | PS_CHECK_EQ(ret, 0) << "posix_memalign error: " << strerror(ret); 53 | PS_CHECK(p); 54 | memset(p, 0, size); 55 | *ptr = p; 56 | } 57 | 58 | bool IsValidPushpull(const Message &msg) { 59 | if (!msg.meta.control.empty()) return false; 60 | if (msg.meta.simple_app) return false; 61 | return true; 62 | } 63 | 64 | // just a translation, the decoded key might not be 65 | // readable when we have multiple servers 66 | uint64_t DecodeKey(SArray keys) { 67 | ps::Key key = 0; 68 | uint64_t coef = 1; 69 | for (unsigned int i = 0; i < keys.size(); ++i) { 70 | key += coef * (uint8_t)keys.data()[i]; 71 | coef *= 256; // 256=2^8 (uint8_t) 72 | } 73 | return key; 74 | } 75 | 76 | template 77 | class AddressPool { 78 | public: 79 | AddressPool() { 80 | auto addrpool_size = Environment::Get()->find("BYTEPS_ADDRESS_POOL_SIZE"); 81 | kMaxEntries = addrpool_size ? atoi(addrpool_size) : kMaxEntries; 82 | std::lock_guard lk(mu_); 83 | table_ = new T *[kMaxEntries]; 84 | // init the queue 85 | for (int i = 0; i < kMaxEntries; i++) { 86 | indices_.push(i); 87 | table_[i] = nullptr; 88 | } 89 | } 90 | 91 | T *GetAddressAndRelease(uint32_t index) { 92 | std::lock_guard lk(mu_); 93 | T *ptr = table_[index]; 94 | PS_CHECK(ptr); 95 | indices_.push(index); 96 | table_[index] = nullptr; 97 | return ptr; 98 | } 99 | 100 | // TODO(none): make the address pool size dynamic 101 | T *GetAddress(uint32_t index) { 102 | std::lock_guard lk(mu_); 103 | return PS_CHECK_NOTNULL(table_[index]); 104 | } 105 | 106 | uint32_t StoreAddress(T *ptr) { 107 | std::lock_guard lk(mu_); 108 | PS_CHECK(ptr); 109 | PS_CHECK(!indices_.empty()) 110 | << "Address pool size is too small, " 111 | << "current size is " << kMaxEntries 112 | << ", consider increasing BYTEPS_ADDRESS_POOL_SIZE"; 113 | uint32_t idx = indices_.front(); 114 | indices_.pop(); 115 | PS_CHECK_EQ(table_[idx], nullptr) << idx; 116 | table_[idx] = ptr; 117 | return idx; 118 | } 119 | 120 | private: 121 | int kMaxEntries = kMaxAddressEntries; 122 | 123 | std::mutex mu_; 124 | std::queue indices_; 125 | T **table_; 126 | }; 127 | 128 | }; // namespace ps 129 | 130 | #endif // DMLC_USE_RDMA || DMLC_USE_FABRIC 131 | #endif // VAN_COMMON_H_ 132 | -------------------------------------------------------------------------------- /tests/vllm/ffn.py: -------------------------------------------------------------------------------- 1 | import os 2 | from vllm.config import AFDConfig, VllmConfig,ParallelConfig 3 | from stepmesh_connector import StepMeshAFDConnector 4 | import torch 5 | import torch.profiler 6 | import time 7 | from bind_pid import set_numa_affinity, bind_pid 8 | from cycle import get_cycles_per_ms 9 | from stepmesh_connector import StepMeshAFDConnector,AFDConnectorMetadata 10 | 11 | import numpy as np 12 | 13 | os.environ['STEPMESH_BIND_CPU_CORE']='1' 14 | os.environ['STEPMESH_CONNECTOR_DEBUG']='true' 15 | os.environ['STEPMESH_SPLIT_QP_LAG']='1' 16 | 17 | ''' 18 | export STEPMESH_BIND_CPU_CORE=1 19 | export STEPMESH_CONNECTOR_DEBUG=true 20 | export STEPMESH_SPLIT_QP_LAG=1 21 | export VLLM_TORCH_PROFILER_DIR=prof 22 | ''' 23 | 24 | ip="10.203.8.15" 25 | 26 | cycle_per_ms = get_cycles_per_ms() 27 | 28 | rank = int(os.environ.get("RANK", 0)) 29 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 30 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 31 | local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) 32 | node_rank = rank // local_world_size 33 | 34 | # bind_pid(os.getpid(), local_rank) 35 | set_numa_affinity(local_rank) 36 | afd_config = AFDConfig( 37 | afd_connector="stepmesh", 38 | afd_role="ffn", 39 | afd_port=1239, 40 | afd_host=f"{ip}", 41 | num_afd_stages=3, 42 | num_attention_servers=1, 43 | num_ffn_servers=1, 44 | afd_server_rank=node_rank, 45 | ) 46 | parallel_config = ParallelConfig( 47 | tensor_parallel_size=8, 48 | pipeline_parallel_size=1, 49 | data_parallel_size=1, 50 | ) 51 | vllm_config = VllmConfig( 52 | afd_config=afd_config, 53 | parallel_config=parallel_config, 54 | ) 55 | connector = StepMeshAFDConnector( 56 | rank=rank, 57 | local_rank=local_rank, 58 | config=vllm_config 59 | ) 60 | torch.cuda.set_device(local_rank) 61 | time.sleep(5) 62 | connector.init_afd_connector() 63 | # set_numa_affinity(local_rank) 64 | import fserver_lib as ps 65 | ret_buffer = torch.rand([65535, 7168], dtype=torch.bfloat16, device='cuda') 66 | 67 | 68 | s = torch.cuda.Stream() 69 | 70 | if __name__ == "__main__": 71 | counter = 0 72 | profiler = None 73 | while True: 74 | counter += 1 75 | if counter % 1000 == 0: 76 | print(f"Respond {rank} counter {counter}") 77 | 78 | # 在counter为10000~11100时启用torch profiler,包含100轮warmup + 1000轮active记录 79 | if counter == 20000: 80 | profiler = torch.profiler.profile( 81 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], 82 | schedule=torch.profiler.schedule(wait=0, warmup=100, active=1000), 83 | on_trace_ready=torch.profiler.tensorboard_trace_handler(f'./profiler_logs/rank_{rank}', use_gzip=True), 84 | record_shapes=True, 85 | with_stack=True, 86 | experimental_config=torch.profiler._ExperimentalConfig( 87 | verbose=True, # 启用详细日志 88 | enable_cuda_sync_events=True # 启用CUDA同步事件跟踪 89 | ) 90 | ) 91 | profiler.start() 92 | print(f"Rank {rank}: Started profiler at counter {counter}, will warmup 100 steps then record 1000 steps with gzip compression") 93 | 94 | if counter >= 20000 and counter <= 21099: 95 | profiler.step() 96 | 97 | if counter == 21099: 98 | # profiler会在active阶段结束时自动停止并保存,无需手动stop() 99 | print(f"Rank {rank}: Profiler completed at counter {counter}, recorded 100 warmup + 1000 active steps") 100 | profiler = None 101 | 102 | with torch.cuda.stream(s): 103 | batches = ps.get_batch() 104 | if len(batches) != 0: 105 | recv_tensor_list = [batches[i][1][0] for i in range(1)] 106 | comm_id_list = [batches[i][0] for i in range(1)] 107 | # comm.all_gather(allgather_input_buffer) 108 | torch.cuda._sleep(int(cycle_per_ms * 0.26)) 109 | ps.respond_vec(ret_buffer, recv_tensor_list, comm_id_list) 110 | # if counter % (1830*5) == 0: 111 | # connector.print_trace() -------------------------------------------------------------------------------- /tracker/dmlc_ssh.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | DMLC submission script by ssh 4 | 5 | One need to make sure all slaves machines are ssh-able. 6 | """ 7 | 8 | import argparse 9 | import sys 10 | import os 11 | import subprocess 12 | import tracker 13 | import logging 14 | from threading import Thread 15 | 16 | class SSHLauncher(object): 17 | def __init__(self, args, unknown): 18 | self.args = args 19 | self.cmd = (' '.join(args.command) + ' ' + ' '.join(unknown)) 20 | 21 | assert args.hostfile is not None 22 | with open(args.hostfile) as f: 23 | hosts = f.readlines() 24 | assert len(hosts) > 0 25 | self.hosts=[] 26 | for h in hosts: 27 | if len(h.strip()) > 0: 28 | self.hosts.append(h.strip()) 29 | 30 | def sync_dir(self, local_dir, slave_node, slave_dir): 31 | """ 32 | sync the working directory from root node into slave node 33 | """ 34 | remote = slave_node + ':' + slave_dir 35 | logging.info('rsync %s -> %s', local_dir, remote) 36 | 37 | # TODO uses multithread 38 | prog = 'rsync -az --rsh="ssh -o StrictHostKeyChecking=no" %s %s' % ( 39 | local_dir, remote) 40 | subprocess.check_call([prog], shell = True) 41 | 42 | 43 | def get_env(self, pass_envs): 44 | envs = [] 45 | # get system envs 46 | keys = ['LD_LIBRARY_PATH', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY'] 47 | for k in keys: 48 | v = os.getenv(k) 49 | if v is not None: 50 | envs.append('export ' + k + '=' + v + ';') 51 | # get ass_envs 52 | for k, v in pass_envs.items(): 53 | envs.append('export ' + str(k) + '=' + str(v) + ';') 54 | return (' '.join(envs)) 55 | 56 | def submit(self): 57 | def ssh_submit(nworker, nserver, pass_envs): 58 | """ 59 | customized submit script 60 | """ 61 | # thread func to run the job 62 | def run(prog): 63 | subprocess.check_call(prog, shell = True) 64 | 65 | # sync programs if necessary 66 | local_dir = os.getcwd()+'/' 67 | working_dir = local_dir 68 | if self.args.sync_dir is not None: 69 | working_dir = self.args.sync_dir 70 | for h in self.hosts: 71 | self.sync_dir(local_dir, h, working_dir) 72 | 73 | # launch jobs 74 | for i in range(nworker + nserver): 75 | pass_envs['DMLC_ROLE'] = 'server' if i < nserver else 'worker' 76 | node = self.hosts[i % len(self.hosts)] 77 | prog = self.get_env(pass_envs) + ' cd ' + working_dir + '; ' + self.cmd 78 | prog = 'ssh -o StrictHostKeyChecking=no ' + node + ' \'' + prog + '\'' 79 | 80 | thread = Thread(target = run, args=(prog,)) 81 | thread.setDaemon(True) 82 | thread.start() 83 | 84 | return ssh_submit 85 | 86 | def run(self): 87 | tracker.config_logger(self.args) 88 | tracker.submit(self.args.num_workers, 89 | self.args.num_servers, 90 | fun_submit = self.submit(), 91 | pscmd = self.cmd) 92 | 93 | def main(): 94 | parser = argparse.ArgumentParser(description='DMLC script to submit dmlc job using ssh') 95 | parser.add_argument('-n', '--num-workers', required=True, type=int, 96 | help = 'number of worker nodes to be launched') 97 | parser.add_argument('-s', '--num-servers', default = 0, type=int, 98 | help = 'number of server nodes to be launched') 99 | parser.add_argument('-H', '--hostfile', type=str, 100 | help = 'the hostfile of all slave nodes') 101 | parser.add_argument('command', nargs='+', 102 | help = 'command for dmlc program') 103 | parser.add_argument('--sync-dir', type=str, 104 | help = 'if specificed, it will sync the current \ 105 | directory into slave machines\'s SYNC_DIR') 106 | 107 | args, unknown = parser.parse_known_args() 108 | 109 | launcher = SSHLauncher(args, unknown) 110 | launcher.run() 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /tests/vllm/attn.py: -------------------------------------------------------------------------------- 1 | import os 2 | from vllm.config import AFDConfig, VllmConfig,ParallelConfig 3 | from stepmesh_connector import StepMeshAFDConnector,AFDConnectorMetadata 4 | import torch 5 | import time 6 | import numpy as np 7 | import fserver_lib as ps 8 | from cycle import get_cycles_per_ms 9 | from bind_pid import set_numa_affinity 10 | import torch.profiler 11 | 12 | 13 | os.environ['STEPMESH_BIND_CPU_CORE']='1' 14 | os.environ['STEPMESH_CONNECTOR_DEBUG']='true' 15 | os.environ['STEPMESH_SPLIT_QP_LAG']='1' 16 | 17 | ip="10.203.8.15" 18 | 19 | cycle_per_ms = get_cycles_per_ms() 20 | 21 | rank = int(os.environ.get("RANK", 0)) 22 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 23 | afd_config = AFDConfig( 24 | afd_connector="stepmesh", 25 | afd_role="attention", 26 | afd_port=1239, 27 | afd_host=f"{ip}", 28 | num_afd_stages=3, 29 | num_attention_servers=1, 30 | num_ffn_servers=1, 31 | ) 32 | parallel_config = ParallelConfig( 33 | tensor_parallel_size=1, 34 | pipeline_parallel_size=1, 35 | data_parallel_size=8, 36 | ) 37 | vllm_config = VllmConfig( 38 | afd_config=afd_config, 39 | parallel_config=parallel_config, 40 | ) 41 | connector = StepMeshAFDConnector( 42 | rank=rank, 43 | local_rank=local_rank, 44 | config=vllm_config 45 | ) 46 | torch.cuda.set_device(local_rank) 47 | set_numa_affinity(local_rank) 48 | time.sleep(5) 49 | connector.init_afd_connector() 50 | print(f"--------- rank {rank} local_rank {local_rank} init ---------") 51 | 52 | if __name__ == "__main__": 53 | counter = 0 54 | s = torch.cuda.Stream() 55 | torch.cuda.set_stream(s) 56 | profiler = None 57 | hidden_states = [torch.randn(4, 7168, dtype=torch.bfloat16, device="cuda") for i in range(afd_config.num_afd_stages)] 58 | while True: 59 | if counter % (1830*2) == 0: 60 | connector.print_trace() 61 | 62 | for layer_idx in range(61): 63 | for stage_idx in range(afd_config.num_afd_stages): 64 | counter += 1 65 | if layer_idx > 0: 66 | connector.recv_ffn_output() 67 | 68 | # torch.cuda._sleep(int(cycle_per_ms * 0.1)) 69 | time.sleep(0.0002) 70 | # cpu sleep 100us 71 | connector.send_attn_output( 72 | hidden_states[stage_idx], 73 | AFDConnectorMetadata.create_attention_metadata( 74 | layer_idx, 75 | stage_idx, 76 | hidden_states[stage_idx].shape[0], 77 | hidden_states[stage_idx].dtype, 78 | hidden_states[stage_idx].device, 79 | ) 80 | ) 81 | 82 | 83 | if counter == 14000: 84 | profiler = torch.profiler.profile( 85 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], 86 | schedule=torch.profiler.schedule(wait=0, warmup=100, active=1000), 87 | on_trace_ready=torch.profiler.tensorboard_trace_handler(f'./profiler_logs/rank_{rank}', use_gzip=True), 88 | record_shapes=True, 89 | with_stack=True, 90 | # experimental_config=torch.profiler._ExperimentalConfig( 91 | # verbose=True, # 启用详细日志 92 | # enable_cuda_sync_events=True # 启用CUDA同步事件跟踪 93 | # ) 94 | ) 95 | profiler.start() 96 | print(f"Rank {rank}: Started profiler at counter {counter}, will warmup 100 steps then record 1000 steps with gzip compression") 97 | 98 | if counter >= 14000 and counter <= 15099: 99 | profiler.step() 100 | 101 | if counter == 15099: 102 | # profiler会在active阶段结束时自动停止并保存,无需手动stop() 103 | print(f"Rank {rank}: Profiler completed at counter {counter}, recorded 100 warmup + 1000 active steps") 104 | profiler = None 105 | 106 | for i in range(afd_config.num_afd_stages): 107 | connector.recv_ffn_output() 108 | time.sleep(0.01) 109 | torch.cuda.synchronize() 110 | time.sleep(0.01) -------------------------------------------------------------------------------- /include/ps/internal/parallel_kv_match.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | * \file parallel_kv_match.h 4 | * \brief paralle key-value pairs matching 5 | */ 6 | #ifndef PS_INTERNAL_PARALLEL_KV_MATCH_H_ 7 | #define PS_INTERNAL_PARALLEL_KV_MATCH_H_ 8 | #include 9 | #include 10 | #include "ps/sarray.h" 11 | #include "ps/internal/assign_op.h" 12 | 13 | namespace ps { 14 | namespace { 15 | /** 16 | * \brief thread function, internal use 17 | * 18 | * \param src_key start of source key 19 | * \param src_key_end end of source key 20 | * \param src_val start of source val 21 | * \param dst_key start of destination key 22 | * \param dst_key_end end of denstination key 23 | * \param dst_val start of destination val 24 | * \param k length of a single value 25 | * \param op assignment operator 26 | * \param grainsize thread grainsize size 27 | * \param n number of matched kv pairs 28 | */ 29 | template 30 | void ParallelOrderedMatch( 31 | const K* src_key, const K* src_key_end, const V* src_val, 32 | const K* dst_key, const K* dst_key_end, V* dst_val, 33 | int k, AsOp op, size_t grainsize, size_t* n) { 34 | size_t src_len = std::distance(src_key, src_key_end); 35 | size_t dst_len = std::distance(dst_key, dst_key_end); 36 | if (dst_len == 0 || src_len == 0) return; 37 | 38 | // drop the unmatched tail of src 39 | src_key = std::lower_bound(src_key, src_key_end, *dst_key); 40 | src_val += (src_key - (src_key_end - src_len)) * k; 41 | 42 | if (dst_len <= grainsize) { 43 | while (dst_key != dst_key_end && src_key != src_key_end) { 44 | if (*src_key < *dst_key) { 45 | ++src_key; src_val += k; 46 | } else { 47 | if (!(*dst_key < *src_key)) { 48 | for (int i = 0; i < k; ++i) { 49 | AssignOp(dst_val[i], src_val[i], op); 50 | } 51 | ++src_key; src_val += k; 52 | *n += k; 53 | } 54 | ++dst_key; dst_val += k; 55 | } 56 | } 57 | } else { 58 | std::thread thr( 59 | ParallelOrderedMatch, src_key, src_key_end, src_val, 60 | dst_key, dst_key + dst_len / 2, dst_val, 61 | k, op, grainsize, n); 62 | size_t m = 0; 63 | ParallelOrderedMatch( 64 | src_key, src_key_end, src_val, 65 | dst_key + dst_len / 2, dst_key_end, dst_val + (dst_len / 2) * k, 66 | k, op, grainsize, &m); 67 | thr.join(); 68 | *n += m; 69 | } 70 | } 71 | } // namespace 72 | 73 | 74 | /** 75 | * \brief Merge \a src_val into \a dst_val by matching keys. Keys must be unique 76 | * and sorted. 77 | * 78 | * \code 79 | * if (dst_key[i] == src_key[j]) { 80 | * dst_val[i] op= src_val[j] 81 | * } 82 | * \endcode 83 | * 84 | * When finished, \a dst_val will have length `k * dst_key.size()` and filled 85 | * with matched value. Umatched value will be untouched if exists or filled with 0. 86 | * 87 | * \tparam K type of key 88 | * \tparam V type of value 89 | * \tparam C type of the container such as \ref SArray or \ref std::vector 90 | * \param src_key the source keys 91 | * \param src_val the source values 92 | * \param dst_key the destination keys 93 | * \param dst_val the destination values. 94 | * \param k the length of a single value (default is 1) 95 | * \param op the assignment operator (default is ASSIGN) 96 | * \param num_threads number of thread (default is 1) 97 | * \return the number of matched kv pairs 98 | */ 99 | template 100 | size_t ParallelOrderedMatch( 101 | const SArray& src_key, const SArray& src_val, 102 | const SArray& dst_key, C* dst_val, 103 | int k = 1, AssignOp op = ASSIGN, int num_threads = 1) { 104 | // do check 105 | PS_CHECK_GT(num_threads, 0); 106 | PS_CHECK_EQ(src_key.size() * k, src_val.size()); 107 | PS_CHECK_NOTNULL(dst_val->resize(dst_key.size() * k)); 108 | if (dst_key.empty()) return 0; 109 | 110 | // shorten the matching range 111 | Range range = FindRange(dst_key, src_key.begin(), src_key.end()); 112 | size_t grainsize = std::max(range.size() * k / num_threads + 5, 113 | static_cast(1024*1024)); 114 | size_t n = 0; 115 | ParallelOrderedMatch( 116 | src_key.begin(), src_key.end(), src_val.begin(), 117 | dst_key.begin() + range.begin(), dst_key.begin() + range.end(), 118 | dst_val->begin() + range.begin()*k, k, op, grainsize, &n); 119 | return n; 120 | } 121 | 122 | } // namespace ps 123 | #endif // PS_INTERNAL_PARALLEL_KV_MATCH_H_ 124 | -------------------------------------------------------------------------------- /tests/utests/test_common.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (C) 2025 by StepAI Contributors. 3 | */ 4 | #ifndef PS_TEST_COMMON_H_ 5 | #define PS_TEST_COMMON_H_ 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | #ifdef DMLC_USE_CUDA 18 | #include 19 | #include 20 | #endif 21 | 22 | #include "ps/af_tensor_app.h" 23 | 24 | #include "ps/ps.h" 25 | 26 | using namespace ps; 27 | 28 | static struct { 29 | int warmup_iter = 10; 30 | int iter = INT_MAX; 31 | int size = 7168; 32 | int debug = false; 33 | int gpu_num = 1; 34 | int gpu = 0; 35 | int mb_num = 3; 36 | int group_size; 37 | int node_rank ; 38 | std::string role_str; 39 | ps::Node::Role role; 40 | } g_conf; 41 | 42 | #define DIVUP(x, y) (((x)+(y)-1)/(y)) 43 | #define ROUNDUP(x, y) (DIVUP((x), (y))*(y)) 44 | 45 | void InitConfig() { 46 | Environment::Get()->find("BENCHMARK_WARMUP", 47 | &g_conf.warmup_iter, g_conf.warmup_iter); 48 | Environment::Get()->find("BENCHMARK_ITER", &g_conf.iter, g_conf.iter); 49 | Environment::Get()->find("BENCHMARK_SIZE", &g_conf.size, g_conf.size); 50 | Environment::Get()->find("STEPMESH_GPU", &g_conf.gpu, g_conf.gpu); 51 | 52 | g_conf.gpu = GetGpuId(); 53 | g_conf.group_size = GetGroupSize(); 54 | g_conf.node_rank = GetNodeRank(); 55 | const char* val = PS_CHECK_NOTNULL(Environment::Get()->find("DMLC_ROLE")); 56 | g_conf.role_str = std::string(val); 57 | g_conf.role = GetRole(g_conf.role_str); 58 | g_conf.debug = Environment::Get()->find("DEBUG_MODE") != nullptr; 59 | } 60 | 61 | static inline void AlignedMalloc(void** ptr, size_t size) { 62 | size_t page_size = sysconf(_SC_PAGESIZE); 63 | void* p; 64 | int size_aligned = ROUNDUP(size, page_size); 65 | int ret = posix_memalign(&p, page_size, size_aligned); 66 | PS_CHECK_EQ(ret, 0) << "posix_memalign error: " << strerror(ret); 67 | PS_CHECK(p); 68 | memset(p, 1, size); 69 | *ptr = p; 70 | } 71 | 72 | static inline std::vector GetPercentile( 73 | std::vector& vec, std::vector percentiles) { 74 | std::vector result; 75 | result.reserve(percentiles.size()); 76 | std::sort(vec.begin(), vec.end()); 77 | for (auto percentile : percentiles) { 78 | PS_CHECK(percentile >= 0 && percentile < 100); 79 | auto percentile_idx = int(vec.size() * percentile / 100) - 1; 80 | result.emplace_back(vec[percentile_idx]); 81 | } 82 | return result; 83 | } 84 | 85 | static inline int64_t GetMean(std::vector& vec) { 86 | PS_CHECK(!vec.empty()); 87 | int64_t result = 0; 88 | for (auto data : vec) { 89 | result += data; 90 | } 91 | return result / vec.size(); 92 | } 93 | 94 | static inline void DumpLatency(const std::string& head, std::vector& vec) { 95 | auto pull_mean = GetMean(vec); 96 | auto pull_percentile = GetPercentile(vec, {50, 90, 99}); 97 | LL << head << ": mean=" << pull_mean / 1000.0 98 | << "us, min=" << vec[0] / 1000.0 99 | << "us, p50=" << pull_percentile[0] / 1000.0 100 | << "us, p90=" << pull_percentile[1] / 1000.0 101 | << "us, p99=" << pull_percentile[2] / 1000.0 102 | << "us, max=" << vec[vec.size() - 1] / 1000.0 << "us"; 103 | } 104 | 105 | static inline at::Tensor CreateTensor( 106 | std::vector shape, 107 | at::ScalarType dtype, 108 | int gpu, 109 | bool random = false) { 110 | auto options = torch::TensorOptions() 111 | .dtype(dtype) 112 | .memory_format(at::MemoryFormat::Contiguous) 113 | .device(at::Device(at::kCUDA, gpu)); 114 | if (random) { 115 | return torch::rand(shape, options); 116 | } else { 117 | return torch::zeros(shape, options); 118 | } 119 | } 120 | 121 | static inline void StartServer( 122 | std::function 123 | func) { 124 | AFTensorServer* server = new AFTensorServer(g_conf.gpu); 125 | server->SetRequestHandle(func); 126 | RegisterExitCallback([server]() { delete server; }); 127 | } 128 | 129 | static inline void InitWorker(AFTensorWorker* kv) { 130 | Postoffice::GetWorker(g_conf.gpu)->Barrier(0, ps::kWorkerGroup); 131 | PS_LOG(INFO) << "finish worker init."; 132 | } 133 | 134 | static inline void StartScheduler() { 135 | PS_LOG(INFO) << "Scheduler starts"; 136 | StartPS(0, Node::SCHEDULER, -1, true); 137 | Finalize(0, Node::SCHEDULER, true); 138 | PS_LOG(INFO) << "Scheduler ends"; 139 | } 140 | 141 | #endif // PS_TEST_COMMON_H_ 142 | -------------------------------------------------------------------------------- /src/customer.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | * Modifications Copyright (C) by StepAI Contributors. 2025. 4 | */ 5 | 6 | #include "ps/internal/customer.h" 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "ps/internal/postoffice.h" 18 | #include "ps/internal/threadsafe_queue.h" 19 | 20 | namespace ps { 21 | 22 | const int Node::kEmpty = std::numeric_limits::max(); 23 | const int Meta::kEmpty = std::numeric_limits::max(); 24 | const int kMaxSpinCount = 10000; 25 | 26 | Customer::Customer(int app_id, int customer_id, 27 | const Customer::RecvHandle& recv_handle, 28 | Postoffice* postoffice) 29 | : app_id_(app_id), 30 | customer_id_(customer_id), 31 | recv_handle_(recv_handle), 32 | postoffice_(postoffice) { 33 | postoffice_->AddCustomer(this); 34 | } 35 | 36 | Customer::~Customer() { postoffice_->RemoveCustomer(this); } 37 | 38 | int Customer::NewRequest(int recver) { 39 | PS_CHECK(recver == kServerGroup) << recver; 40 | std::lock_guard lk(tracker_mu_); 41 | // for push/pull requests, the worker only communication with one instance 42 | // from each server instance group 43 | int num = postoffice_->GetNodeIDs(recver).size() / postoffice_->group_size(); 44 | auto* t = new CustomerTracker(); 45 | t->count = num; 46 | t->response_count.store(0); 47 | t->response_count_cache = 0; 48 | t->start_time = GetNanosecond(false); 49 | tracker_.push_back(t); 50 | return tracker_.size() - 1; 51 | } 52 | 53 | void Customer::WaitRequest(int timestamp, uint64_t timeout_ms) { 54 | uint64_t timeout_ns = timeout_ms * 1000000; 55 | auto* req = tracker_[timestamp]; 56 | int spin_count = 0; 57 | while (req->count != req->response_count.load(std::memory_order_acquire)) { 58 | if (spin_count < kMaxSpinCount) { 59 | spin_count++; 60 | } else { 61 | _mm_pause(); 62 | uint64_t now = GetNanosecond(false); 63 | // 1s for timeout 64 | if (now - req->start_time > timeout_ns) { 65 | PS_LOG(FATAL) << "request timeout " << timeout_ms << "ms, handler " 66 | << timestamp << " " << (now - req->start_time) / 1000 67 | << "us" 68 | << " " 69 | << req->response_count.load(std::memory_order_acquire) 70 | << " " << req->count; 71 | } 72 | } 73 | } 74 | } 75 | 76 | int Customer::NumResponse(int timestamp) { 77 | // std::unique_lock lk(tracker_mu_); 78 | return tracker_[timestamp]->count; 79 | } 80 | 81 | void Customer::AddResponse(int timestamp, int num) { 82 | // std::unique_lock lk(tracker_mu_); 83 | tracker_[timestamp]->response_count.fetch_add(num, std::memory_order_release); 84 | } 85 | 86 | void Customer::Receiving() { 87 | while (true) { 88 | Message recv; 89 | recv_queue_.WaitAndPop(&recv); 90 | if (!recv.meta.control.empty() && 91 | recv.meta.control.cmd == Control::TERMINATE) { 92 | break; 93 | } 94 | recv_handle_(recv); 95 | if (!recv.meta.request) { 96 | auto t = tracker_[recv.meta.timestamp]; 97 | PS_CHECK_NE(t, nullptr) << "could not find tracker"; 98 | #ifdef STEPMESH_ENABLE_TRACE 99 | t->request = recv.meta.request_trace; 100 | t->response = recv.meta.response_trace; 101 | #endif // STEPMESH_ENABLE_TRACE 102 | t->response_count.fetch_add(1, std::memory_order_release); 103 | } 104 | } 105 | } 106 | 107 | void Customer::DirectProcess(Message& recv) { 108 | if (!recv.meta.control.empty() && 109 | recv.meta.control.cmd == Control::TERMINATE) { 110 | return; 111 | } 112 | recv_handle_(recv); 113 | 114 | if (!recv.meta.request) { 115 | auto t = tracker_[recv.meta.timestamp]; 116 | PS_CHECK_NE(t, nullptr) << "could not find tracker"; 117 | #ifdef STEPMESH_ENABLE_TRACE 118 | t->request = recv.meta.request_trace; 119 | t->response = recv.meta.response_trace; 120 | #endif // STEPMESH_ENABLE_TRACE 121 | PS_VLOG(4) << "recv response " << recv.meta.timestamp << " " 122 | << recv.meta.DebugString(); 123 | t->response_count.fetch_add(1, std::memory_order_release); 124 | t->response_count_cache += 1; 125 | } 126 | } 127 | 128 | std::pair Customer::FetchTrace(int timestamp) { 129 | #ifdef STEPMESH_ENABLE_TRACE 130 | std::unique_lock lk(tracker_mu_); 131 | auto p = tracker_[timestamp]; 132 | return std::make_pair(p->request, p->response); 133 | #endif // STEPMESH_ENABLE_TRACE 134 | return std::make_pair(Trace(), Trace()); 135 | } 136 | 137 | } // namespace ps 138 | -------------------------------------------------------------------------------- /src/resender.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef PS_RESENDER_H_ 5 | #define PS_RESENDER_H_ 6 | #include 7 | #include 8 | #include 9 | #include 10 | namespace ps { 11 | 12 | /** 13 | * \brief resend a messsage if no ack is received within a given time 14 | */ 15 | class Resender { 16 | public: 17 | /** 18 | * \param timeout timeout in millisecond 19 | */ 20 | Resender(int timeout, int max_num_retry, Van* van) { 21 | timeout_ = timeout; 22 | max_num_retry_ = max_num_retry; 23 | van_ = van; 24 | monitor_ = new std::thread(&Resender::Monitoring, this); 25 | } 26 | ~Resender() { 27 | exit_ = true; 28 | monitor_->join(); 29 | delete monitor_; 30 | } 31 | 32 | /** 33 | * \brief add an outgoining message 34 | * 35 | */ 36 | void AddOutgoing(const Message& msg) { 37 | if (msg.meta.control.cmd == Control::ACK) return; 38 | PS_CHECK_NE(msg.meta.timestamp, Meta::kEmpty) << msg.DebugString(); 39 | auto key = GetKey(msg); 40 | std::lock_guard lk(mu_); 41 | // already buffered, which often due to call Send by the monitor thread 42 | if (send_buff_.find(key) != send_buff_.end()) return; 43 | 44 | auto& ent = send_buff_[key]; 45 | ent.msg = msg; 46 | ent.send = Now(); 47 | ent.num_retry = 0; 48 | } 49 | 50 | /** 51 | * \brief add an incomming message 52 | * \brief return true if msg has been added before or a ACK message 53 | */ 54 | bool AddIncomming(const Message& msg) { 55 | // a message can be received by multiple times 56 | if (msg.meta.control.cmd == Control::TERMINATE) { 57 | return false; 58 | } else if (msg.meta.control.cmd == Control::ACK) { 59 | mu_.lock(); 60 | auto key = msg.meta.control.msg_sig; 61 | auto it = send_buff_.find(key); 62 | if (it != send_buff_.end()) send_buff_.erase(it); 63 | mu_.unlock(); 64 | return true; 65 | } else { 66 | mu_.lock(); 67 | auto key = GetKey(msg); 68 | auto it = acked_.find(key); 69 | bool duplicated = it != acked_.end(); 70 | if (!duplicated) acked_.insert(key); 71 | mu_.unlock(); 72 | // send back ack message (even if it is duplicated) 73 | Message ack; 74 | ack.meta.recver = msg.meta.sender; 75 | ack.meta.sender = msg.meta.recver; 76 | ack.meta.control.cmd = Control::ACK; 77 | ack.meta.control.msg_sig = key; 78 | van_->Send(ack); 79 | // warning 80 | if (duplicated) 81 | PS_LOG(WARNING) << "Duplicated message: " << msg.DebugString(); 82 | return duplicated; 83 | } 84 | } 85 | 86 | private: 87 | using Time = std::chrono::milliseconds; 88 | // the buffer entry 89 | struct Entry { 90 | Message msg; 91 | Time send; 92 | int num_retry = 0; 93 | }; 94 | std::unordered_map send_buff_; 95 | 96 | uint64_t GetKey(const Message& msg) { 97 | PS_CHECK_NE(msg.meta.timestamp, Meta::kEmpty) << msg.DebugString(); 98 | uint16_t id = msg.meta.app_id; 99 | uint8_t sender = 100 | msg.meta.sender == Node::kEmpty ? van_->my_node().id : msg.meta.sender; 101 | uint8_t recver = msg.meta.recver; 102 | return (static_cast(id) << 48) | 103 | (static_cast(sender) << 40) | 104 | (static_cast(recver) << 32) | (msg.meta.timestamp << 1) | 105 | msg.meta.request; 106 | } 107 | Time Now() { 108 | return std::chrono::duration_cast