├── core ├── core ├── memory │ ├── stream_pool.cpp │ ├── stream_pool.h │ ├── weights_buffer.h │ ├── memory_pool.h │ ├── kv_cache_buffer.h │ ├── fixed_size_allocator.h │ ├── host_caching_allocator.cpp │ ├── device_caching_allocator.cpp │ └── host_caching_allocator.h ├── base │ ├── noncopyable.h │ ├── copyable.h │ ├── countdown_latch.h │ ├── countdown_latch.cc │ ├── exception.h │ ├── current_thread.h │ ├── exception.cc │ ├── log_file.h │ ├── thread.h │ ├── timezone.h │ ├── thread_pool.h │ ├── process_info.h │ ├── timestamp.cc │ ├── date.cc │ ├── file_util.h │ ├── date.h │ ├── log_file.cc │ ├── thread_pool.cc │ ├── timestamp.h │ ├── logging.h │ ├── log_stream.h │ └── file_util.cc ├── engine │ └── h2d_engine.h ├── aio │ ├── archer_aio_threadpool.h │ ├── archer_aio_thread.h │ ├── archer_aio_utils.h │ ├── archer_aio_threadpool.cpp │ ├── archer_tensor_index.h │ ├── archer_aio_thread.cpp │ ├── archer_tensor_handle.h │ ├── archer_prio_aio_handle.h │ └── archer_aio_utils.cpp ├── prefetch │ ├── task_thread.h │ ├── task_thread.cpp │ ├── archer_prefetch_handle.h │ └── task_scheduler.h ├── common │ ├── time.h │ ├── status.h │ ├── types.h │ └── pytorch.h ├── utils │ ├── archer_logger.cpp │ ├── archer_logger.h │ ├── cuda_utils.cpp │ ├── cuda_utils.h │ ├── threadsafe_queue.h │ ├── cache.h │ ├── logger.cpp │ ├── prefix_tree.h │ ├── lockfree_queue.h │ ├── simple_object_pool.h │ └── logger.h └── parallel │ └── expert_dispatcher.h ├── moe_infinity ├── ops │ ├── __init__.py │ ├── core │ ├── prefetch │ │ └── __init__.py │ └── op_builder ├── entrypoints │ ├── openai │ │ └── __init__.py │ └── __init__.py ├── common │ ├── __init__.py │ └── constants.py ├── runtime │ ├── __init__.py │ ├── state_dict.py │ └── hooks.py ├── __init__.py ├── models │ ├── modeling_grok │ │ ├── __init__.py │ │ └── configuration_grok1.py │ ├── modeling_deepseek_v3 │ │ └── __init__.py │ ├── modeling_arctic │ │ ├── __init__.py │ │ └── tokenization_arctic.py │ ├── modeling_deepseek │ │ ├── __init__.py │ │ └── tokenization_deepseek_fast.py │ ├── __init__.py │ ├── arctic.py │ ├── model_utils.py │ └── nllb_moe.py ├── utils │ ├── __init__.py │ ├── arguments.py │ ├── config.py │ └── checkpoints.py ├── distributed │ ├── __init__.py │ ├── devicemap_manager.py │ └── expert_prefetcher.py └── memory │ ├── __init__.py │ ├── expert_entry.py │ ├── expert_predictor.py │ └── expert_prefetcher.py ├── setup.cfg ├── MANIFEST.in ├── requirements-lint.txt ├── requirements.txt ├── tests ├── test_oai_completions.py ├── test_oai_chat_completions.py └── queues │ ├── CMakeLists.txt │ ├── test_lockfree_queue.cpp │ └── test_threadsafe_queue.cpp ├── CITATIONS.md ├── .github ├── PULL_REQUEST_TEMPLATE.md ├── workflows │ ├── scripts │ │ ├── create-release.js │ │ ├── cuda-install.sh │ │ └── free-disk-space.sh │ ├── build-test.yml │ ├── pre-commit-format.yml │ ├── publish-test.yml │ └── publish.yml └── ISSUE_TEMPLATE │ ├── feature_request.yml │ └── bug_report.yml ├── pyproject.toml ├── .clang-format ├── examples └── readme_example.py ├── op_builder ├── all_ops.py ├── __init__.py └── prefetch.py ├── .pre-commit-config.yaml ├── RELEASE.md └── setup.py /core/core: -------------------------------------------------------------------------------- 1 | core/ -------------------------------------------------------------------------------- /moe_infinity/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /moe_infinity/ops/core: -------------------------------------------------------------------------------- 1 | ../../core -------------------------------------------------------------------------------- /moe_infinity/ops/prefetch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /moe_infinity/entrypoints/openai/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /moe_infinity/ops/op_builder: -------------------------------------------------------------------------------- 1 | ../../op_builder -------------------------------------------------------------------------------- /moe_infinity/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import * 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [options.data_files] 2 | . = requirements.txt 3 | -------------------------------------------------------------------------------- /moe_infinity/entrypoints/__init__.py: -------------------------------------------------------------------------------- 1 | from .big_modeling import MoE 2 | -------------------------------------------------------------------------------- /moe_infinity/runtime/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_offload import OffloadEngine 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include core *.cpp *.h *.cc 2 | recursive-include op_builder *.py 3 | -------------------------------------------------------------------------------- /moe_infinity/__init__.py: -------------------------------------------------------------------------------- 1 | from moe_infinity.entrypoints import MoE 2 | from moe_infinity.runtime import OffloadEngine 3 | 4 | __version__ = "0.0.1" 5 | -------------------------------------------------------------------------------- /moe_infinity/models/modeling_grok/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_grok1 import Grok1Config 2 | from .modeling_grok1 import Grok1ModelForCausalLM, MoeBlock, MoeMLP 3 | -------------------------------------------------------------------------------- /moe_infinity/models/modeling_deepseek_v3/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_deepseek import ( 2 | DeepseekV3ForCausalLM, 3 | DeepseekV3MLP, 4 | DeepseekV3MoE, 5 | MoEGate, 6 | ) 7 | -------------------------------------------------------------------------------- /moe_infinity/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoints import get_checkpoint_paths 2 | from .config import ArcherConfig 3 | from .hf_config import ( 4 | parse_expert_dtype, 5 | parse_expert_id, 6 | parse_moe_param, 7 | ) 8 | -------------------------------------------------------------------------------- /requirements-lint.txt: -------------------------------------------------------------------------------- 1 | clang-format==18.1.4 2 | isort==5.13.2 3 | 4 | # type checking 5 | mypy==1.8.0 6 | pre-commit 7 | ruff==0.6.9 8 | toml==0.10.2 9 | tomli==2.0.1 10 | types-PyYAML 11 | types-requests 12 | types-setuptools 13 | -------------------------------------------------------------------------------- /moe_infinity/models/modeling_arctic/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_arctic import ArcticConfig 2 | from .modeling_arctic import ( 3 | ArcticForCausalLM, 4 | ArcticMLP, 5 | ArcticMoE, 6 | apply_rotary_pos_emb, 7 | ) 8 | from .tokenization_arctic import ArcticTokenizer 9 | -------------------------------------------------------------------------------- /moe_infinity/models/modeling_deepseek/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_deepseek import DeepseekV2Config 2 | from .modeling_deepseek import ( 3 | DeepseekV2ForCausalLM, 4 | DeepseekV2MLP, 5 | DeepseekV2MoE, 6 | MoEGate, 7 | ) 8 | from .tokenization_deepseek_fast import DeepseekTokenizerFast 9 | -------------------------------------------------------------------------------- /moe_infinity/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) EfficientMoE. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # EfficientMoE Team 5 | 6 | from .devicemap_manager import DeviceMapManager 7 | from .expert_executor import DistributedExpertExecutor 8 | from .expert_prefetcher import DistributedExpertPrefetcher 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | auto_gptq 3 | chardet 4 | datasets>=2.12.0 5 | fastapi 6 | hjson 7 | ninja 8 | numpy==1.22.4 9 | openai 10 | optimum>=1.17.1 11 | packaging>=20.0 12 | pre-commit 13 | py-cpuinfo 14 | pyarrow==12.0.0 15 | pydantic==1.10.12 16 | scipy 17 | sentencepiece 18 | sphinx 19 | torch>=2.1.1 20 | transformers>=4.37.1, <4.47 21 | uvicorn 22 | -------------------------------------------------------------------------------- /moe_infinity/memory/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) EfficientMoE. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # EfficientMoE Team 5 | 6 | from .expert_cache import ExpertCache 7 | from .expert_predictor import ExpertPredictor 8 | from .expert_prefetcher import ExpertPrefetcher 9 | from .expert_priority_score import * 10 | from .expert_tracer import ExpertTracer 11 | -------------------------------------------------------------------------------- /tests/test_oai_completions.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | openai_api_key = "EMPTY" 4 | openai_api_base = "http://localhost:8000/v1" 5 | client = OpenAI( 6 | api_key=openai_api_key, 7 | base_url=openai_api_base, 8 | ) 9 | completion = client.completions.create( 10 | model="deepseek-ai/DeepSeek-V2-Lite-Chat", 11 | prompt="Write a story about a cat.", 12 | ) 13 | print("Completion result:", completion) 14 | -------------------------------------------------------------------------------- /core/memory/stream_pool.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #include "stream_pool.h" 7 | 8 | // Stream0 is used for H2D, Stream1 is used for Kernel, Stream2 is used for D2H 9 | // TorchStreamPool* kTorchStreamPool = TorchStreamPool::GetInstance(); 10 | std::unique_ptr kTorchStreamPool = 11 | std::make_unique(); 12 | -------------------------------------------------------------------------------- /core/base/noncopyable.h: -------------------------------------------------------------------------------- 1 | #ifndef MUDUO_BASE_NONCOPYABLE_H 2 | #define MUDUO_BASE_NONCOPYABLE_H 3 | 4 | namespace base { 5 | 6 | class noncopyable { 7 | protected: 8 | noncopyable() = default; 9 | ~noncopyable() = default; 10 | 11 | private: 12 | noncopyable(const noncopyable&) = delete; 13 | void operator=(const noncopyable&) = delete; 14 | }; 15 | 16 | } // namespace base 17 | 18 | #endif // MUDUO_BASE_NONCOPYABLE_H 19 | -------------------------------------------------------------------------------- /CITATIONS.md: -------------------------------------------------------------------------------- 1 | ```bibtex 2 | @misc{moe-infinity, 3 | author = {Leyang Xue and 4 | Yao Fu and 5 | Zhan Lu and 6 | Luo Mai and 7 | Mahesh Marina}, 8 | title = {MoE-Infinity: Efficient MoE Inference on Personal Machines with Sparsity-Aware Expert Cache}, 9 | archivePrefix= {arXiv}, 10 | eprint = {2401.14361}, 11 | year = {2024} 12 | } 13 | ``` 14 | -------------------------------------------------------------------------------- /core/base/copyable.h: -------------------------------------------------------------------------------- 1 | #ifndef MUDUO_BASE_COPYABLE_H 2 | #define MUDUO_BASE_COPYABLE_H 3 | 4 | namespace base { 5 | 6 | /// A tag class emphasises the objects are copyable. 7 | /// The empty base class optimization applies. 8 | /// Any derived class of copyable should be a value type. 9 | class copyable { 10 | protected: 11 | copyable() = default; 12 | ~copyable() = default; 13 | }; 14 | 15 | }; // namespace base 16 | 17 | #endif // MUDUO_BASE_COPYABLE_H 18 | -------------------------------------------------------------------------------- /core/engine/h2d_engine.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "base/noncopyable.h" 4 | #include "base/thread_pool.h" 5 | 6 | class H2DEngine : public base::noncopyable { 7 | public: 8 | H2DEngine(int id, int num_threads = 1) { 9 | std::string thread_name = std::string("H2DEngine-") + std::to_string(id); 10 | // use move constructor to avoid copy 11 | thread_pool_ = std::make_unique(thread_name); 12 | thread_pool_->start(num_threads); 13 | }; 14 | 15 | private: 16 | std::unique_ptr thread_pool_; 17 | }; 18 | -------------------------------------------------------------------------------- /tests/test_oai_chat_completions.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | openai_api_key = "EMPTY" 4 | openai_api_base = "http://localhost:8000/v1" 5 | 6 | client = OpenAI( 7 | api_key=openai_api_key, 8 | base_url=openai_api_base, 9 | ) 10 | 11 | chat_response = client.chat.completions.create( 12 | model="deepseek-ai/DeepSeek-V2-Lite-Chat", 13 | messages=[ 14 | {"role": "system", "content": "You are a helpful assistant."}, 15 | {"role": "user", "content": "Tell me a joke"}, 16 | ], 17 | ) 18 | print("Chat response:", chat_response) 19 | -------------------------------------------------------------------------------- /moe_infinity/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) EfficientMoE. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # EfficientMoE Team 5 | 6 | from .arctic import ArcticConfig, SyncArcticMoeBlock 7 | from .deepseek import DeepseekMoEBlock 8 | from .grok import SyncGrokMoeBlock 9 | from .mixtral import SyncMixtralSparseMoeBlock 10 | from .model_utils import ( 11 | apply_rotary_pos_emb, 12 | apply_rotary_pos_emb_deepseek, 13 | rotate_half, 14 | ) 15 | from .nllb_moe import SyncNllbMoeSparseMLP 16 | from .switch_transformers import SyncSwitchTransformersSparseMLP 17 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | Briefly describe your changes. 3 | 4 | ## Motivation 5 | Explain why this change is needed and what problem it solves. 6 | If it fixes an issue, link it (e.g., `close #123`). 7 | 8 | ## Type of Change 9 | - [ ] Bug fix 10 | - [ ] New feature 11 | - [ ] Breaking change 12 | - [ ] Documentation update 13 | 14 | ## Checklist 15 | - [ ] I have read the [CONTRIBUTION](https://github.com/EfficientMoE/MoE-Infinity/blob/main/CONTRIBUTING.md) guide. 16 | - [ ] I have updated the tests (if applicable). 17 | - [ ] I have updated the documentation (if applicable). 18 | -------------------------------------------------------------------------------- /core/aio/archer_aio_threadpool.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | 10 | #include "archer_aio_thread.h" 11 | 12 | class ArcherAioThreadPool { 13 | public: 14 | explicit ArcherAioThreadPool(int num_threads); 15 | ~ArcherAioThreadPool(); 16 | 17 | void Start(); 18 | void Stop(); 19 | 20 | void Enqueue(AioCallback& callback, int thread_id = -1); 21 | void Wait(); 22 | 23 | private: 24 | int num_threads_; 25 | std::vector> threads_; 26 | }; 27 | -------------------------------------------------------------------------------- /core/prefetch/task_thread.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | 10 | #include "model/model_topology.h" 11 | 12 | typedef std::vector> NodeMoveVec; 13 | 14 | void SetThreadScheduling(std::thread& th, int policy, int priority); 15 | void SetThreadAffinity(std::thread& th, int cpu_id); 16 | void SetThreadAffinity(std::thread& th); 17 | void SetThreadAffinity(pid_t tid); 18 | 19 | static std::atomic_uint64_t kCPUCounter{0}; 20 | extern std::atomic_uint32_t kGPUCounter; 21 | -------------------------------------------------------------------------------- /moe_infinity/memory/expert_entry.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import numpy as np 4 | 5 | 6 | @dataclass 7 | class ExpertTraceEntry: 8 | seq_id: str = None 9 | matrix: np.ndarray = None 10 | access: int = 0 11 | num_new_tokens: int = 0 12 | 13 | def __hash__(self): 14 | return hash(self.seq_id) 15 | 16 | 17 | @dataclass 18 | class ExpertCacheEntry: 19 | expert_idx: int = None 20 | layer_idx: int = None 21 | r: float = 0.0 22 | visit: int = 0 23 | timestamp: int = 0 24 | 25 | def __hash__(self): 26 | return hash((self.layer_idx, self.expert_idx)) 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools==75.3.2", "wheel", "torch"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [tool.ruff] 7 | line-length = 80 8 | exclude = [] 9 | 10 | [tool.ruff.lint] 11 | fixable = ["ALL"] 12 | unfixable = [ 13 | # star imports 14 | "F405", 15 | "F403", 16 | # lambda expression assignment 17 | "E731", 18 | # Loop control variable not used within loop body 19 | "B007", 20 | # raise distinguish errors 21 | "B904", 22 | # f-string format 23 | "UP032", 24 | ] 25 | select = [ 26 | # isort 27 | "I", 28 | ] 29 | ignore = [ 30 | # Loop control variable not used within loop body 31 | "B007" 32 | ] 33 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | 2 | BasedOnStyle: Google 3 | UseTab: Never 4 | IndentWidth: 2 5 | ColumnLimit: 80 6 | 7 | # Force pointers to the type for C++. 8 | DerivePointerAlignment: false 9 | PointerAlignment: Left 10 | 11 | # Reordering #include statements can (and currently will) introduce errors 12 | SortIncludes: false 13 | 14 | # Style choices 15 | AlignConsecutiveAssignments: false 16 | AlignConsecutiveDeclarations: false 17 | IndentPPDirectives: BeforeHash 18 | 19 | IncludeCategories: 20 | - Regex: '^' 21 | Priority: 2 22 | - Regex: '^<.*\.h>' 23 | Priority: 1 24 | - Regex: '^<.*' 25 | Priority: 2 26 | - Regex: '.*' 27 | Priority: 3 28 | -------------------------------------------------------------------------------- /.github/workflows/scripts/create-release.js: -------------------------------------------------------------------------------- 1 | // Uses Github's API to create the release and wait for result. 2 | // We use a JS script since github CLI doesn't provide a way to wait for the 3 | // release's creation and returns immediately. 4 | 5 | module.exports = async (github, context, core) => { 6 | try { 7 | const response = await github.rest.repos.createRelease({ 8 | draft: false, 9 | generate_release_notes: true, 10 | name: process.env.RELEASE_TAG, 11 | owner: context.repo.owner, 12 | prerelease: false, 13 | repo: context.repo.repo, 14 | tag_name: process.env.RELEASE_TAG, 15 | }); 16 | 17 | core.setOutput('upload_url', response.data.upload_url); 18 | } catch (error) { 19 | core.setFailed(error.message); 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /core/base/countdown_latch.h: -------------------------------------------------------------------------------- 1 | // Use of this source code is governed by a BSD-style license 2 | // that can be found in the License file. 3 | // 4 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 5 | 6 | #ifndef MUDUO_BASE_COUNTDOWNLATCH_H 7 | #define MUDUO_BASE_COUNTDOWNLATCH_H 8 | 9 | #include "noncopyable.h" 10 | 11 | #include 12 | #include 13 | 14 | namespace base { 15 | 16 | class CountDownLatch : noncopyable { 17 | public: 18 | explicit CountDownLatch(int count); 19 | 20 | void wait(); 21 | 22 | void countDown(); 23 | 24 | int getCount() const; 25 | 26 | private: 27 | mutable std::mutex mutex_; 28 | std::condition_variable condition_; 29 | int count_; 30 | }; 31 | 32 | } // namespace base 33 | #endif // MUDUO_BASE_COUNTDOWNLATCH_H 34 | -------------------------------------------------------------------------------- /tests/queues/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | project(LockFreeQueueTests) 3 | 4 | # Set C++ standard 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CXX_STANDARD_REQUIRED True) 7 | 8 | # Add GoogleTest 9 | find_package(GTest REQUIRED) 10 | include_directories(${GTEST_INCLUDE_DIRS}) 11 | 12 | # Include directories 13 | message("CMAKE_SOURCE_DIR: ${CMAKE_SOURCE_DIR}") 14 | include_directories(${CMAKE_SOURCE_DIR}/../../core) 15 | 16 | # Add test executables 17 | add_executable(test_lockfree_queue ../test_lockfree_queue.cpp) 18 | add_executable(test_threadsafe_queue ../test_threadsafe_queue.cpp) 19 | 20 | # Link GoogleTest and pthread 21 | target_link_libraries(test_lockfree_queue ${GTEST_LIBRARIES} pthread) 22 | target_link_libraries(test_threadsafe_queue ${GTEST_LIBRARIES} pthread) 23 | -------------------------------------------------------------------------------- /core/aio/archer_aio_thread.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "archer_aio_utils.h" 15 | 16 | class ArcherAioThread { 17 | public: 18 | explicit ArcherAioThread(int thread_id); 19 | ~ArcherAioThread(); 20 | 21 | void Start(); 22 | void Stop(); 23 | 24 | void Enqueue(AioCallback& callback); 25 | void Wait(); 26 | 27 | private: 28 | void Run(); 29 | 30 | private: 31 | int thread_id_; 32 | std::thread thread_; 33 | bool is_running_; 34 | 35 | std::list callbacks_; 36 | 37 | std::mutex mutex_; 38 | std::atomic pending_callbacks_; 39 | }; 40 | -------------------------------------------------------------------------------- /core/common/time.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | 10 | typedef std::chrono::high_resolution_clock::time_point TimePoint; 11 | 12 | #define TIME_NOW std::chrono::high_resolution_clock::now() 13 | #define MCIROSECONDS std::chrono::microseconds 14 | #define MILLISECONDS std::chrono::milliseconds 15 | #define SECONDS std::chrono::seconds 16 | 17 | #define MCIROSECONDS_SINCE_EPOCH \ 18 | std::chrono::duration_cast(TIME_NOW.time_since_epoch()).count() 19 | #define MILLISECONDS_SINCE_EPOCH \ 20 | std::chrono::duration_cast(TIME_NOW.time_since_epoch()).count() 21 | #define SECONDS_SINCE_EPOCH \ 22 | std::chrono::duration_cast(TIME_NOW.time_since_epoch()).count() 23 | -------------------------------------------------------------------------------- /.github/workflows/scripts/cuda-install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Replace '.' with '-' ex: 11.8 -> 11-8 4 | cuda_version=$(echo $1 | tr "." "-") 5 | # Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004 6 | OS=$(echo $2 | tr -d ".\-") 7 | 8 | sudo apt -qq update 9 | sudo apt install -y wget gnupg2 10 | 11 | # Installs CUDA 12 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb 13 | sudo dpkg -i cuda-keyring_1.1-1_all.deb 14 | rm cuda-keyring_1.1-1_all.deb 15 | sudo apt -qq update 16 | sudo apt --no-install-recommends -y install cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version} 17 | sudo apt clean 18 | 19 | # Test nvcc 20 | PATH=/usr/local/cuda-$1/bin:${PATH} 21 | nvcc --version 22 | 23 | # Log gcc, g++, c++ versions 24 | gcc --version 25 | g++ --version 26 | c++ --version 27 | -------------------------------------------------------------------------------- /core/base/countdown_latch.cc: -------------------------------------------------------------------------------- 1 | // Use of this source code is governed by a BSD-style license 2 | // that can be found in the License file. 3 | // 4 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 5 | 6 | #include "countdown_latch.h" 7 | 8 | using namespace base; 9 | 10 | CountDownLatch::CountDownLatch(int count) 11 | : mutex_(), condition_(), count_(count) {} 12 | 13 | void CountDownLatch::wait() { 14 | std::unique_lock lock(mutex_); 15 | while (count_ > 0) { 16 | condition_.wait(lock); 17 | } 18 | } 19 | 20 | void CountDownLatch::countDown() { 21 | std::lock_guard lock(mutex_); 22 | --count_; 23 | if (count_ == 0) { 24 | condition_.notify_all(); 25 | } 26 | } 27 | 28 | int CountDownLatch::getCount() const { 29 | std::lock_guard lock(mutex_); 30 | return count_; 31 | } 32 | -------------------------------------------------------------------------------- /core/base/exception.h: -------------------------------------------------------------------------------- 1 | // Use of this source code is governed by a BSD-style license 2 | // that can be found in the License file. 3 | // 4 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 5 | 6 | #ifndef MUDUO_BASE_EXCEPTION_H 7 | #define MUDUO_BASE_EXCEPTION_H 8 | 9 | #include 10 | #include 11 | #include "types.h" 12 | 13 | namespace base { 14 | 15 | class Exception : public std::exception { 16 | public: 17 | explicit Exception(const char* what); 18 | explicit Exception(const std::string& what); 19 | virtual ~Exception() throw(); 20 | virtual const char* what() const throw(); 21 | const char* stackTrace() const throw(); 22 | 23 | private: 24 | void fillStackTrace(); 25 | 26 | std::string message_; 27 | std::string stack_; 28 | }; 29 | 30 | } // namespace base 31 | 32 | #endif // MUDUO_BASE_EXCEPTION_H 33 | -------------------------------------------------------------------------------- /core/aio/archer_aio_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | typedef std::function AioCallback; 13 | 14 | int ArcherOpenFile(const char* filename); 15 | int ArcherCloseFile(const int fd); 16 | int ArcherReadFileBatch(const int fd, void* buffer, const size_t num_bytes, 17 | const size_t offset); 18 | int ArcherWriteFileBatch(const int fd, const void* buffer, 19 | const size_t num_bytes, const size_t offset); 20 | int ArcherReadFile(int fd, void* buffer, const size_t num_bytes, 21 | const size_t offset); 22 | int ArcherWriteFile(int fd, const void* buffer, size_t num_bytes, 23 | size_t offset); 24 | -------------------------------------------------------------------------------- /examples/readme_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from transformers import AutoTokenizer 4 | 5 | from moe_infinity import MoE 6 | 7 | user_home = os.path.expanduser("~") 8 | 9 | checkpoint = "deepseek-ai/DeepSeek-V2-Lite-Chat" 10 | tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote=True) 11 | 12 | config = { 13 | "offload_path": os.path.join(user_home, "moe-infinity"), 14 | "device_memory_ratio": 0.75, # 75% of the device memory is used for caching, change the value according to your device memory size on OOM 15 | } 16 | 17 | model = MoE(checkpoint, config) 18 | 19 | input_text = "translate English to German: How old are you?" 20 | input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda:0") 21 | 22 | output_ids = model.generate(input_ids) 23 | output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) 24 | 25 | print(output_text) 26 | -------------------------------------------------------------------------------- /.github/workflows/build-test.yml: -------------------------------------------------------------------------------- 1 | name: Build Test 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | - dev 8 | paths-ignore: 9 | - '**.md' 10 | - 'examples/**' 11 | - 'tests/**' 12 | - 'docs/**' 13 | 14 | concurrency: 15 | group: ${{ github.workflow }}-${{ github.ref }} 16 | cancel-in-progress: true 17 | 18 | jobs: 19 | build: 20 | runs-on: ubuntu-20.04 21 | container: 22 | image: nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 23 | 24 | steps: 25 | - name: Checkout Source Code 26 | uses: actions/checkout@v3 27 | 28 | - name: Install Dependencies 29 | run: | 30 | apt update && apt install --no-install-recommends -y python3-pip python3-dev 31 | python3 -m pip install --upgrade pip 32 | python3 -m pip install build 33 | 34 | - name: Build Wheel 35 | run: | 36 | BUILD_OPS=1 python3 -m build 37 | -------------------------------------------------------------------------------- /core/common/status.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | enum StatusType : std::uint32_t { 12 | kOK = 0, 13 | kUnknown, 14 | kErrCuda, 15 | }; 16 | 17 | static const std::unordered_map kStatusStr = { 18 | {kOK, ""}, {kUnknown, "unknown: "}, {kErrCuda, "cuda error: "}}; 19 | 20 | class Status { 21 | public: 22 | Status() : status_(kOK), err_() {} 23 | bool OK() const { return status_ == kOK; } 24 | const uint32_t status() const { return status_; } 25 | const std::string& err() const { return err_; } 26 | void SetError(StatusType status, const std::string& msg) { 27 | status_ = status; 28 | if (kStatusStr.find(status) == kStatusStr.end()) status_ = kUnknown; 29 | err_ = kStatusStr.at(status_) + msg; 30 | } 31 | 32 | private: 33 | StatusType status_; 34 | std::string err_; 35 | }; 36 | -------------------------------------------------------------------------------- /core/utils/archer_logger.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) TorchMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // TorchMoE Team 5 | 6 | #include "archer_logger.h" 7 | 8 | #include 9 | // #include 10 | #include 11 | 12 | std::shared_ptr kLogger = nullptr; 13 | std::once_flag kLoggerFlag; 14 | 15 | static const char* kArcherLoggerName = "archer_logger"; 16 | 17 | void InitLogger() { 18 | std::call_once(kLoggerFlag, []() { 19 | kLogger = spdlog::get(kArcherLoggerName); 20 | kLogger = spdlog::stdout_color_mt(kArcherLoggerName); 21 | printf("SPDLOG_LEVEL : %s\n", getenv("SPDLOG_LEVEL")); 22 | if (getenv("SPDLOG_LEVEL")) { 23 | kLogger->set_level(spdlog::level::from_str(getenv("SPDLOG_LEVEL"))); 24 | } else { 25 | kLogger->set_level(spdlog::level::info); 26 | } 27 | kLogger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] [%s:%#] %v"); 28 | kLogger->debug("create logger for MoE-Infinity"); 29 | }); 30 | } 31 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit-format.yml: -------------------------------------------------------------------------------- 1 | name: Formatting 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | branches: 7 | '**' 8 | merge_group: 9 | branches: [ main ] 10 | schedule: 11 | - cron: "0 0 * * *" 12 | 13 | concurrency: 14 | group: ${{ github.workflow }}-${{ github.ref }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | 19 | # formatting and basic install on cpu-only machine 20 | unit-tests: 21 | runs-on: ubuntu-20.04 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | 26 | - name: environment 27 | run: | 28 | which python 29 | python --version 30 | 31 | - name: Install dependencies 32 | run: | 33 | # Previously we would do pip install .[dev] but this is causing out of 34 | # space errors start with torch 2.1.0 release 35 | pip install -r requirements-lint.txt 36 | 37 | - name: Formatting checks 38 | run: | 39 | pip show pre-commit clang-format 40 | pre-commit run --all-files 41 | -------------------------------------------------------------------------------- /core/utils/archer_logger.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) TorchMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // TorchMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | #include "noncopyable.h" 12 | 13 | extern std::shared_ptr kLogger; 14 | extern std::once_flag kLoggerFlag; 15 | 16 | extern void InitLogger(); 17 | 18 | #define ARCHER_LOG_INFO(...) SPDLOG_LOGGER_INFO(kLogger, __VA_ARGS__) 19 | #define ARCHER_LOG_ERROR(...) SPDLOG_LOGGER_ERROR(kLogger, __VA_ARGS__) 20 | #define ARCHER_LOG_WARN(...) SPDLOG_LOGGER_WARN(kLogger, __VA_ARGS__) 21 | #define ARCHER_LOG_DEBUG(...) kLogger->debug(__VA_ARGS__) 22 | #define ARCHER_LOG_TRACE(...) SPDLOG_LOGGER_TRACE(kLogger, __VA_ARGS__) 23 | #define ARCHER_LOG_CRITICAL(...) SPDLOG_LOGGER_CRITICAL(kLogger, __VA_ARGS__) 24 | #define ARCHER_LOG_FATAL(...) \ 25 | do { \ 26 | SPDLOG_LOGGER_CRITICAL(kLogger, __VA_ARGS__); \ 27 | throw std::runtime_error("Logged a FATAL error"); \ 28 | } while (0) 29 | -------------------------------------------------------------------------------- /core/base/current_thread.h: -------------------------------------------------------------------------------- 1 | // Use of this source code is governed by a BSD-style license 2 | // that can be found in the License file. 3 | // 4 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 5 | 6 | #ifndef MUDUO_BASE_CURRENTTHREAD_H 7 | #define MUDUO_BASE_CURRENTTHREAD_H 8 | 9 | #include 10 | 11 | namespace base { 12 | namespace CurrentThread { 13 | // internal 14 | extern __thread int t_cachedTid; 15 | extern __thread char t_tidString[32]; 16 | extern __thread int t_tidStringLength; 17 | extern __thread const char* t_threadName; 18 | void cacheTid(); 19 | 20 | inline int tid() { 21 | if (__builtin_expect(t_cachedTid == 0, 0)) { 22 | cacheTid(); 23 | } 24 | return t_cachedTid; 25 | } 26 | 27 | inline const char* tidString() // for logging 28 | { 29 | return t_tidString; 30 | } 31 | 32 | inline int tidStringLength() // for logging 33 | { 34 | return t_tidStringLength; 35 | } 36 | 37 | inline const char* name() { return t_threadName; } 38 | 39 | bool isMainThread(); 40 | 41 | void sleepUsec(int64_t usec); 42 | } // namespace CurrentThread 43 | } // namespace base 44 | 45 | #endif 46 | -------------------------------------------------------------------------------- /core/aio/archer_aio_threadpool.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #include "archer_aio_threadpool.h" 7 | 8 | #include "utils/logger.h" 9 | 10 | ArcherAioThreadPool::ArcherAioThreadPool(int num_threads) 11 | : num_threads_(num_threads) { 12 | for (auto i = 0; i < num_threads_; ++i) { 13 | threads_.emplace_back(std::make_unique(i)); 14 | } 15 | } 16 | 17 | ArcherAioThreadPool::~ArcherAioThreadPool() { Stop(); } 18 | 19 | void ArcherAioThreadPool::Start() { 20 | for (auto& thread : threads_) { 21 | thread->Start(); 22 | } 23 | } 24 | 25 | void ArcherAioThreadPool::Stop() { 26 | for (auto& thread : threads_) { 27 | thread->Stop(); 28 | } 29 | } 30 | 31 | void ArcherAioThreadPool::Enqueue(AioCallback& callback, int thread_id) { 32 | if (thread_id < 0) { 33 | const auto thread_id = rand() % num_threads_; 34 | threads_[thread_id]->Enqueue(callback); 35 | } else { 36 | threads_[thread_id]->Enqueue(callback); 37 | } 38 | } 39 | 40 | void ArcherAioThreadPool::Wait() { 41 | for (auto& thread : threads_) { 42 | thread->Wait(); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /core/base/exception.cc: -------------------------------------------------------------------------------- 1 | // Use of this source code is governed by a BSD-style license 2 | // that can be found in the License file. 3 | // 4 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 5 | 6 | #include "exception.h" 7 | 8 | // #include 9 | #include 10 | #include 11 | 12 | using namespace base; 13 | 14 | Exception::Exception(const char* msg) : message_(msg) { fillStackTrace(); } 15 | 16 | Exception::Exception(const std::string& msg) : message_(msg) { 17 | fillStackTrace(); 18 | } 19 | 20 | Exception::~Exception() throw() {} 21 | 22 | const char* Exception::what() const throw() { return message_.c_str(); } 23 | 24 | const char* Exception::stackTrace() const throw() { return stack_.c_str(); } 25 | 26 | void Exception::fillStackTrace() { 27 | const int len = 200; 28 | void* buffer[len]; 29 | int nptrs = ::backtrace(buffer, len); 30 | char** strings = ::backtrace_symbols(buffer, nptrs); 31 | if (strings) { 32 | for (int i = 0; i < nptrs; ++i) { 33 | // TODO demangle function name with abi::__cxa_demangle 34 | stack_.append(strings[i]); 35 | stack_.push_back('\n'); 36 | } 37 | free(strings); 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /core/base/log_file.h: -------------------------------------------------------------------------------- 1 | #ifndef MUDUO_BASE_LOGFILE_H 2 | #define MUDUO_BASE_LOGFILE_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "types.h" 9 | 10 | namespace base { 11 | 12 | namespace FileUtil { 13 | class AppendFile; 14 | } 15 | 16 | class LogFile : noncopyable { 17 | public: 18 | LogFile(const std::string& basename, size_t rollSize, bool threadSafe = true, 19 | int flushInterval = 3, int checkEveryN = 1024); 20 | ~LogFile(); 21 | 22 | void append(const char* logline, int len); 23 | void flush(); 24 | bool rollFile(); 25 | 26 | private: 27 | void append_unlocked(const char* logline, int len); 28 | 29 | static std::string getLogFileName(const std::string& basename, time_t* now); 30 | 31 | const std::string basename_; 32 | const size_t rollSize_; 33 | const int flushInterval_; 34 | const int checkEveryN_; 35 | 36 | int count_; 37 | 38 | std::unique_ptr mutex_; 39 | time_t startOfPeriod_; 40 | time_t lastRoll_; 41 | time_t lastFlush_; 42 | std::unique_ptr file_; 43 | 44 | const static int kRollPerSeconds_ = 60 * 60 * 24; 45 | }; 46 | 47 | } // namespace base 48 | #endif // MUDUO_BASE_LOGFILE_H 49 | -------------------------------------------------------------------------------- /core/common/types.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | 10 | typedef std::uint32_t TensorID; 11 | typedef std::size_t HashID; 12 | typedef std::size_t NodeID; 13 | typedef std::uint64_t GraphID; 14 | typedef std::uint64_t RequestID; 15 | 16 | #define KB 1024 17 | #define MB (KB * KB) 18 | #define GB (KB * KB * KB) 19 | 20 | #define DELETE_COPY_AND_ASSIGN(classname) \ 21 | classname(const classname&) = delete; \ 22 | classname& operator=(const classname&) = delete; \ 23 | classname(classname&&) = delete; \ 24 | classname& operator=(classname&&) = delete; 25 | 26 | #define STATIC_GET_INSTANCE(classname) \ 27 | static classname* GetInstance() { \ 28 | static std::once_flag flag; \ 29 | static classname* instance = nullptr; \ 30 | std::call_once(flag, []() { instance = new classname(); }); \ 31 | return instance; \ 32 | } 33 | 34 | template 35 | struct DoNothingDeleter { 36 | void operator()(T* ptr) const {} 37 | }; 38 | -------------------------------------------------------------------------------- /moe_infinity/utils/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) EfficientMoE. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # EfficientMoE Team 5 | 6 | import torch 7 | 8 | 9 | def copy_args_to_device(device, args): 10 | new_args = () 11 | if isinstance(args, torch.Tensor): 12 | return args.to(device) 13 | for i in range(len(args)): 14 | if isinstance(args[i], torch.Tensor): 15 | new_args += (args[i].to(device, non_blocking=True),) 16 | elif isinstance(args[i], list) or isinstance(args[i], tuple): 17 | # move_args_to_device(device, *args[i]) 18 | new_args += (copy_args_to_device(device, args[i]),) 19 | elif isinstance(args[i], dict): 20 | new_args += (copy_kwargs_to_device(device, args[i]),) 21 | else: 22 | new_args += (args[i],) 23 | # print("new_args", device, new_args) 24 | return new_args 25 | 26 | 27 | def copy_kwargs_to_device(device, kwargs): 28 | new_kwargs = kwargs 29 | for key, value in kwargs.items(): 30 | if isinstance(value, torch.Tensor): 31 | new_kwargs[key] = value.to(device, non_blocking=True) 32 | elif isinstance(value, list) or isinstance(value, tuple): 33 | new_kwargs[key] = copy_args_to_device(device, value) 34 | else: 35 | new_kwargs[key] = value 36 | return new_kwargs 37 | -------------------------------------------------------------------------------- /moe_infinity/memory/expert_predictor.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | from moe_infinity.memory.expert_tracer import ExpertTracer 4 | from moe_infinity.utils import parse_moe_param 5 | 6 | 7 | class ExpertPredictor: 8 | def __init__(self, config: PretrainedConfig) -> None: 9 | self.num_layers, self.num_experts, self.num_encoder_layers = ( 10 | parse_moe_param(config) 11 | ) 12 | self.layer_decay_func = lambda x, l, L: -1 / (L + 1) * (x - l) + 1 13 | 14 | def add_tracer(self, tracer: ExpertTracer): 15 | self.tracer = tracer 16 | 17 | def predict(self, seq_id, expert_list, layer_idx): 18 | self.tracer.update_entry(seq_id, expert_list, layer_idx) 19 | current_entry = self.tracer.get_entry(seq_id) 20 | 21 | # start_time = time.time() 22 | expert_matrix = self.tracer.find_most_similar( 23 | current_entry.matrix, layer_idx 24 | ) 25 | # print("find_most_similar", time.time() - start_time) 26 | 27 | # expert_matrix = copy.deepcopy(entry) 28 | expert_matrix[:layer_idx, :] = 0 29 | 30 | for l in range(layer_idx, self.num_layers): 31 | expert_matrix[l] = ( 32 | expert_matrix[l] + 1e-8 33 | ) * self.layer_decay_func(l, layer_idx, self.num_layers) 34 | 35 | return expert_matrix 36 | -------------------------------------------------------------------------------- /core/base/thread.h: -------------------------------------------------------------------------------- 1 | // Use of this source code is governed by a BSD-style license 2 | // that can be found in the License file. 3 | // 4 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 5 | 6 | #ifndef MUDUO_BASE_THREAD_H 7 | #define MUDUO_BASE_THREAD_H 8 | 9 | #include "countdown_latch.h" 10 | #include "types.h" 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | namespace base { 19 | 20 | class Thread : noncopyable { 21 | public: 22 | typedef std::function ThreadFunc; 23 | 24 | explicit Thread(ThreadFunc, const std::string& name = std::string()); 25 | // FIXME: make it movable in C++11 26 | ~Thread(); 27 | 28 | void start(); 29 | int join(); // return pthread_join() 30 | 31 | bool started() const { return started_; } 32 | // pthread_t pthreadId() const { return pthreadId_; } 33 | pid_t tid() const { return tid_; } 34 | const std::string& name() const { return name_; } 35 | 36 | static int numCreated() { return numCreated_.load(); } 37 | 38 | private: 39 | void setDefaultName(); 40 | 41 | bool started_; 42 | bool joined_; 43 | pthread_t pthreadId_; 44 | pid_t tid_; 45 | ThreadFunc func_; 46 | std::string name_; 47 | CountDownLatch latch_; 48 | 49 | static std::atomic_int32_t numCreated_; 50 | }; 51 | 52 | } // namespace base 53 | #endif 54 | -------------------------------------------------------------------------------- /core/memory/stream_pool.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | 10 | #include "base/noncopyable.h" 11 | #include "utils/cuda_utils.h" 12 | 13 | class TorchStreamPool : public base::noncopyable { 14 | public: 15 | std::vector& operator()(const int device_id) { 16 | return cuda_streams_[device_id]; 17 | } 18 | 19 | TorchStreamPool() { 20 | int num_devices = GetDeviceCount(); 21 | for (int i = 0; i < num_devices; ++i) { 22 | std::vector streams; 23 | for (int j = 0; j < 3; ++j) { 24 | streams.push_back(c10::cuda::getStreamFromPool(false, i)); 25 | } 26 | cuda_streams_.push_back(std::move(streams)); 27 | } 28 | } 29 | virtual ~TorchStreamPool() = default; 30 | 31 | private: 32 | std::vector> cuda_streams_; 33 | }; 34 | 35 | extern std::unique_ptr kTorchStreamPool; 36 | #define TORCH_STREAM_VIEW(device_id, stream_id) \ 37 | (*kTorchStreamPool)(device_id)[stream_id] 38 | #define TORCH_STREAM_H2D_VIEW(device_id) TORCH_STREAM_VIEW(device_id, 0) 39 | #define TORCH_STREAM_D2H_VIEW(device_id) TORCH_STREAM_VIEW(device_id, 1) 40 | #define TORCH_STREAM_COMPUTE_VIEW(device_id) TORCH_STREAM_VIEW(device_id, 2) 41 | -------------------------------------------------------------------------------- /op_builder/all_ops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Microsoft Corporation. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # DeepSpeed Team 6 | 7 | # op_builder/all_ops.py 8 | # 9 | # Part of the DeepSpeed Project, under the Apache-2.0 License. 10 | # See https://github.com/microsoft/DeepSpeed/blob/master/LICENSE for license information. 11 | # SPDX-License-Identifier: Apache-2.0 12 | 13 | # MoE-Infinity: deleted accelerator check. 14 | 15 | import importlib 16 | import os 17 | import pkgutil 18 | 19 | __op_builders__ = [] 20 | 21 | op_builder_dir = "op_builder" 22 | op_builder_module = importlib.import_module(op_builder_dir) 23 | 24 | for _, module_name, _ in pkgutil.iter_modules( 25 | [os.path.dirname(op_builder_module.__file__)] 26 | ): 27 | # avoid self references 28 | if module_name != "all_ops" and module_name != "builder": 29 | module = importlib.import_module( 30 | "{}.{}".format(op_builder_dir, module_name) 31 | ) 32 | for member_name in module.__dir__(): 33 | if ( 34 | member_name.endswith("Builder") 35 | and member_name != "OpBuilder" 36 | and member_name != "CUDAOpBuilder" 37 | ): 38 | # append builder to __op_builders__ list 39 | builder = getattr(module, member_name)() 40 | __op_builders__.append(builder) 41 | 42 | ALL_OPS = {op.name: op for op in __op_builders__ if op is not None} 43 | -------------------------------------------------------------------------------- /core/base/timezone.h: -------------------------------------------------------------------------------- 1 | // Use of this source code is governed by a BSD-style license 2 | // that can be found in the License file. 3 | // 4 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 5 | 6 | #ifndef MUDUO_BASE_TIMEZONE_H 7 | #define MUDUO_BASE_TIMEZONE_H 8 | 9 | #include 10 | #include "copyable.h" 11 | #include "time.h" 12 | 13 | namespace base { 14 | 15 | // TimeZone for 1970~2030 16 | class TimeZone : public copyable { 17 | public: 18 | explicit TimeZone(const char* zonefile); 19 | TimeZone(int eastOfUtc, const char* tzname); // a fixed timezone 20 | TimeZone() {} // an invalid timezone 21 | 22 | // default copy ctor/assignment/dtor are Okay. 23 | 24 | bool valid() const { 25 | // 'explicit operator bool() const' in C++11 26 | return static_cast(data_); 27 | } 28 | 29 | struct tm toLocalTime(time_t secondsSinceEpoch) const; 30 | time_t fromLocalTime(const struct tm&) const; 31 | 32 | // gmtime(3) 33 | static struct tm toUtcTime(time_t secondsSinceEpoch, bool yday = false); 34 | // timegm(3) 35 | static time_t fromUtcTime(const struct tm&); 36 | // year in [1900..2500], month in [1..12], day in [1..31] 37 | static time_t fromUtcTime(int year, int month, int day, int hour, int minute, 38 | int seconds); 39 | 40 | struct Data; 41 | 42 | private: 43 | std::shared_ptr data_; 44 | }; 45 | 46 | } // namespace base 47 | #endif // MUDUO_BASE_TIMEZONE_H 48 | -------------------------------------------------------------------------------- /core/base/thread_pool.h: -------------------------------------------------------------------------------- 1 | // Use of this source code is governed by a BSD-style license 2 | // that can be found in the License file. 3 | // 4 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 5 | 6 | #ifndef MUDUO_BASE_THREADPOOL_H 7 | #define MUDUO_BASE_THREADPOOL_H 8 | 9 | #include 10 | #include 11 | #include "thread.h" 12 | #include "types.h" 13 | 14 | #include 15 | #include 16 | 17 | namespace base { 18 | 19 | class ThreadPool : noncopyable { 20 | public: 21 | typedef std::function Task; 22 | 23 | explicit ThreadPool(const std::string& nameArg = std::string("ThreadPool")); 24 | ~ThreadPool(); 25 | 26 | // Must be called before start(). 27 | void setMaxQueueSize(int maxSize) { maxQueueSize_ = maxSize; } 28 | void setThreadInitCallback(const Task& cb) { threadInitCallback_ = cb; } 29 | 30 | void start(int numThreads); 31 | void stop(); 32 | 33 | const std::string& name() const { return name_; } 34 | 35 | size_t queueSize() const; 36 | 37 | // Could block if maxQueueSize > 0 38 | void run(Task f); 39 | 40 | private: 41 | bool isFull() const; 42 | void runInThread(); 43 | Task take(); 44 | 45 | mutable std::mutex mutex_; 46 | std::condition_variable notEmpty_; 47 | std::condition_variable notFull_; 48 | std::string name_; 49 | Task threadInitCallback_; 50 | std::vector> threads_; 51 | std::deque queue_; 52 | size_t maxQueueSize_; 53 | bool running_; 54 | }; 55 | 56 | } // namespace base 57 | 58 | #endif 59 | -------------------------------------------------------------------------------- /moe_infinity/models/modeling_deepseek/tokenization_deepseek_fast.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | from transformers.models.llama import LlamaTokenizerFast 4 | 5 | 6 | class DeepseekTokenizerFast(LlamaTokenizerFast): 7 | def convert_ids_to_tokens( 8 | self, ids: Union[int, List[int]], skip_special_tokens: bool = False 9 | ) -> Union[str, List[str]]: 10 | """ 11 | Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and 12 | added tokens. 13 | 14 | Args: 15 | ids (`int` or `List[int]`): 16 | The token id (or token ids) to convert to tokens. 17 | skip_special_tokens (`bool`, *optional*, defaults to `False`): 18 | Whether or not to remove special tokens in the decoding. 19 | 20 | Returns: 21 | `str` or `List[str]`: The decoded token(s). 22 | """ 23 | if isinstance(ids, int): 24 | return self._convert_id_to_token(ids) 25 | tokens = [] 26 | for index in ids: 27 | index = int(index) 28 | if skip_special_tokens and index in self.all_special_ids: 29 | continue 30 | token = self._tokenizer.id_to_token(index) 31 | tokens.append(token if token is not None else "") 32 | return tokens 33 | 34 | def _convert_id_to_token(self, index: int) -> Optional[str]: 35 | token = self._tokenizer.id_to_token(int(index)) 36 | return token if token is not None else "" 37 | -------------------------------------------------------------------------------- /core/aio/archer_tensor_index.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "base/noncopyable.h" 17 | #include "common/pytorch.h" 18 | #include "common/types.h" 19 | 20 | static const std::uint32_t kTensorIndexVersion = 1; 21 | 22 | struct TensorStorageMeta { 23 | std::uint32_t file_id; 24 | std::int64_t offset; 25 | std::size_t size; 26 | std::vector shape; 27 | torch::TensorOptions options; 28 | TensorID id; 29 | 30 | // not for serialization 31 | torch::Tensor tensor; 32 | torch::Device device = DISK_DEVICE; 33 | 34 | std::string DebugString() const; 35 | }; 36 | 37 | std::ostream& operator<<(std::ostream& os, const TensorStorageMeta& obj); 38 | std::istream& operator>>(std::istream& is, TensorStorageMeta& obj); 39 | void write_options(std::ostream& os, const torch::TensorOptions& obj); 40 | void read_options(std::istream& is, torch::TensorOptions& obj); 41 | 42 | class ArcherTensorIndex 43 | : public std::unordered_map, 44 | public base::noncopyable { 45 | public: 46 | void Serialize(const char* path); 47 | void Deserialize(const char* path); 48 | 49 | ArcherTensorIndex() = default; 50 | ~ArcherTensorIndex() = default; 51 | 52 | private: 53 | }; 54 | 55 | extern std::unique_ptr kTensorIndex; 56 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: meta 3 | hooks: 4 | - id: check-hooks-apply 5 | - id: check-useless-excludes 6 | 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v4.0.1 9 | hooks: 10 | - id: check-case-conflict 11 | # - id: check-json 12 | # - id: check-symlinks 13 | - id: check-yaml 14 | - id: destroyed-symlinks 15 | - id: end-of-file-fixer 16 | - id: fix-byte-order-marker 17 | - id: fix-encoding-pragma 18 | args: [--remove] 19 | - id: mixed-line-ending 20 | args: [--fix=lf] 21 | - id: requirements-txt-fixer 22 | - id: trailing-whitespace 23 | 24 | - repo: https://github.com/astral-sh/ruff-pre-commit 25 | # Ruff version. 26 | rev: v0.6.9 27 | hooks: 28 | - id: ruff 29 | args: [--fix] 30 | - id: ruff-format 31 | # args: [--check] 32 | 33 | 34 | - repo: https://github.com/pre-commit/mirrors-clang-format 35 | rev: v18.1.4 36 | hooks: 37 | - id: clang-format 38 | 39 | - repo: https://github.com/codespell-project/codespell 40 | rev: v2.3.0 41 | hooks: 42 | - id: codespell 43 | args: [ 44 | # Do not check files that are automatically generated 45 | '--skip=docs/Gemfile.lock,tests/unit/gpt2-merges.txt,tests/unit/gpt2-vocab.json', 46 | '--ignore-regex=\\n', # Do not count the 'n' in an escaped newline as part of a word 47 | '--ignore-words-list=youn,unsupport,noe,ccompiler', # Word used in error messages that need rewording 48 | --check-filenames, 49 | --check-hidden 50 | ] 51 | -------------------------------------------------------------------------------- /core/aio/archer_aio_thread.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #include "archer_aio_thread.h" 7 | 8 | #include "utils/logger.h" 9 | 10 | ArcherAioThread::ArcherAioThread(int thread_id) 11 | : thread_id_(thread_id), is_running_(false) { 12 | DLOG_INFO("Create ArcherAioThread for thread: ", thread_id_); 13 | } 14 | 15 | ArcherAioThread::~ArcherAioThread() { Stop(); } 16 | 17 | void ArcherAioThread::Start() { 18 | if (is_running_) { 19 | return; 20 | } 21 | 22 | is_running_ = true; 23 | pending_callbacks_ = 0; 24 | thread_ = std::thread(&ArcherAioThread::Run, this); 25 | } 26 | 27 | void ArcherAioThread::Stop() { 28 | if (!is_running_) { 29 | return; 30 | } 31 | 32 | is_running_ = false; 33 | thread_.join(); 34 | } 35 | 36 | void ArcherAioThread::Enqueue(AioCallback& callback) { 37 | std::lock_guard lock(mutex_); 38 | callbacks_.push_back(std::move(callback)); 39 | pending_callbacks_.fetch_add(1); 40 | } 41 | 42 | void ArcherAioThread::Wait() { 43 | // while (!callbacks_.empty()) { usleep(1000); } 44 | while (pending_callbacks_.load() != 0) { 45 | usleep(1000); 46 | } 47 | std::lock_guard lock(mutex_); 48 | callbacks_.clear(); 49 | } 50 | 51 | void ArcherAioThread::Run() { 52 | while (is_running_) { 53 | std::function callback; 54 | { 55 | std::lock_guard lock(mutex_); 56 | if (callbacks_.empty()) { 57 | continue; 58 | } 59 | callback = std::move(callbacks_.front()); 60 | callbacks_.pop_front(); 61 | } 62 | callback(); 63 | pending_callbacks_.fetch_sub(1); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /core/base/process_info.h: -------------------------------------------------------------------------------- 1 | // Copyright 2010, Shuo Chen. All rights reserved. 2 | // http://code.google.com/p/muduo/ 3 | // 4 | // Use of this source code is governed by a BSD-style license 5 | // that can be found in the License file. 6 | 7 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 8 | // 9 | // This is a public header file, it must only include public header files. 10 | 11 | #ifndef MUDUO_BASE_PROCESSINFO_H 12 | #define MUDUO_BASE_PROCESSINFO_H 13 | 14 | #include 15 | 16 | #include 17 | #include 18 | 19 | #include "string_piece.h" 20 | #include "timestamp.h" 21 | #include "types.h" 22 | 23 | namespace base { 24 | 25 | namespace ProcessInfo { 26 | pid_t pid(); 27 | std::string pidString(); 28 | uid_t uid(); 29 | std::string username(); 30 | uid_t euid(); 31 | Timestamp startTime(); 32 | int clockTicksPerSecond(); 33 | int pageSize(); 34 | bool isDebugBuild(); // constexpr 35 | 36 | std::string hostname(); 37 | std::string procname(); 38 | StringPiece procname(const std::string& stat); 39 | 40 | /// read /proc/self/status 41 | std::string procStatus(); 42 | 43 | /// read /proc/self/stat 44 | std::string procStat(); 45 | 46 | /// read /proc/self/task/tid/stat 47 | std::string threadStat(); 48 | 49 | /// readlink /proc/self/exe 50 | std::string exePath(); 51 | 52 | int openedFiles(); 53 | int maxOpenFiles(); 54 | 55 | struct CpuTime { 56 | double userSeconds; 57 | double systemSeconds; 58 | 59 | CpuTime() : userSeconds(0.0), systemSeconds(0.0) {} 60 | }; 61 | CpuTime cpuTime(); 62 | 63 | int numThreads(); 64 | std::vector threads(); 65 | } // namespace ProcessInfo 66 | 67 | } // namespace base 68 | 69 | #endif // MUDUO_BASE_PROCESSINFO_H 70 | -------------------------------------------------------------------------------- /core/base/timestamp.cc: -------------------------------------------------------------------------------- 1 | #include "timestamp.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "time.h" // struct tm 7 | 8 | #ifndef __STDC_FORMAT_MACROS 9 | #define __STDC_FORMAT_MACROS 10 | #endif 11 | 12 | #include 13 | 14 | using namespace base; 15 | 16 | static_assert(sizeof(Timestamp) == sizeof(int64_t), 17 | "Timestamp is same size as int64_t"); 18 | 19 | std::string Timestamp::toString() const { 20 | char buf[32] = {0}; 21 | int64_t seconds = microSecondsSinceEpoch_ / kMicroSecondsPerSecond; 22 | int64_t microseconds = microSecondsSinceEpoch_ % kMicroSecondsPerSecond; 23 | snprintf(buf, sizeof(buf) - 1, "%" PRId64 ".%06" PRId64 "", seconds, 24 | microseconds); 25 | return buf; 26 | } 27 | 28 | std::string Timestamp::toFormattedString(bool showMicroseconds) const { 29 | char buf[32] = {0}; 30 | time_t seconds = 31 | static_cast(microSecondsSinceEpoch_ / kMicroSecondsPerSecond); 32 | struct tm tm_time; 33 | gmtime_r(&seconds, &tm_time); 34 | 35 | if (showMicroseconds) { 36 | int microseconds = 37 | static_cast(microSecondsSinceEpoch_ % kMicroSecondsPerSecond); 38 | snprintf(buf, sizeof(buf), "%4d%02d%02d %02d:%02d:%02d.%06d", 39 | tm_time.tm_year + 1900, tm_time.tm_mon + 1, tm_time.tm_mday, 40 | tm_time.tm_hour, tm_time.tm_min, tm_time.tm_sec, microseconds); 41 | } else { 42 | snprintf(buf, sizeof(buf), "%4d%02d%02d %02d:%02d:%02d", 43 | tm_time.tm_year + 1900, tm_time.tm_mon + 1, tm_time.tm_mday, 44 | tm_time.tm_hour, tm_time.tm_min, tm_time.tm_sec); 45 | } 46 | return buf; 47 | } 48 | 49 | Timestamp Timestamp::now() { 50 | struct timeval tv; 51 | gettimeofday(&tv, NULL); 52 | int64_t seconds = tv.tv_sec; 53 | return Timestamp(seconds * kMicroSecondsPerSecond + tv.tv_usec); 54 | } 55 | -------------------------------------------------------------------------------- /moe_infinity/common/constants.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | MixtralForCausalLM, 3 | NllbMoeForConditionalGeneration, 4 | OPTForCausalLM, 5 | PretrainedConfig, 6 | SwitchTransformersForConditionalGeneration, 7 | ) 8 | 9 | from ..models.modeling_arctic import ( 10 | ArcticForCausalLM, 11 | ) # TODO: Replace this with huggingface transformers 12 | from ..models.modeling_deepseek import DeepseekV2ForCausalLM 13 | from ..models.modeling_deepseek_v3 import DeepseekV3ForCausalLM 14 | from ..models.modeling_grok.modeling_grok1 import ( 15 | Grok1ModelForCausalLM, 16 | ) # TODO: Replace this with huggingface transformers 17 | 18 | MODEL_MAPPING_NAMES = { 19 | "switch": SwitchTransformersForConditionalGeneration, 20 | "nllb": NllbMoeForConditionalGeneration, 21 | "mixtral": MixtralForCausalLM, 22 | "opt": OPTForCausalLM, 23 | "grok": Grok1ModelForCausalLM, 24 | "arctic": ArcticForCausalLM, 25 | "deepseek": DeepseekV2ForCausalLM, 26 | "deepseek_v3": DeepseekV3ForCausalLM, 27 | } 28 | 29 | MODEL_MAPPING_TYPES = { 30 | "switch": 0, 31 | "nllb": 2, 32 | "mixtral": 4, 33 | "grok": 4, 34 | "arctic": 4, 35 | "deepseek": 5, 36 | "deepseek_v3": 5, 37 | } 38 | 39 | 40 | def parse_expert_type(config: PretrainedConfig) -> int: 41 | architecture = config.architectures[0].lower() 42 | arch = None 43 | for supp_arch in MODEL_MAPPING_NAMES: 44 | if supp_arch in architecture: 45 | arch = supp_arch 46 | break 47 | if arch is None: 48 | raise RuntimeError( 49 | f"The `load_checkpoint_and_dispatch` function does not support the architecture {architecture}. " 50 | f"Please provide a model that is supported by the function. " 51 | f"Supported architectures are {list(MODEL_MAPPING_NAMES.keys())}." 52 | ) 53 | 54 | return MODEL_MAPPING_TYPES[arch] 55 | -------------------------------------------------------------------------------- /core/utils/cuda_utils.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #include "cuda_utils.h" 7 | #include "logger.h" 8 | 9 | int kNumDevices = GetDeviceCount(); 10 | 11 | bool IsDevicePointer(const void* ptr) { 12 | cudaPointerAttributes attr; 13 | cudaError_t err = cudaPointerGetAttributes(&attr, ptr); 14 | if (err != cudaSuccess) { 15 | DLOG_ERROR("cudaPointerGetAttributes failed: ", cudaGetErrorString(err)); 16 | return false; 17 | } 18 | return attr.type == cudaMemoryTypeDevice; 19 | } 20 | 21 | int GetDeviceCount() { 22 | int device_count = 0; 23 | cudaGetDeviceCount(&device_count); 24 | return device_count; 25 | } 26 | 27 | int GetDevice() { 28 | int device_id; 29 | cudaGetDevice(&device_id); 30 | return device_id; 31 | } 32 | 33 | std::size_t GetTotalDeviceMemory(int device_id) { 34 | size_t free_memory, total_memory; 35 | cudaSetDevice(device_id); 36 | cudaMemGetInfo(&free_memory, &total_memory); 37 | return total_memory; 38 | } 39 | 40 | std::size_t GetFreeDeviceMemory(int device_id) { 41 | size_t free_memory, total_memory; 42 | cudaSetDevice(device_id); 43 | cudaMemGetInfo(&free_memory, &total_memory); 44 | return free_memory; 45 | } 46 | 47 | int CudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind) { 48 | return cudaMemcpy(dst, src, count, kind); 49 | } 50 | 51 | int CudaMemcpyAsync(void* dst, const void* src, size_t count, 52 | cudaMemcpyKind kind, cudaStream_t stream) { 53 | return cudaMemcpyAsync(dst, src, count, kind, stream); 54 | } 55 | 56 | void BlockingCudaCopy(int device, void* dst, const void* src, size_t size, 57 | cudaMemcpyKind kind, cudaStream_t stream) { 58 | CUDA_CHECK(cudaSetDevice(device)); 59 | CUDA_CHECK(cudaMemcpyAsync(dst, src, size, kind, stream)); 60 | CUDA_CHECK(cudaStreamSynchronize(stream)); 61 | } 62 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Suggest an idea for MoE-Infinity 3 | title: "[Feature Request] " 4 | labels: [enhancement] 5 | assignees: [futurexy] 6 | body: 7 | - type: checkboxes 8 | id: prerequisites 9 | attributes: 10 | label: Prerequisites 11 | options: 12 | - label: I have searched existing issues and reviewed documentation. 13 | required: true 14 | 15 | - type: textarea 16 | id: problem 17 | attributes: 18 | label: Problem Description 19 | description: Is your feature request related to a problem? Please describe. 20 | placeholder: I'm always frustrated when [...] 21 | validations: 22 | required: true 23 | 24 | - type: textarea 25 | id: solution 26 | attributes: 27 | label: Proposed Solution 28 | description: Describe the solution you'd like. 29 | validations: 30 | required: true 31 | 32 | - type: textarea 33 | id: alternatives 34 | attributes: 35 | label: Alternatives Considered 36 | description: Describe any alternative solutions or features you've considered. 37 | 38 | - type: textarea 39 | id: context 40 | attributes: 41 | label: Additional Context 42 | description: Add any other context, examples, or screenshots about the feature request here. 43 | 44 | - type: dropdown 45 | id: importance 46 | attributes: 47 | label: Importance 48 | options: 49 | - Nice to have 50 | - Important 51 | - Critical 52 | validations: 53 | required: true 54 | 55 | - type: textarea 56 | id: statistics 57 | attributes: 58 | label: Usage Statistics (Optional) 59 | description: | 60 | We'd love to know how you're using MoE-Infinity! If you're comfortable, please share details like your affiliation and use case to help us improve the project. 61 | validations: 62 | required: false 63 | -------------------------------------------------------------------------------- /core/common/pytorch.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | #include "aio/archer_prio_aio_handle.h" 10 | #include "base/noncopyable.h" 11 | 12 | #define CPU_DEVICE torch::Device(torch::kCPU) 13 | #define CUDA_DEVICE(index) torch::Device(torch::kCUDA, index) 14 | #define DISK_DEVICE torch::Device(torch::kMeta) 15 | #define DEFAULT_CUDA_DEVICE torch::Device(torch::kCUDA, 0) 16 | 17 | #define TENSOR_OPTIONS(dtype, target) \ 18 | torch::TensorOptions() \ 19 | .dtype(dtype) \ 20 | .device(target) \ 21 | .requires_grad(false) \ 22 | .memory_format(torch::MemoryFormat::Contiguous) 23 | 24 | #define FLOAT32_TENSOR_OPTIONS(target) TENSOR_OPTIONS(torch::kFloat32, target) 25 | #define FLOAT16_TENSOR_OPTIONS(target) TENSOR_OPTIONS(torch::kFloat16, target) 26 | #define INT32_TENSOR_OPTIONS(target) TENSOR_OPTIONS(torch::kInt32, target) 27 | #define INT64_TENSOR_OPTIONS(target) TENSOR_OPTIONS(torch::kInt64, target) 28 | #define BFLOAT16_TENSOR_OPTIONS(target) TENSOR_OPTIONS(torch::kBFloat16, target) 29 | 30 | #define FAKE_TENSOR_SIZES torch::IntArrayRef({1}) 31 | 32 | inline std::vector list_to_vector(py::list list) { 33 | std::vector vec; 34 | for (auto item : list) { 35 | vec.push_back(item.cast()); 36 | } 37 | return vec; 38 | } 39 | 40 | inline py::list vector_to_list(std::vector& vec) { 41 | py::list list; 42 | for (auto item : vec) { 43 | list.append(item); 44 | } 45 | return list; 46 | } 47 | 48 | inline size_t torch_shape_size(const std::vector& shape, int dtype) { 49 | auto torch_type = torch::ScalarType(dtype); 50 | auto itemsize = torch::empty({1}, torch_type).itemsize(); 51 | size_t size = 1; 52 | for (auto dim : shape) { 53 | size *= dim; 54 | } 55 | size *= itemsize; 56 | return size; 57 | } 58 | -------------------------------------------------------------------------------- /moe_infinity/models/modeling_arctic/tokenization_arctic.py: -------------------------------------------------------------------------------- 1 | """Tokenization classes for Arctic.""" 2 | 3 | from typing import Any, Dict, Optional 4 | 5 | from transformers.models.llama import LlamaTokenizer 6 | 7 | 8 | class ArcticTokenizer(LlamaTokenizer): 9 | def __init__( 10 | self, 11 | vocab_file, 12 | unk_token="", 13 | bos_token="", 14 | eos_token="", 15 | pad_token=None, 16 | sp_model_kwargs: Optional[Dict[str, Any]] = None, 17 | add_bos_token=True, 18 | add_eos_token=False, 19 | clean_up_tokenization_spaces=False, 20 | use_default_system_prompt=False, 21 | spaces_between_special_tokens=False, 22 | legacy=False, 23 | add_prefix_space=True, 24 | **kwargs, 25 | ): 26 | # Same as LlamaTokenizer except default legacy=False. 27 | super().__init__( 28 | vocab_file, 29 | bos_token=bos_token, 30 | eos_token=eos_token, 31 | unk_token=unk_token, 32 | pad_token=pad_token, 33 | sp_model_kwargs=sp_model_kwargs, 34 | add_bos_token=add_bos_token, 35 | add_eos_token=add_eos_token, 36 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 37 | use_default_system_prompt=use_default_system_prompt, 38 | spaces_between_special_tokens=spaces_between_special_tokens, 39 | legacy=legacy, 40 | add_prefix_space=add_prefix_space, 41 | **kwargs, 42 | ) 43 | 44 | @property 45 | def default_chat_template(self): 46 | """ 47 | This template formats inputs in the standard Arctic format. 48 | """ 49 | return ( 50 | "{% for message in messages %}" 51 | "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" 52 | "{% endfor %}" 53 | "{% if add_generation_prompt %}" 54 | "{{ '<|im_start|>assistant\n' }}" 55 | "{% endif %}" 56 | ) 57 | -------------------------------------------------------------------------------- /core/utils/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | inline void throwOnCudaError(cudaError_t error, const char* file, int line, 14 | const char* function, const char* call) { 15 | if (error != cudaSuccess) { 16 | std::stringstream ss; 17 | ss << "CUDA error " << error << " at " << file << ":" << line 18 | << " in function " << function << ": " << cudaGetErrorString(error) 19 | << "\nCall: " << call; 20 | throw std::runtime_error(ss.str()); 21 | } 22 | }; 23 | 24 | #define CUDA_CHECK(call) \ 25 | throwOnCudaError(call, __FILE__, __LINE__, __FUNCTION__, #call) 26 | 27 | int GetDevice(); 28 | bool IsDevicePointer(const void* ptr); 29 | int GetDeviceCount(); 30 | std::size_t GetTotalDeviceMemory(int device_id); 31 | std::size_t GetFreeDeviceMemory(int device_id); 32 | 33 | #define DEVICE_CACHE_LIMIT(gid) GetTotalDeviceMemory(gid) * 0.7 34 | extern int kNumDevices; 35 | 36 | int CudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind); 37 | int CudaMemcpyAsync(void* dst, const void* src, size_t count, 38 | cudaMemcpyKind kind, cudaStream_t stream = 0); 39 | void BlockingCudaCopy(int device, void* dst, const void* src, size_t size, 40 | cudaMemcpyKind kind, cudaStream_t stream); 41 | 42 | struct CUDADeviceAllocator { 43 | void* operator()(std::size_t size) { 44 | void* ptr; 45 | CUDA_CHECK(cudaMalloc(&ptr, size)); 46 | return ptr; 47 | } 48 | }; 49 | 50 | struct CUDADeviceDeleter { 51 | void operator()(void* ptr) { CUDA_CHECK(cudaFree(ptr)); } 52 | }; 53 | 54 | struct CUDAHostAllocator { 55 | void* operator()(std::size_t size) { 56 | void* ptr; 57 | CUDA_CHECK(cudaHostAlloc(&ptr, size, cudaHostAllocDefault)); 58 | return ptr; 59 | } 60 | }; 61 | 62 | struct CUDAHostDeleter { 63 | void operator()(void* ptr) { CUDA_CHECK(cudaFreeHost(ptr)); } 64 | }; 65 | -------------------------------------------------------------------------------- /op_builder/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Microsoft Corporation. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # DeepSpeed Team 6 | 7 | # op_builder/__init__.py 8 | # 9 | # Part of the DeepSpeed Project, under the Apache-2.0 License. 10 | # See https://github.com/microsoft/DeepSpeed/blob/master/LICENSE for license information. 11 | # SPDX-License-Identifier: Apache-2.0 12 | 13 | # MoE-Infinity: replced builder_closure with PrefetchBuilder 14 | 15 | import importlib 16 | import os 17 | import pkgutil 18 | import sys 19 | 20 | from .builder import OpBuilder, get_default_compute_capabilities 21 | from .prefetch import PrefetchBuilder 22 | 23 | # Do not remove, required for abstract accelerator to detect if we have a deepspeed or 3p op_builder 24 | __deepspeed__ = True 25 | 26 | # List of all available op builders from deepspeed op_builder 27 | try: 28 | import moe_infinity.ops.op_builder # noqa: F401 29 | 30 | op_builder_dir = "moe_infinity.ops.op_builder" 31 | except ImportError: 32 | op_builder_dir = "op_builder" 33 | 34 | this_module = sys.modules[__name__] 35 | 36 | 37 | def builder_closure(member_name): 38 | return PrefetchBuilder 39 | 40 | 41 | # reflect builder names and add builder closure, such as 'TransformerBuilder()' creates op builder wrt current accelerator 42 | for _, module_name, _ in pkgutil.iter_modules( 43 | [os.path.dirname(this_module.__file__)] 44 | ): 45 | if module_name != "all_ops" and module_name != "builder": 46 | module = importlib.import_module( 47 | f".{module_name}", package=op_builder_dir 48 | ) 49 | for member_name in module.__dir__(): 50 | if ( 51 | member_name.endswith("Builder") 52 | and member_name != "OpBuilder" 53 | and member_name != "CUDAOpBuilder" 54 | ): 55 | # assign builder name to variable with same name 56 | # the following is equivalent to i.e. TransformerBuilder = "TransformerBuilder" 57 | this_module.__dict__[member_name] = builder_closure(member_name) 58 | -------------------------------------------------------------------------------- /core/prefetch/task_thread.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #include "task_thread.h" 7 | 8 | #include 9 | 10 | #include 11 | 12 | #include "common/time.h" 13 | #include "utils/logger.h" 14 | 15 | std::atomic_uint32_t kGPUCounter{0}; 16 | 17 | void SetThreadScheduling(std::thread& th, int policy, int priority) { 18 | sched_param sch_params; 19 | sch_params.sched_priority = priority; 20 | if (pthread_setschedparam(th.native_handle(), policy, &sch_params)) { 21 | std::cerr << "Failed to set Thread scheduling : " << std::strerror(errno) 22 | << std::endl; 23 | assert(false); 24 | } 25 | } 26 | 27 | void SetThreadAffinity(std::thread& th, int cpu_id) { 28 | cpu_set_t cpuset; 29 | CPU_ZERO(&cpuset); 30 | CPU_SET(cpu_id, &cpuset); 31 | if (pthread_setaffinity_np(th.native_handle(), sizeof(cpu_set_t), &cpuset)) { 32 | std::cerr << "Failed to set Thread affinity : " << std::strerror(errno) 33 | << std::endl; 34 | assert(false); 35 | } 36 | } 37 | 38 | void SetThreadAffinity(std::thread& th) { 39 | // get number of cpus 40 | int num_cpus = sysconf(_SC_NPROCESSORS_ONLN); 41 | kCPUCounter++; 42 | int cpu_id = kCPUCounter % num_cpus; 43 | cpu_set_t cpuset; 44 | CPU_ZERO(&cpuset); 45 | CPU_SET(cpu_id, &cpuset); 46 | 47 | if (pthread_setaffinity_np(th.native_handle(), sizeof(cpu_set_t), &cpuset)) { 48 | std::cerr << "Failed to set Thread affinity : " << std::strerror(errno) 49 | << std::endl; 50 | assert(false); 51 | } 52 | } 53 | 54 | void SetThreadAffinity(pid_t tid) { 55 | // get number of cpus 56 | int num_cpus = sysconf(_SC_NPROCESSORS_ONLN); 57 | kCPUCounter++; 58 | int cpu_id = kCPUCounter % num_cpus; 59 | cpu_set_t cpuset; 60 | CPU_ZERO(&cpuset); 61 | CPU_SET(cpu_id, &cpuset); 62 | 63 | if (pthread_setaffinity_np(tid, sizeof(cpu_set_t), &cpuset)) { 64 | std::cerr << "Failed to set Thread affinity : " << std::strerror(errno) 65 | << std::endl; 66 | assert(false); 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /.github/workflows/scripts/free-disk-space.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | # 19 | # The Azure provided machines typically have the following disk allocation: 20 | # Total space: 85GB 21 | # Allocated: 67 GB 22 | # Free: 17 GB 23 | # This script frees up 28 GB of disk space by deleting unneeded packages and 24 | # large directories. 25 | # The Flink end to end tests download and generate more than 17 GB of files, 26 | # causing unpredictable behavior and build failures. 27 | # 28 | echo "==============================================================================" 29 | echo "Freeing up disk space on CI system" 30 | echo "==============================================================================" 31 | 32 | echo "Listing 100 largest packages" 33 | dpkg-query -Wf '${Installed-Size}\t${Package}\n' | sort -n | tail -n 100 34 | df -h 35 | echo "Removing large packages" 36 | sudo apt-get remove -y '^ghc-8.*' 37 | sudo apt-get remove -y '^dotnet-.*' 38 | sudo apt-get remove -y '^llvm-.*' 39 | sudo apt-get remove -y 'php.*' 40 | sudo apt-get remove -y azure-cli google-cloud-sdk hhvm google-chrome-stable firefox powershell mono-devel 41 | sudo apt-get autoremove -y 42 | sudo apt-get clean 43 | df -h 44 | echo "Removing large directories" 45 | # deleting 15GB 46 | rm -rf /usr/share/dotnet/ 47 | rm -rf /opt/hostedtoolcache/ 48 | df -h 49 | -------------------------------------------------------------------------------- /moe_infinity/runtime/state_dict.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import torch 4 | from sllm_store.device_map_utils import _compute_device_placement_from_map_fast 5 | from sllm_store.utils import get_no_split_modules, get_tied_no_split_modules 6 | from transformers import PretrainedConfig, PreTrainedModel 7 | 8 | from moe_infinity.utils.hf_config import parse_expert_id 9 | 10 | 11 | def partition_offloading_state_dict( 12 | state_dict: Dict[str, Any], config: PretrainedConfig 13 | ): 14 | non_offloading_state_dict = {} 15 | offloading_state_dict = {} 16 | 17 | for key, value in state_dict.items(): 18 | layer_id, expert_id = parse_expert_id(key) 19 | if layer_id is None: 20 | non_offloading_state_dict[key] = value 21 | else: 22 | offloading_state_dict[key] = value 23 | 24 | return non_offloading_state_dict, offloading_state_dict 25 | 26 | 27 | def load_non_offloading_state_dict( 28 | model: PreTrainedModel, state_dict: Dict[str, Any] 29 | ): 30 | config = model.config 31 | no_split_modules = get_no_split_modules(model, model._no_split_modules) 32 | tied_no_split_modules = get_tied_no_split_modules(model, no_split_modules) 33 | 34 | device_map = _compute_device_placement_from_map_fast( 35 | no_split_modules, tied_no_split_modules, "auto" 36 | ) 37 | 38 | # model.load_state_dict(state_dict, strict=False) 39 | 40 | for key, param in state_dict.items(): 41 | levels = key.split(".") 42 | # If the key cannot be found in the model, skip it 43 | if not hasattr(model, levels[0]): 44 | continue 45 | weight = model.__getattr__(levels[0]) 46 | for l in levels[1:]: 47 | if not hasattr(weight, l): 48 | weight = None 49 | break 50 | weight = weight.__getattr__(l) 51 | if weight is not None: 52 | weight.data = param.to(config.torch_dtype) 53 | byte_size = ( 54 | sum(p.numel() * p.element_size() for p in model.parameters()) / 1e9 55 | ) 56 | print(f"Model non offloading size: {byte_size:.2f} GB", flush=True) 57 | -------------------------------------------------------------------------------- /core/base/date.cc: -------------------------------------------------------------------------------- 1 | // Use of this source code is governed by a BSD-style license 2 | // that can be found in the License file. 3 | // 4 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 5 | 6 | #include "date.h" 7 | #include // snprintf 8 | #include 9 | #include "time.h" // struct tm 10 | 11 | namespace base { 12 | namespace detail { 13 | 14 | char require_32_bit_integer_at_least[sizeof(int) >= sizeof(int32_t) ? 1 : -1]; 15 | 16 | // algorithm and explanation see: 17 | // http://www.faqs.org/faqs/calendars/faq/part2/ 18 | // http://blog.csdn.net/Solstice 19 | 20 | int getJulianDayNumber(int year, int month, int day) { 21 | (void)require_32_bit_integer_at_least; // no warning please 22 | int a = (14 - month) / 12; 23 | int y = year + 4800 - a; 24 | int m = month + 12 * a - 3; 25 | return day + (153 * m + 2) / 5 + y * 365 + y / 4 - y / 100 + y / 400 - 32045; 26 | } 27 | 28 | struct Date::YearMonthDay getYearMonthDay(int julianDayNumber) { 29 | int a = julianDayNumber + 32044; 30 | int b = (4 * a + 3) / 146097; 31 | int c = a - ((b * 146097) / 4); 32 | int d = (4 * c + 3) / 1461; 33 | int e = c - ((1461 * d) / 4); 34 | int m = (5 * e + 2) / 153; 35 | Date::YearMonthDay ymd; 36 | ymd.day = e - ((153 * m + 2) / 5) + 1; 37 | ymd.month = m + 3 - 12 * (m / 10); 38 | ymd.year = b * 100 + d - 4800 + (m / 10); 39 | return ymd; 40 | } 41 | } // namespace detail 42 | const int Date::kJulianDayOf1970_01_01 = detail::getJulianDayNumber(1970, 1, 1); 43 | } // namespace base 44 | 45 | using namespace base; 46 | using namespace base::detail; 47 | 48 | Date::Date(int y, int m, int d) 49 | : julianDayNumber_(getJulianDayNumber(y, m, d)) {} 50 | 51 | Date::Date(const struct tm& t) 52 | : julianDayNumber_( 53 | getJulianDayNumber(t.tm_year + 1900, t.tm_mon + 1, t.tm_mday)) {} 54 | 55 | std::string Date::toIsoString() const { 56 | char buf[32]; 57 | YearMonthDay ymd(yearMonthDay()); 58 | snprintf(buf, sizeof buf, "%4d-%02d-%02d", ymd.year, ymd.month, ymd.day); 59 | return buf; 60 | } 61 | 62 | Date::YearMonthDay Date::yearMonthDay() const { 63 | return getYearMonthDay(julianDayNumber_); 64 | } 65 | -------------------------------------------------------------------------------- /core/utils/threadsafe_queue.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "base/noncopyable.h" 8 | 9 | template 10 | class ThreadSafeQueue : public base::noncopyable { 11 | public: 12 | ThreadSafeQueue() = default; 13 | 14 | // Disable copy constructor and assignment to avoid accidental data races. 15 | ThreadSafeQueue(const ThreadSafeQueue&) = delete; 16 | ThreadSafeQueue& operator=(const ThreadSafeQueue&) = delete; 17 | 18 | // Pushes an item into the queue (thread-safe). 19 | void Push(T& item) { 20 | { 21 | std::lock_guard lock(mutex_); 22 | queue_.push(std::move(item)); 23 | } 24 | cond_.notify_one(); 25 | } 26 | 27 | // Pops an item from the queue (blocking). 28 | bool Pop(T& item) { 29 | std::unique_lock lock(mutex_); 30 | cond_.wait(lock, [this] { return !queue_.empty(); }); 31 | 32 | item = std::move(queue_.front()); 33 | queue_.pop(); 34 | return true; 35 | } 36 | 37 | // Tries to pop an item without blocking. Returns false if the queue is empty. 38 | bool TryPop(T& item) { 39 | std::lock_guard lock(mutex_); 40 | if (queue_.empty()) { 41 | return false; 42 | } 43 | item = std::move(queue_.front()); 44 | queue_.pop(); 45 | return true; 46 | } 47 | 48 | // Returns true if the queue is empty. 49 | bool Empty() const { 50 | std::lock_guard lock(mutex_); 51 | return queue_.empty(); 52 | } 53 | 54 | protected: 55 | std::queue queue_; 56 | mutable std::mutex mutex_; 57 | std::condition_variable cond_; 58 | }; 59 | 60 | // recycling queue implementation, popped item is pushed back to the queue 61 | template 62 | class ThreadSafeRecyclingQueue : public ThreadSafeQueue { 63 | public: 64 | ThreadSafeRecyclingQueue() = default; 65 | 66 | void Pop(T& item) override { 67 | ThreadSafeQueue::Pop(item); 68 | Push(item); 69 | } 70 | 71 | bool TryPop(T& item) override { 72 | bool success = ThreadSafeQueue::TryPop(item); 73 | Push(item); 74 | return success; 75 | } 76 | }; 77 | -------------------------------------------------------------------------------- /core/base/file_util.h: -------------------------------------------------------------------------------- 1 | // Copyright 2010, Shuo Chen. All rights reserved. 2 | // http://code.google.com/p/muduo/ 3 | // 4 | // Use of this source code is governed by a BSD-style license 5 | // that can be found in the License file. 6 | 7 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 8 | // 9 | // This is a public header file, it must only include public header files. 10 | 11 | #ifndef MUDUO_BASE_FILEUTIL_H 12 | #define MUDUO_BASE_FILEUTIL_H 13 | 14 | #include "string_piece.h" 15 | 16 | namespace base { 17 | 18 | namespace FileUtil { 19 | 20 | // read small file < 64KB 21 | class ReadSmallFile : noncopyable { 22 | public: 23 | ReadSmallFile(StringArg filename); 24 | ~ReadSmallFile(); 25 | 26 | // return errno 27 | template 28 | int readToString(int maxSize, String* content, int64_t* fileSize, 29 | int64_t* modifyTime, int64_t* createTime); 30 | 31 | /// Read at maximum kBufferSize into buf_ 32 | // return errno 33 | int readToBuffer(int* size); 34 | 35 | const char* buffer() const { return buf_; } 36 | 37 | static const int kBufferSize = 64 * 1024; 38 | 39 | private: 40 | int fd_; 41 | int err_; 42 | char buf_[kBufferSize]; 43 | }; 44 | 45 | // read the file content, returns errno if error happens. 46 | template 47 | int readFile(StringArg filename, int maxSize, String* content, 48 | int64_t* fileSize = NULL, int64_t* modifyTime = NULL, 49 | int64_t* createTime = NULL) { 50 | ReadSmallFile file(filename); 51 | return file.readToString(maxSize, content, fileSize, modifyTime, createTime); 52 | } 53 | 54 | // not thread safe 55 | class AppendFile : noncopyable { 56 | public: 57 | explicit AppendFile(StringArg filename); 58 | 59 | ~AppendFile(); 60 | 61 | void append(const char* logline, const size_t len); 62 | 63 | void flush(); 64 | 65 | size_t writtenBytes() const { return writtenBytes_; } 66 | 67 | private: 68 | size_t write(const char* logline, size_t len); 69 | 70 | FILE* fp_; 71 | char buffer_[64 * 1024]; 72 | size_t writtenBytes_; 73 | }; 74 | } // namespace FileUtil 75 | 76 | } // namespace base 77 | 78 | #endif // MUDUO_BASE_FILEUTIL_H 79 | -------------------------------------------------------------------------------- /core/aio/archer_tensor_handle.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | 10 | #include "archer_prio_aio_handle.h" 11 | #include "archer_tensor_index.h" 12 | #include "base/noncopyable.h" 13 | 14 | extern const char* ARCHER_PARAM_NAME; 15 | extern const char* ARCHER_IHDEX_NAME; 16 | 17 | class ArcherTensorHandle : public base::noncopyable { 18 | public: 19 | explicit ArcherTensorHandle(const std::string& prefix); 20 | ~ArcherTensorHandle() = default; 21 | 22 | void StoreTensor(const std::uint32_t tensor_id, torch::Tensor& buffer); 23 | void RegisterTensor(std::uint32_t tensor_id, torch::Tensor& buffer); 24 | void SetTensor(std::uint32_t tensor_id, torch::Tensor& buffer); 25 | void SetTensor(std::uint32_t tensor_id, torch::Tensor& buffer, 26 | const torch::Device& device); 27 | 28 | void ReadTensor(const std::uint32_t tensor_id, void* memory_ptr, 29 | bool on_demand = false); 30 | 31 | void MoveTensor(const std::uint32_t tensor_id, 32 | const torch::Device& src_device, 33 | const torch::Device& dst_device); 34 | 35 | std::uint32_t GetTensorId(void* tensor) const; 36 | void UpdateTensorMap(void* old_data_ptr, void* new_data_ptr); 37 | 38 | bool IsTensorIndexInitialized() const { return is_serialized_; } 39 | 40 | int64_t GetTensorSizeAligned(const std::uint32_t tensor_id) const; 41 | torch::TensorOptions GetTensorOptions(const std::uint32_t tensor_id) const; 42 | 43 | private: 44 | // bool ValidateTensorMove(const std::uint32_t tensor_id, 45 | // const torch::Device& src_device, 46 | // const torch::Device& dst_device); 47 | std::string GetIndexFileName(const std::uint32_t file_id) const; 48 | 49 | private: 50 | std::string prefix_; 51 | ArcherPrioAioHandle prio_aio_handle_; 52 | std::uint32_t file_id_; 53 | std::int64_t file_offset_; 54 | std::unordered_map tensor_to_id_; 55 | 56 | std::mutex mutex_; 57 | 58 | bool is_serialized_ = false; 59 | }; 60 | 61 | extern std::unique_ptr kArcherTensorHandle; 62 | -------------------------------------------------------------------------------- /core/utils/cache.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | template 9 | class LFUCache { 10 | private: 11 | struct Node { 12 | KeyType key; 13 | ValueType value; 14 | int freq; 15 | Node(KeyType k, ValueType v) : key(k), value(v), freq(1) {} 16 | }; 17 | 18 | int capacity; 19 | int minFreq; 20 | std::unordered_map::iterator> keyNodeMap; 21 | std::unordered_map> freqListMap; 22 | 23 | public: 24 | LFUCache(int capacity) : capacity(capacity), minFreq(0) {} 25 | 26 | ValueType get(KeyType key) { 27 | if (!keyNodeMap.count(key)) { 28 | throw std::range_error("Key not found"); 29 | } 30 | 31 | auto node = keyNodeMap[key]; 32 | int freq = node->freq; 33 | freqListMap[freq].erase(node); 34 | if (freqListMap[freq].empty()) { 35 | freqListMap.erase(freq); 36 | if (minFreq == freq) minFreq += 1; 37 | } 38 | 39 | node->freq += 1; 40 | freqListMap[node->freq].push_front(*node); 41 | keyNodeMap[key] = freqListMap[node->freq].begin(); 42 | 43 | return node->value; 44 | } 45 | 46 | void put(KeyType key, ValueType value) { 47 | if (capacity == 0) return; 48 | 49 | if (keyNodeMap.count(key)) { 50 | auto node = keyNodeMap[key]; 51 | node->value = value; 52 | get(key); // update the node's frequency 53 | return; 54 | } 55 | 56 | if (keyNodeMap.size() == capacity) { 57 | auto node = freqListMap[minFreq].back(); 58 | keyNodeMap.erase(node.key); 59 | freqListMap[minFreq].pop_back(); 60 | if (freqListMap[minFreq].empty()) { 61 | freqListMap.erase(minFreq); 62 | } 63 | } 64 | 65 | minFreq = 1; 66 | Node newNode(key, value); 67 | freqListMap[minFreq].push_front(newNode); 68 | keyNodeMap[key] = freqListMap[minFreq].begin(); 69 | } 70 | 71 | void reset() { 72 | for (auto& freq_pair : freqListMap) { 73 | for (auto& node : freq_pair.second) { 74 | node.freq = 1; // reset frequency to 1 75 | freqListMap[1].push_back(node); 76 | } 77 | freq_pair.second.clear(); 78 | } 79 | freqListMap.erase(++freqListMap.begin(), freqListMap.end()); 80 | minFreq = 1; 81 | } 82 | }; 83 | -------------------------------------------------------------------------------- /moe_infinity/memory/expert_prefetcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) EfficientMoE. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # EfficientMoE Team 5 | 6 | 7 | import numpy as np 8 | from transformers import PretrainedConfig 9 | 10 | from moe_infinity.utils import parse_moe_param 11 | 12 | 13 | class ExpertPrefetcher(object): 14 | cache_file_rd = None 15 | first_k_dense_replace: int = 0 16 | 17 | def __init__(self, config: PretrainedConfig): 18 | print(config) 19 | self.num_layers, self.num_experts, self.num_encoder_layers = ( 20 | parse_moe_param(config) 21 | ) 22 | 23 | def set_archer_engine(self, archer_engine): 24 | global _expert_prefetcher 25 | _expert_prefetcher = archer_engine 26 | self.archer_engine = archer_engine 27 | 28 | def prefetch_experts_list(self, layer_id, expert_list): 29 | tensor_ids = [] 30 | for j in expert_list: 31 | tensor_ids.append(self.expert_tensor_map[(layer_id, j)]) 32 | for tensor_id in tensor_ids: 33 | gpu_id = self.archer_engine.get_node_default_device([tensor_id]) 34 | self.archer_engine.enqueue_prefetch(tensor_id, gpu_id) 35 | 36 | def fetch_experts_lock_cache(self, layer_id, expert_list): 37 | tensor_ids = [] 38 | for j in expert_list: 39 | tensor_ids.append(self.expert_tensor_map[(layer_id, j)]) 40 | self.archer_engine.replace_cache_candidates(tensor_ids) 41 | 42 | def prefetch_experts(self, layer_id, expert_matrix): 43 | expert_list = [] 44 | # print("expert_tensor_map", self.expert_tensor_map) 45 | for i in range(layer_id, self.num_layers): 46 | for j in range(self.num_experts): 47 | if expert_matrix[i, j] > 0: 48 | expert_list.append( 49 | (self.expert_tensor_map[(i, j)], expert_matrix[i, j]) 50 | ) 51 | ordered_expert_list = sorted( 52 | expert_list, key=lambda x: x[1], reverse=True 53 | ) 54 | tensor_ids = [x[0] for x in ordered_expert_list] 55 | assert len(np.unique(tensor_ids)) == len(tensor_ids) 56 | self.archer_engine.replace_cache_candidates(tensor_ids) 57 | for tensor_id in tensor_ids: 58 | gpu_id = self.archer_engine.get_node_default_device([tensor_id]) 59 | self.archer_engine.enqueue_prefetch(tensor_id, gpu_id) 60 | -------------------------------------------------------------------------------- /moe_infinity/models/modeling_grok/configuration_grok1.py: -------------------------------------------------------------------------------- 1 | from transformers.configuration_utils import PretrainedConfig 2 | 3 | 4 | class Grok1Config(PretrainedConfig): 5 | model_type = "grok-1" 6 | keys_to_ignore_at_inference = ["past_key_values"] 7 | 8 | def __init__( 9 | self, 10 | vocab_size=32000, 11 | hidden_size=4096, 12 | intermediate_size=32768, 13 | num_hidden_layers=32, 14 | num_attention_heads=32, 15 | num_key_value_heads=32, 16 | attn_output_multiplier=1.0, 17 | max_attn_value=1.0, 18 | max_position_embeddings=4096, 19 | embedding_multiplier_scale: float = 1.0, 20 | output_multiplier_scale: float = 1.0, 21 | rms_norm_eps=1e-5, 22 | use_cache=True, 23 | pad_token_id=None, 24 | bos_token_id=1, 25 | eos_token_id=2, 26 | tie_word_embeddings=True, 27 | num_experts_per_tok=2, 28 | num_experts=8, 29 | output_router_logits=False, 30 | router_aux_loss_coef=0.001, 31 | **kwargs, 32 | ): 33 | self.vocab_size = vocab_size 34 | self.attn_output_multiplier = attn_output_multiplier 35 | self.max_attn_value = max_attn_value 36 | self.max_position_embeddings = max_position_embeddings 37 | self.embedding_multiplier_scale = embedding_multiplier_scale 38 | self.output_multiplier_scale = output_multiplier_scale 39 | self.hidden_size = hidden_size 40 | self.intermediate_size = intermediate_size 41 | self.num_hidden_layers = num_hidden_layers 42 | self.num_attention_heads = num_attention_heads 43 | 44 | # for backward compatibility 45 | if num_key_value_heads is None: 46 | num_key_value_heads = num_attention_heads 47 | 48 | self.num_key_value_heads = num_key_value_heads 49 | self.rms_norm_eps = rms_norm_eps 50 | self.use_cache = use_cache 51 | 52 | self.num_experts_per_tok = num_experts_per_tok 53 | self.num_experts = num_experts 54 | self.output_router_logits = output_router_logits 55 | self.router_aux_loss_coef = router_aux_loss_coef 56 | super().__init__( 57 | pad_token_id=pad_token_id, 58 | bos_token_id=bos_token_id, 59 | eos_token_id=eos_token_id, 60 | tie_word_embeddings=tie_word_embeddings, 61 | **kwargs, 62 | ) 63 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: File an issue about a bug in MoE-Infinity. 3 | title: "[BUG] " 4 | labels: [bug] 5 | assignees: [] 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | Please provide as much detail as possible to help us address the issue efficiently. If you're unsure if this is a bug, consider asking by creating an issue. 11 | 12 | - type: checkboxes 13 | id: prerequisites 14 | attributes: 15 | label: Prerequisites 16 | options: 17 | - label: I have read the [MoE-Infinity documentation](). 18 | required: true 19 | - label: I have searched the [Issue Tracker](https://github.com/EfficientMoE/MoE-Infinity/issues) to ensure this hasn't been reported before. 20 | required: true 21 | 22 | - type: textarea 23 | id: system-info 24 | attributes: 25 | label: System Information 26 | description: Please provide details about your environment (OS, Python version, GPU, etc.). 27 | validations: 28 | required: true 29 | 30 | - type: textarea 31 | id: description 32 | attributes: 33 | label: Problem Description 34 | description: Provide a clear description of the bug. 35 | validations: 36 | required: true 37 | 38 | - type: textarea 39 | id: reproduction 40 | attributes: 41 | label: Steps to Reproduce 42 | description: Please provide code snippets and steps to reproduce the issue. 43 | value: | 44 | Code snippets: 45 | ```python 46 | 47 | ``` 48 | 49 | Steps to reproduce: 50 | 1. 51 | 2. 52 | 3. 53 | validations: 54 | required: true 55 | 56 | - type: textarea 57 | id: expected 58 | attributes: 59 | label: Expected Behavior 60 | description: What did you expect to happen? 61 | 62 | - type: textarea 63 | id: additional-context 64 | attributes: 65 | label: Additional Context 66 | description: Add any other relevant information, screenshots, or suggested fixes. 67 | 68 | - type: textarea 69 | id: statistics 70 | attributes: 71 | label: Usage Statistics (Optional) 72 | description: | 73 | We'd love to know how you're using MoE-Infinity! If you're comfortable, please share details like your affiliation and use case to help us improve the project. 74 | validations: 75 | required: false 76 | -------------------------------------------------------------------------------- /core/utils/logger.cpp: -------------------------------------------------------------------------------- 1 | // // Copyright (c) EfficientMoE. 2 | // // SPDX-License-Identifier: Apache-2.0 3 | 4 | // // EfficientMoE Team 5 | 6 | // #include "logger.h" 7 | 8 | // #include 9 | // #include 10 | // #include 11 | 12 | // std::once_flag kLoggerFlag; 13 | // int kLogLevel = -1; 14 | // std::mutex kLogMutex; 15 | 16 | // int str2level(const char* level) 17 | // { 18 | // if (strcmp(level, "info") == 0) { 19 | // return kInfo; 20 | // } else if (strcmp(level, "error") == 0) { 21 | // return kError; 22 | // } else if (strcmp(level, "warn") == 0) { 23 | // return kWarn; 24 | // } else if (strcmp(level, "debug") == 0) { 25 | // return kDebug; 26 | // } else if (strcmp(level, "fatal") == 0) { 27 | // return kFatal; 28 | // } else { 29 | // return -1; 30 | // } 31 | // } 32 | 33 | // std::string level2str(int level) 34 | // { 35 | // switch (level) { 36 | // case kInfo: return "INFO"; 37 | // case kError: return "ERROR"; 38 | // case kWarn: return "WARN"; 39 | // case kDebug: return "DEBUG"; 40 | // case kFatal: return "FATAL"; 41 | // default: return "UNKNOWN"; 42 | // } 43 | // } 44 | 45 | // std::string formatstr() 46 | // { 47 | // // get actual values in the format 48 | // auto time = std::chrono::system_clock::now(); 49 | // auto ms = 50 | // std::chrono::duration_cast(time.time_since_epoch()) 51 | // % 1000; auto timer = std::chrono::system_clock::to_time_t(time); auto tm 52 | // = *std::localtime(&timer); 53 | 54 | // auto year = tm.tm_year + 1900; 55 | // auto month = tm.tm_mon + 1; 56 | // auto day = tm.tm_mday; 57 | // auto hour = tm.tm_hour; 58 | // auto min = tm.tm_min; 59 | // auto sec = tm.tm_sec; 60 | // auto msec = ms.count(); 61 | 62 | // char buf[128]; 63 | // sprintf(buf, "%04d-%02d-%02d %02d:%02d:%02d.%03ld", year, month, day, 64 | // hour, min, sec, msec); return std::string(buf) + " "; 65 | // } 66 | 67 | // void InitLogger() 68 | // { 69 | // std::call_once(kLoggerFlag, []() { 70 | // printf("SPDLOG_LEVEL : %s\n", getenv("SPDLOG_LEVEL")); 71 | // if (getenv("SPDLOG_LEVEL")) { 72 | // kLogLevel = str2level(getenv("SPDLOG_LEVEL")); 73 | // } else { 74 | // kLogLevel = kInfo; 75 | // } 76 | // }); 77 | // } 78 | -------------------------------------------------------------------------------- /core/aio/archer_prio_aio_handle.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "archer_aio_threadpool.h" 15 | #include "utils/simple_object_pool.h" 16 | 17 | static const std::size_t kAioAlignment = 4096; 18 | using IocbPtr = typename SimpleObjectPool::Pointer; 19 | 20 | struct AioRequest { 21 | // std::vector iocbs; 22 | std::vector callbacks; 23 | std::mutex mutex; 24 | std::condition_variable cv; 25 | std::atomic pending_callbacks; 26 | }; 27 | 28 | class ArcherPrioAioContext { 29 | public: 30 | explicit ArcherPrioAioContext(const int block_size); 31 | ~ArcherPrioAioContext(); 32 | 33 | void AcceptRequest(std::shared_ptr& io_request, bool high_prio); 34 | 35 | void Schedule(); 36 | std::vector PrepIocbs(const bool read_op, void* buffer, 37 | const int fd, const int block_size, 38 | const std::int64_t offset, 39 | const std::int64_t total_size); 40 | 41 | private: 42 | std::int64_t block_size_; 43 | 44 | std::mutex io_queue_high_mutex_; 45 | std::mutex io_queue_low_mutex_; 46 | 47 | std::deque> io_queue_high_; 48 | std::deque> io_queue_low_; 49 | 50 | std::unique_ptr thread_pool_; 51 | }; 52 | 53 | class ArcherPrioAioHandle { 54 | public: 55 | explicit ArcherPrioAioHandle(const std::string& prefix); 56 | ~ArcherPrioAioHandle(); 57 | 58 | std::int64_t Read(const std::string& filename, void* buffer, 59 | const bool high_prio, const std::int64_t num_bytes, 60 | const std::int64_t offset); 61 | std::int64_t Write(const std::string& filename, const void* buffer, 62 | const bool high_prio, const std::int64_t num_bytes, 63 | const std::int64_t offset); 64 | 65 | private: 66 | void Run(); // io submit thread function 67 | 68 | private: 69 | bool time_to_exit_; 70 | std::thread thread_; 71 | std::mutex file_set_mutex_; 72 | std::unordered_map file_set_; 73 | ArcherPrioAioContext aio_context_; 74 | }; 75 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # Package Release Guide 2 | 3 | This document describes the process of releasing a new version of the MoE-Infinity-Rel package. 4 | 5 | ## Automated Release Process 6 | 7 | The release mechanism is fully automated through a GitHub Actions workflow, which is defined in the `.github/workflows/publish.yml` file. This workflow triggers upon the creation and publication of a new version tag formatted as `v*` within the repository. 8 | 9 | ### Steps to Release a New Version 10 | To release a new version, such as version 1.0.0, please adhere to the following procedure: 11 | 12 | 1. Update Version: Modify the version number in the setup.py file to reflect the new release version. 13 | 2. Commit Changes: Commit these changes with an appropriate commit message that summarizes the update, such as "Update version for 1.0.0 release". 14 | 3. Create and Push Tag: Tag the latest commit with the new version number and push the tag to the repository. Use the following commands to accomplish this: 15 | ```bash 16 | git tag v1.0.0 17 | git push origin v1.0.0 18 | ``` 19 | 20 | Upon the successful push of the tag, the workflow will creata a new release draft, build the package and publish it to the GitHub Package Registry and PyPI repositories. 21 | 22 | 23 | ## Manual Package Building and Publishing 24 | 25 | For developers who prefer to manually build and publish their package to PyPI, the following steps provide a detailed guide to execute this process effectively. 26 | 27 | 1. Start by cloning the repository and navigating to the root directory of the package: 28 | ```bash 29 | git clone https://github.com/EfficientMoE/MoE-Infinity.git 30 | cd MoE-Infinity 31 | ``` 32 | 2. Install the required dependencies to build the package: 33 | ```bash 34 | pip install -r requirements.txt 35 | pip install build 36 | ``` 37 | 3. Build the source distribution and wheel for the package using: 38 | ```bash 39 | BUILD_OPS=1 python -m build 40 | ``` 41 | This command generates the package files in the `dist/` directory. 42 | 4. Upload the built package to the PyPI repository using `twine`: 43 | ```bash 44 | twine upload dist/* 45 | ``` 46 | Ensure that you have the necessary credentials configured for `twine` to authenticate to PyPI. 47 | 48 | 49 | To build the package wheel for multiple Python versions, you should execute the build process individually for each version by specifying the corresponding Python interpreter. 50 | -------------------------------------------------------------------------------- /core/base/date.h: -------------------------------------------------------------------------------- 1 | // Use of this source code is governed by a BSD-style license 2 | // that can be found in the License file. 3 | // 4 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 5 | 6 | #ifndef MUDUO_BASE_DATE_H 7 | #define MUDUO_BASE_DATE_H 8 | 9 | #include "copyable.h" 10 | #include "types.h" 11 | 12 | struct tm; 13 | 14 | namespace base { 15 | 16 | /// 17 | /// Date in Gregorian calendar. 18 | /// 19 | /// This class is immutable. 20 | /// It's recommended to pass it by value, since it's passed in register on x64. 21 | /// 22 | class Date : public base::copyable 23 | // public boost::less_than_comparable, 24 | // public boost::equality_comparable 25 | { 26 | public: 27 | struct YearMonthDay { 28 | int year; // [1900..2500] 29 | int month; // [1..12] 30 | int day; // [1..31] 31 | }; 32 | 33 | static const int kDaysPerWeek = 7; 34 | static const int kJulianDayOf1970_01_01; 35 | 36 | /// 37 | /// Constructs an invalid Date. 38 | /// 39 | Date() : julianDayNumber_(0) {} 40 | 41 | /// 42 | /// Constructs a yyyy-mm-dd Date. 43 | /// 44 | /// 1 <= month <= 12 45 | Date(int year, int month, int day); 46 | 47 | /// 48 | /// Constructs a Date from Julian Day Number. 49 | /// 50 | explicit Date(int julianDayNum) : julianDayNumber_(julianDayNum) {} 51 | 52 | /// 53 | /// Constructs a Date from struct tm 54 | /// 55 | explicit Date(const struct tm&); 56 | 57 | // default copy/assignment/dtor are Okay 58 | 59 | void swap(Date& that) { std::swap(julianDayNumber_, that.julianDayNumber_); } 60 | 61 | bool valid() const { return julianDayNumber_ > 0; } 62 | 63 | /// 64 | /// Converts to yyyy-mm-dd format. 65 | /// 66 | std::string toIsoString() const; 67 | 68 | struct YearMonthDay yearMonthDay() const; 69 | 70 | int year() const { return yearMonthDay().year; } 71 | 72 | int month() const { return yearMonthDay().month; } 73 | 74 | int day() const { return yearMonthDay().day; } 75 | 76 | // [0, 1, ..., 6] => [Sunday, Monday, ..., Saturday ] 77 | int weekDay() const { return (julianDayNumber_ + 1) % kDaysPerWeek; } 78 | 79 | int julianDayNumber() const { return julianDayNumber_; } 80 | 81 | private: 82 | int julianDayNumber_; 83 | }; 84 | 85 | inline bool operator<(Date x, Date y) { 86 | return x.julianDayNumber() < y.julianDayNumber(); 87 | } 88 | 89 | inline bool operator==(Date x, Date y) { 90 | return x.julianDayNumber() == y.julianDayNumber(); 91 | } 92 | 93 | } // namespace base 94 | #endif // MUDUO_BASE_DATE_H 95 | -------------------------------------------------------------------------------- /moe_infinity/distributed/devicemap_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) EfficientMoE. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # EfficientMoE Team 5 | 6 | # The global device manager shared among all nodes, using grpc server to communicate with each other. 7 | import random 8 | from typing import List, Tuple 9 | 10 | import torch.distributed as dist 11 | 12 | from moe_infinity.utils import ArcherConfig 13 | 14 | 15 | class DeviceMapManager: 16 | def __init__(self, archer_config: ArcherConfig) -> None: 17 | world_size = dist.get_world_size() 18 | device_per_node = archer_config.device_per_node 19 | 20 | total_device = world_size * device_per_node 21 | if total_device > 1: 22 | self.num_device_plan = [1] + [ 23 | x for x in range(2, total_device + 1, 2) 24 | ] 25 | else: 26 | self.num_device_plan = [1] 27 | 28 | self.device_per_node = device_per_node 29 | self.total_device = total_device 30 | self.world_size = world_size 31 | 32 | def set_expert_tensor_map(self, expert_tensor_map): 33 | self.expert_tensor_map = expert_tensor_map 34 | 35 | def set_archer_engine(self, archer_engine): 36 | self.archer_engine = archer_engine 37 | 38 | def get_target_device( 39 | self, expert_list: List[int] 40 | ) -> List[Tuple[int, int, int]]: 41 | num_experts = len(expert_list) 42 | num_device = self.total_device 43 | 44 | # index = np.argsort(expert_counts)[::-1] 45 | # expert_list = np.array(expert_list) 46 | # expert_list = expert_list[index] 47 | 48 | device_list = [] 49 | k = 0 50 | 51 | # scatter the experts to all GPUs 52 | r = num_device % self.world_size 53 | world_size = num_device // self.world_size 54 | 55 | if r > 0: 56 | world_size += 1 57 | 58 | base = 1 if self.world_size > 1 else 0 59 | 60 | gpu_ids = [id for id in range(self.device_per_node)] 61 | 62 | # scatter the experts to all GPUs 63 | while k < num_experts: 64 | random.shuffle(gpu_ids) 65 | for rank in range(base, min(world_size + base, self.world_size)): 66 | for gpu in gpu_ids: 67 | if k >= num_experts: 68 | break 69 | expert_id = expert_list[k] 70 | device_list.append((rank, gpu, expert_id)) 71 | k += 1 72 | 73 | return device_list 74 | -------------------------------------------------------------------------------- /core/memory/weights_buffer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "fixed_size_allocator.h" 4 | #include "base/noncopyable.h" 5 | #include "common/pytorch.h" 6 | 7 | template , 8 | typename Deleter = std::default_delete> 9 | class WeightsBuffer : public base::noncopyable { 10 | using ThisAllocator = FixedSizeAllocator; 11 | 12 | public: 13 | explicit WeightsBuffers(size_t num_buffers, 14 | std::vector> shapes, 15 | int torch_dtype) 16 | : allocator_(nullptr) { 17 | size_t size = 0; 18 | for (const auto& shape : shapes) { 19 | size += torch_shape_size(shape, torch_dtype); 20 | } 21 | 22 | allocator_ = std::make_unique(num_buffers, size); 23 | } 24 | 25 | void* get_slot(int layer_id, int expert_id) { 26 | uint64_t key = (static_cast(layer_id) << 32) | expert_id; 27 | std::unique_lock lock(mutex_); 28 | if (buffer_map_.find(key) == buffer_map_.end()) { 29 | void* buffer = allocator_->get_slot(); 30 | if (buffer == nullptr) { 31 | // DLOG_FATAL << "No empty slot in WeightsBuffer, buffer: " << buffer; 32 | return nullptr; 33 | } 34 | buffer_map_[key] = buffer; 35 | } 36 | weight_in_buffer_[key] = false; 37 | return buffer_map_[key]; 38 | } 39 | 40 | void set_slot(int layer_id, int expert_id) { 41 | uint64_t key = (static_cast(layer_id) << 32) | expert_id; 42 | { 43 | std::lock_guard lock(mutex_); 44 | weight_in_buffer_[key] = true; 45 | } 46 | cv_.notify_all(); 47 | } 48 | 49 | void wait_slot(int layer_id, int expert_id) { 50 | uint64_t key = (static_cast(layer_id) << 32) | expert_id; 51 | std::unique_lock lock(mutex_); 52 | cv_.wait(lock, [this, key] { 53 | return buffer_map_.find(key) != buffer_map_.end() && 54 | weight_in_buffer_[key]; 55 | }); 56 | } 57 | 58 | void release_slot((int layer_id, int expert_id)) { 59 | uint64_t key = (static_cast(layer_id) << 32) | expert_id; 60 | std::lock_guard lock(mutex_); 61 | buffer_map_.erase(key); 62 | weight_in_buffer_[key] = false; 63 | } 64 | 65 | private: 66 | std::unordered_map buffer_map_; 67 | std::unordered_map weight_in_buffer_; 68 | std::mutex mutex_; 69 | std::condition_variable cv_; 70 | 71 | ThisAllocator allocator_; 72 | }; 73 | -------------------------------------------------------------------------------- /core/utils/prefix_tree.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include // For std::unique_ptr 8 | 9 | // Trie Node Definition (Template) 10 | template 11 | class TrieNode { 12 | public: 13 | std::unordered_map>> children; 14 | std::unique_ptr value; // Store any data type 15 | bool is_end_of_word = false; 16 | }; 17 | 18 | template 19 | class Trie { 20 | private: 21 | std::unique_ptr> root_; 22 | 23 | public: 24 | Trie() : root_(std::make_unique>()) {} 25 | 26 | // Insert a key-value pair (T can be torch::Tensor or other types) 27 | void Insert(const std::string& key, T value) { 28 | TrieNode* node = root_.get(); 29 | for (char ch : key) { 30 | if (node->children.find(ch) == node->children.end()) { 31 | node->children[ch] = std::make_unique>(); 32 | } 33 | node = node->children[ch].get(); 34 | } 35 | node->is_end_of_word = true; 36 | node->value = std::make_unique(std::move(value)); // Move value 37 | } 38 | 39 | // Retrieve all keys with a given prefix 40 | void SearchPrefix(TrieNode* node, std::vector& results, 41 | std::string current) { 42 | if (node->is_end_of_word) { 43 | results.push_back(current); 44 | } 45 | for (auto& [ch, child] : node->children) { 46 | SearchPrefix(child.get(), results, current + ch); 47 | } 48 | } 49 | 50 | // Get all keys with a given prefix 51 | std::vector GetKeysWithPrefix(const std::string& prefix) { 52 | TrieNode* node = root_.get(); 53 | std::vector results; 54 | 55 | for (char ch : prefix) { 56 | if (node->children.find(ch) == node->children.end()) { 57 | return results; // No match 58 | } 59 | node = node->children[ch].get(); 60 | } 61 | SearchPrefix(node, results, prefix); 62 | return results; 63 | } 64 | 65 | // Get a value by exact key 66 | T Get(const std::string& key) { 67 | TrieNode* node = root_.get(); 68 | for (char ch : key) { 69 | if (node->children.find(ch) == node->children.end()) { 70 | throw std::runtime_error("Key not found: " + key); 71 | } 72 | node = node->children[ch].get(); 73 | } 74 | if (!node->is_end_of_word) { 75 | throw std::runtime_error("Key not found: " + key); 76 | } 77 | return *(node->value); // Return the stored value 78 | } 79 | }; 80 | -------------------------------------------------------------------------------- /moe_infinity/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) EfficientMoE. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # EfficientMoE Team 5 | 6 | import os 7 | from dataclasses import dataclass, field 8 | 9 | import torch 10 | from transformers import HfArgumentParser 11 | 12 | 13 | @dataclass 14 | class ArcherConfig: 15 | offload_path: str = field( 16 | default="", metadata={"help": "Path to parameter storage"} 17 | ) 18 | trace_capacity: int = field( 19 | default=1000, metadata={"help": "Capacity of trace"} 20 | ) 21 | trace_path: os.PathLike = field( 22 | default=None, metadata={"help": "Path to trace file"} 23 | ) 24 | # master_addr: str = field( 25 | # default="127.0.0.1", 26 | # metadata={"help": "Hosts for running archer"}, 27 | # ) 28 | # master_port: str = field( 29 | # default=29500, 30 | # metadata={"help": "Port for running archer"}, 31 | # ) 32 | # device_per_node: int = field( 33 | # default=1, 34 | # metadata={"help": "Number of devices per node"}, 35 | # ) 36 | prefetch: bool = field( 37 | default=False, metadata={"help": "Enable prefetching"} 38 | ) 39 | device_memory_ratio: float = field( 40 | default=0.9, 41 | metadata={"help": "Ratio of device memory to use"}, 42 | ) 43 | num_threads: int = field( 44 | default=8, metadata={"help": "Number of threads for each GPU exec"} 45 | ) 46 | host_memory_ratio: float = field( 47 | default=0.9, 48 | metadata={"help": "Ratio of host memory to use"}, 49 | ) 50 | 51 | @classmethod 52 | def load_from_file(self, config_path): 53 | parser = HfArgumentParser(self) 54 | self = parser.parse_json_file(json_file=config_path)[0] 55 | return self 56 | 57 | @classmethod 58 | def load_from_json(self, config_json): 59 | parser = HfArgumentParser(self) 60 | self = parser.parse_dict(config_json)[0] 61 | return self 62 | 63 | def __post_init__(self): 64 | self.perfect_cache_file = os.path.join( 65 | self.offload_path, "perfect_cache" 66 | ) 67 | 68 | self.device_per_node = ( 69 | torch.cuda.device_count() 70 | ) # always run on heterogeneous nodes 71 | 72 | if self.trace_path is not None: 73 | self.trace_path = os.path.abspath(self.trace_path) 74 | if os.path.isdir(self.trace_path): 75 | raise ValueError( 76 | "The trace path should be a file, not a directory." 77 | ) 78 | -------------------------------------------------------------------------------- /core/prefetch/archer_prefetch_handle.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include "aio/archer_tensor_handle.h" 9 | #include "model/model_topology.h" 10 | #include "parallel/expert_dispatcher.h" 11 | 12 | class ArcherPrefetchHandle { 13 | public: 14 | ArcherPrefetchHandle(const std::string& prefix, 15 | const double device_memory_ratio = 0.8); 16 | ~ArcherPrefetchHandle(); 17 | 18 | bool IsTensorOffloaded(const std::uint32_t tensor_id); 19 | 20 | void AcquireTensor(std::uint64_t& request_id, torch::Tensor& buffer); 21 | void ReleaseTensor(std::uint64_t& request_id, torch::Tensor& buffer); 22 | void PrefetchTensors(std::uint64_t& request_id, 23 | const std::vector& buffer); 24 | void FetchTensors(std::uint64_t& request_id, 25 | const std::vector& buffer); 26 | 27 | void ReplaceCacheCandidates(const std::vector& tensor_ids); 28 | void EnqueuePrefetch(const uint32_t tensor_id, int gpu_id); 29 | 30 | void OffloadTensor(torch::Tensor& tensor, const std::uint32_t tensor_id); 31 | void RegisterTensor(torch::Tensor& tensor, const std::uint32_t tensor_id); 32 | void RegisterModule(torch::nn::Module& module); 33 | void RegisterTensor(torch::Tensor* tensor); 34 | 35 | int GetNodeDefaultDevice(std::vector tensor_ids) const; 36 | int GetNodeDevice(std::vector tensor_ids) const; 37 | 38 | void SetTensorDevice(torch::Tensor& tensor, torch::Device device) const; 39 | 40 | torch::Tensor GetTrace(); 41 | torch::Tensor GetHitRate(); 42 | void SetTrace(const torch::Tensor& trace); 43 | void TraceRequest(const std::uint64_t request_id, const TensorID tensor_id); 44 | void SetTopology(const std::vector< 45 | std::tuple>>>& 46 | topology); 47 | void UpdateTensorMap(std::uint64_t old_ptr, std::uint64_t new_ptr); 48 | bool IsTensorIndexInitialized() const; 49 | bool IsTensorOnDevice(const torch::Tensor& tensor) const; 50 | bool IsTensorOnDevice(const TensorID tensor_id) const; 51 | 52 | void CleanUpResources(); 53 | 54 | // void SetNodeCachePriority(const std::uint64_t corr_id, const float 55 | // priority); 56 | 57 | private: 58 | std::string prefix_; 59 | std::unordered_map> 60 | node_id_to_tensor_ids_; 61 | std::unordered_set tensors_to_delete_; 62 | uint64_t last_layer_id_; 63 | NodePtr last_node_; 64 | bool has_cleaned_up_resources_; 65 | 66 | std::unordered_map> 67 | request_id_to_nodes_; 68 | 69 | std::mutex mutex_; 70 | }; 71 | -------------------------------------------------------------------------------- /core/utils/lockfree_queue.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "base/noncopyable.h" 10 | 11 | template 12 | class LockFreeQueue : public base::noncopyable { 13 | public: 14 | LockFreeQueue() { 15 | Node* dummy = new Node(); // Dummy node 16 | head_.store(dummy); 17 | tail_.store(dummy); 18 | } 19 | 20 | ~LockFreeQueue() { 21 | while (Node* old_head = head_.load()) { 22 | head_.store(old_head->next); 23 | delete old_head; 24 | } 25 | } 26 | 27 | void Push(T& value) { 28 | std::shared_ptr new_data = std::make_shared(std::move(value)); 29 | Node* new_node = new Node(); 30 | new_node->data = new_data; 31 | 32 | do { 33 | Node* old_tail = tail_.load(); 34 | Node* next = old_tail->next; 35 | if (old_tail == tail_.load()) { 36 | if (next == nullptr) { 37 | if (old_tail->next.compare_exchange_weak(next, new_node)) { 38 | tail_.compare_exchange_weak(old_tail, new_node); 39 | return; 40 | } 41 | } else { 42 | tail_.compare_exchange_weak(old_tail, next); 43 | } 44 | } 45 | } while (true); 46 | } 47 | 48 | bool Pop(T& value) { 49 | Node* old_head; 50 | 51 | do { 52 | old_head = head_.load(); // Read current head 53 | Node* next = old_head->next.load(std::memory_order_acquire); 54 | if (old_head == tail_.load(std::memory_order_acquire)) { 55 | return false; // Queue is empty 56 | } 57 | 58 | } while (!head_.compare_exchange_weak(old_head, old_head->next, 59 | std::memory_order_release)); 60 | 61 | value = *(old_head->next.load()->data); // Read value 62 | delete old_head; // Free old node 63 | return true; 64 | } 65 | 66 | bool Empty() const { return head_.load() == tail_.load(); } 67 | 68 | bool Full() const { 69 | return false; // Queue is unbounded 70 | } 71 | 72 | protected: 73 | struct Node { 74 | std::shared_ptr data; 75 | std::atomic next; 76 | 77 | Node() : next(nullptr) {} 78 | }; 79 | 80 | std::atomic head_; 81 | std::atomic tail_; 82 | }; 83 | 84 | template 85 | class LockFreeRecyclingQueue : public LockFreeQueue { 86 | public: 87 | LockFreeRecyclingQueue() = default; 88 | 89 | void Pop(T& item) override { 90 | LockFreeQueue::Pop(item); 91 | Push(item); 92 | } 93 | 94 | bool TryPop(T& item) override { 95 | bool success = LockFreeQueue::TryPop(item); 96 | Push(item); 97 | return success; 98 | } 99 | }; 100 | -------------------------------------------------------------------------------- /moe_infinity/runtime/hooks.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Callable 3 | 4 | import torch 5 | 6 | from moe_infinity.models import ( 7 | apply_rotary_pos_emb, 8 | apply_rotary_pos_emb_deepseek, 9 | ) 10 | 11 | 12 | def do_nothing_decorator(orig_func: Callable) -> Callable: 13 | @functools.wraps(orig_func) 14 | def do_nothing(*args, **kwargs): 15 | pass 16 | 17 | return do_nothing 18 | 19 | 20 | def empty_param_init_decorator(orig_param_init: Callable) -> Callable: 21 | @functools.wraps(orig_param_init) 22 | def empty_param_init(cls, *args, **kwargs): 23 | orig_param_init(cls, *args, **kwargs) 24 | 25 | for name, param in cls.named_parameters(recurse=False): 26 | param.data = torch.zeros(1, dtype=param.dtype, device=param.device) 27 | 28 | for name, buf in cls.named_buffers(recurse=False): 29 | buf.data = torch.zeros(1, dtype=buf.dtype, device=buf.device) 30 | 31 | return empty_param_init 32 | 33 | 34 | def activate_empty_init(): 35 | # for all the modules in torch.nn, add post_init method 36 | # assert False, torch.nn.modules.__dict__ 37 | for name, module in torch.nn.modules.__dict__.items(): 38 | if not isinstance(module, type): 39 | continue 40 | if not issubclass(module, torch.nn.modules.module.Module): 41 | continue 42 | if name in [ 43 | "Module", 44 | "Sequential", 45 | "ModuleDict", 46 | "ModuleList", 47 | "ParameterList", 48 | "ParameterDict", 49 | ]: 50 | continue 51 | module._old_init = module.__init__ 52 | module.__init__ = empty_param_init_decorator(module.__init__) 53 | 54 | if hasattr(module, "reset_parameters"): 55 | module._old_reset_parameters = module.reset_parameters 56 | module.reset_parameters = do_nothing_decorator( 57 | module.reset_parameters 58 | ) 59 | 60 | 61 | def deactivate_empty_init(): 62 | for name, module in torch.nn.modules.__dict__.items(): 63 | if not isinstance(module, type): 64 | continue 65 | if not issubclass(module, torch.nn.modules.module.Module): 66 | continue 67 | if name in [ 68 | "Module", 69 | "Sequential", 70 | "ModuleDict", 71 | "ModuleList", 72 | "ParameterList", 73 | "ParameterDict", 74 | ]: 75 | continue 76 | if hasattr(module, "_old_init"): 77 | module.__init__ = module._old_init 78 | del module._old_init 79 | 80 | if hasattr(module, "_old_reset_parameters"): 81 | module.reset_parameters = module._old_reset_parameters 82 | del module._old_reset_parameters 83 | -------------------------------------------------------------------------------- /moe_infinity/models/arctic.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from moe_infinity.utils import ArcherConfig 8 | 9 | from .modeling_arctic import ArcticConfig, ArcticMLP 10 | 11 | 12 | class SyncArcticMoeBlock(nn.Module): 13 | archer_config: ArcherConfig = None 14 | layer_id: int = None 15 | 16 | def __init__(self, config: ArcticConfig, layer_id: int, **kwargs): 17 | super().__init__() 18 | 19 | self.hidden_dim = config.hidden_size 20 | self.num_experts = config.num_local_experts 21 | self.layer_id = layer_id 22 | self.top_k = config.num_experts_per_tok 23 | self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 24 | 25 | self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) 26 | self.experts = nn.ModuleList( 27 | [ArcticMLP(config) for i in range(self.num_experts)] 28 | ) 29 | 30 | self.archer_tracer = None 31 | self.archer_engine = None 32 | self.expert_tensor_ids: Dict[int, int] = None 33 | 34 | def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: 35 | batch_size, sequence_length, hidden_dim = hidden_states.shape 36 | hidden_states = hidden_states.view(-1, hidden_dim) 37 | # router_logits: (batch * sequence_length, n_experts) 38 | router_logits = self.gate(hidden_states) 39 | 40 | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) 41 | routing_weights, selected_experts = torch.topk( 42 | routing_weights, self.top_k, dim=-1 43 | ) 44 | # we cast back to the input dtype 45 | routing_weights = routing_weights.to(hidden_states.dtype) 46 | 47 | expert_index = selected_experts.reshape( 48 | batch_size, sequence_length, self.top_k 49 | ) 50 | for i in range(batch_size): 51 | seq_id = self.seq_id_list[i] 52 | expert_matrix = self.expert_predictor.predict( 53 | seq_id, expert_index[i], self.layer_id 54 | ) 55 | self.expert_prefetcher.prefetch_experts( 56 | self.layer_id, expert_matrix 57 | ) 58 | 59 | final_hidden_states = torch.zeros( 60 | (batch_size * sequence_length, hidden_dim), 61 | dtype=hidden_states.dtype, 62 | device=hidden_states.device, 63 | ) 64 | # One hot encode the selected experts to create an expert mask 65 | # this will be used to easily index which expert is going to be sollicitated 66 | expert_mask = torch.nn.functional.one_hot( 67 | selected_experts, num_classes=self.num_experts 68 | ).permute(2, 1, 0) 69 | return final_hidden_states, expert_mask 70 | -------------------------------------------------------------------------------- /core/memory/memory_pool.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | #include "base/noncopyable.h" 8 | #include "common/pytorch.h" 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include "device_caching_allocator.h" 15 | #include "host_caching_allocator.h" 16 | #include "utils/logger.h" 17 | 18 | std::size_t GetTotalSystemMemory(); 19 | 20 | #ifndef DEVICE_MEMORY_RATIO 21 | #define DEVICE_MEMORY_RATIO 0.8 22 | #endif 23 | 24 | #ifndef HOST_MEMORY_RATIO 25 | #define HOST_MEMORY_RATIO 0.8 26 | #endif 27 | 28 | class HostMemoryPool : public base::noncopyable { 29 | public: 30 | void* AllocateMemory(const std::size_t key, const std::int64_t size, 31 | const torch::Device& device); 32 | int FreeMemory(const std::size_t key, void* data, const std::int64_t size, 33 | const torch::Device& device); 34 | std::int64_t GetFreeMemory(); 35 | std::int64_t GetMemoryCapacity(); 36 | 37 | HostMemoryPool(); 38 | virtual ~HostMemoryPool() { 39 | auto allocator = c10::HostCachingAllocator::get(); 40 | for (auto& [key, data_ptr] : allocated_id_) { 41 | if (data_ptr != nullptr) { 42 | allocator->free(data_ptr); 43 | } 44 | } 45 | allocated_id_.clear(); 46 | } 47 | 48 | private: 49 | std::unordered_map allocated_id_; 50 | std::int64_t free_memory_; 51 | std::int64_t memory_capacity_; 52 | std::mutex mutex_; 53 | }; 54 | 55 | class DeviceMemoryPool : public base::noncopyable { 56 | public: 57 | void* AllocateMemory(const std::size_t key, const std::int64_t size, 58 | const torch::Device& device); 59 | int FreeMemory(const std::size_t key, void* data, const std::int64_t size, 60 | const torch::Device& device); 61 | 62 | void SetMemoryRatio(const double ratio); 63 | std::int64_t GetFreeMemory(const torch::Device& device); 64 | std::int64_t GetMemoryCapacity(const torch::Device& device); 65 | 66 | DeviceMemoryPool(); 67 | virtual ~DeviceMemoryPool() { 68 | // auto allocator = c10::cuda::CUDACachingAllocator::get(); 69 | // for (auto& allocated_id : allocated_id_) { 70 | // for (auto& [key, data_ptr] : allocated_id) { 71 | // if (data_ptr != nullptr) { allocator->raw_deallocate(data_ptr); } 72 | // } 73 | // } 74 | // allocated_id_.clear(); 75 | // free_memory_.clear(); 76 | // memory_capacity_.clear(); 77 | } 78 | 79 | private: 80 | std::vector> allocated_id_; 81 | std::vector free_memory_; 82 | std::vector memory_capacity_; 83 | std::mutex mutex_; 84 | }; 85 | 86 | extern std::unique_ptr kHostMemoryPool; 87 | extern std::unique_ptr kDeviceMemoryPool; 88 | -------------------------------------------------------------------------------- /core/base/log_file.cc: -------------------------------------------------------------------------------- 1 | #include "log_file.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include "file_util.h" 10 | #include "process_info.h" 11 | 12 | using namespace base; 13 | 14 | LogFile::LogFile(const std::string& basename, size_t rollSize, bool threadSafe, 15 | int flushInterval, int checkEveryN) 16 | : basename_(basename), 17 | rollSize_(rollSize), 18 | flushInterval_(flushInterval), 19 | checkEveryN_(checkEveryN), 20 | count_(0), 21 | mutex_(threadSafe ? new std::mutex : NULL), 22 | startOfPeriod_(0), 23 | lastRoll_(0), 24 | lastFlush_(0) { 25 | rollFile(); 26 | } 27 | 28 | LogFile::~LogFile() {} 29 | 30 | void LogFile::append(const char* logline, int len) { 31 | if (mutex_) { 32 | std::lock_guard lock(*mutex_); 33 | append_unlocked(logline, len); 34 | } else { 35 | append_unlocked(logline, len); 36 | } 37 | } 38 | 39 | void LogFile::flush() { 40 | if (mutex_) { 41 | std::lock_guard lock(*mutex_); 42 | file_->flush(); 43 | } else { 44 | file_->flush(); 45 | } 46 | } 47 | 48 | void LogFile::append_unlocked(const char* logline, int len) { 49 | file_->append(logline, len); 50 | 51 | if (file_->writtenBytes() > rollSize_) { 52 | rollFile(); 53 | } else { 54 | ++count_; 55 | if (count_ >= checkEveryN_) { 56 | count_ = 0; 57 | time_t now = ::time(NULL); 58 | time_t thisPeriod_ = now / kRollPerSeconds_ * kRollPerSeconds_; 59 | if (thisPeriod_ != startOfPeriod_) { 60 | rollFile(); 61 | } else if (now - lastFlush_ > flushInterval_) { 62 | lastFlush_ = now; 63 | file_->flush(); 64 | } 65 | } 66 | } 67 | } 68 | 69 | bool LogFile::rollFile() { 70 | time_t now = 0; 71 | std::string filename = getLogFileName(basename_, &now); 72 | time_t start = now / kRollPerSeconds_ * kRollPerSeconds_; 73 | 74 | if (now > lastRoll_) { 75 | lastRoll_ = now; 76 | lastFlush_ = now; 77 | startOfPeriod_ = start; 78 | file_.reset(new FileUtil::AppendFile(filename)); 79 | return true; 80 | } 81 | return false; 82 | } 83 | 84 | std::string LogFile::getLogFileName(const std::string& basename, time_t* now) { 85 | std::string filename; 86 | filename.reserve(basename.size() + 64); 87 | filename = basename; 88 | 89 | char timebuf[32]; 90 | struct tm tm; 91 | *now = time(NULL); 92 | // gmtime_r(now, &tm); // FIXME: localtime_r ? 93 | localtime_r(now, &tm); 94 | strftime(timebuf, sizeof timebuf, ".%Y%m%d-%H%M%S.", &tm); 95 | filename += timebuf; 96 | 97 | filename += ProcessInfo::hostname(); 98 | 99 | char pidbuf[32]; 100 | snprintf(pidbuf, sizeof pidbuf, ".%d", ProcessInfo::pid()); 101 | filename += pidbuf; 102 | 103 | filename += ".log"; 104 | 105 | return filename; 106 | } 107 | -------------------------------------------------------------------------------- /core/utils/simple_object_pool.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | /** Pool for objects that cannot be used from different threads simultaneously. 9 | * Allows to create an object for each thread. 10 | * Pool has unbounded size and objects are not destroyed before destruction of 11 | * pool. 12 | * 13 | * Use it in cases when thread local storage is not appropriate 14 | * (when maximum number of simultaneously used objects is less 15 | * than number of running/sleeping threads, that has ever used object, 16 | * and creation/destruction of objects is expensive). 17 | */ 18 | template 19 | class SimpleObjectPool { 20 | protected: 21 | /// Hold all available objects in stack. 22 | std::mutex mutex; 23 | std::stack> stack; 24 | 25 | /// Specialized deleter for std::unique_ptr. 26 | /// Returns underlying pointer back to stack thus reclaiming its ownership. 27 | struct Deleter { 28 | SimpleObjectPool* parent; 29 | 30 | Deleter(SimpleObjectPool* parent_ = nullptr) 31 | : parent{parent_} {} /// NOLINT 32 | 33 | void operator()(T* owning_ptr) const { 34 | std::lock_guard lock{parent->mutex}; 35 | parent->stack.emplace(owning_ptr); 36 | } 37 | }; 38 | 39 | public: 40 | using Pointer = std::unique_ptr; 41 | 42 | /// Extracts and returns a pointer from the stack if it's not empty, 43 | /// creates a new one by calling provided f() otherwise. 44 | template 45 | Pointer get(Factory&& f) { 46 | std::unique_lock lock(mutex); 47 | 48 | if (stack.empty()) { 49 | lock.unlock(); 50 | return {f(), this}; 51 | } 52 | 53 | auto object = stack.top().release(); 54 | stack.pop(); 55 | 56 | return {object, this}; 57 | } 58 | 59 | /// Return a vector of pointers from the stack if it's not empty, 60 | /// creates a new one by calling provided f() otherwise. 61 | template 62 | std::vector getMany(size_t count, Factory&& f) { 63 | std::unique_lock lock(mutex); 64 | 65 | std::vector result; 66 | result.reserve(count); 67 | 68 | while (count > 0) { 69 | if (stack.empty()) { 70 | lock.unlock(); 71 | while (count > 0) { 72 | result.emplace_back(f(), this); 73 | --count; 74 | } 75 | return result; 76 | } 77 | 78 | auto object = stack.top().release(); 79 | stack.pop(); 80 | result.emplace_back(object, this); 81 | --count; 82 | } 83 | 84 | return result; 85 | } 86 | 87 | /// Like get(), but creates object using default constructor. 88 | Pointer getDefault() { 89 | return get([] { return new T; }); 90 | } 91 | 92 | /// Like getMany(), but creates objects using default constructor. 93 | std::vector getDefaultMany(size_t count) { 94 | return getMany(count, [] { return new T; }); 95 | } 96 | }; 97 | -------------------------------------------------------------------------------- /core/memory/kv_cache_buffer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "fixed_size_allocator.h" 4 | #include "base/noncopyable.h" 5 | #include "common/pytorch.h" 6 | 7 | template , 8 | typename Deleter = std::default_delete> 9 | class KVCacheBuffer : public base::noncopyable { 10 | using ThisAllocator = FixedSizeAllocator; 11 | 12 | public: 13 | explicit KVCacheBuffer(size_t context_length, std::vector key_shape, 14 | std::vector value_shape, int torch_dtype) 15 | : key_allocator_(nullptr), value_allocator_(nullptr) { 16 | size_t key_size = torch_shape_size(key_shape, torch_dtype); 17 | size_t value_size = torch_shape_size(value_shape, torch_dtype); 18 | 19 | key_allocator_ = std::make_unique(context_length, key_size); 20 | value_allocator_ = 21 | std::make_unique(context_length, value_size); 22 | } 23 | 24 | std::tuple get_slot(int layer_id, int microbatch_id) { 25 | uint64_t key = (static_cast(layer_id) << 32) | microbatch_id; 26 | std::unique_lock lock(mutex_); 27 | if (kv_buffer_map_.find(key) == kv_buffer_map_.end()) { 28 | void* key_ptr = key_allocator_->get_slot(); 29 | void* value_ptr = value_allocator_->get_slot(); 30 | 31 | if (key_ptr == nullptr || value_ptr == nullptr) { 32 | DLOG_FATAL << "No empty slot in KVCacheBuffer, key_ptr: " << key_ptr 33 | << ", value_ptr: " << value_ptr; 34 | return std::make_tuple(nullptr, nullptr); 35 | } 36 | 37 | kv_buffer_map_[key] = std::make_tuple(key_ptr, value_ptr); 38 | } 39 | kv_in_buffer_[key] = false; 40 | return kv_buffer_map_[key]; 41 | } 42 | 43 | void set_slot(int layer_id, int microbatch_id) { 44 | uint64_t key = (static_cast(layer_id) << 32) | microbatch_id; 45 | { 46 | std::lock_guard lock(mutex_); 47 | kv_in_buffer_[key] = true; 48 | } 49 | cv_.notify_all(); 50 | } 51 | 52 | void wait_slot(int layer_id, int microbatch_id) { 53 | uint64_t key = (static_cast(layer_id) << 32) | microbatch_id; 54 | std::unique_lock lock(mutex_); 55 | cv_.wait(lock, [this, key] { 56 | return kv_in_buffer_.find(key) != kv_in_buffer_.end() && 57 | kv_in_buffer_[key]; 58 | }); 59 | } 60 | 61 | void release_slot(int layer_id, int microbatch_id) { 62 | uint64_t key = (static_cast(layer_id) << 32) | microbatch_id; 63 | std::lock_guard lock(mutex_); 64 | kv_buffer_map_.erase(key); 65 | kv_in_buffer_[key] = false; 66 | } 67 | 68 | private: 69 | std::unordered_map> kv_buffer_map_; 70 | std::unordered_map kv_in_buffer_; 71 | std::mutex mutex_; 72 | std::condition_variable cv_; 73 | 74 | std::unique_ptr key_allocator_; 75 | std::unique_ptr value_allocator_; 76 | }; 77 | -------------------------------------------------------------------------------- /moe_infinity/models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def rotate_half(x): 5 | """Rotates half the hidden dims of the input.""" 6 | x1 = x[..., : x.shape[-1] // 2] 7 | x2 = x[..., x.shape[-1] // 2 :] 8 | return torch.cat((-x2, x1), dim=-1) 9 | 10 | 11 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): 12 | device = position_ids.device 13 | position_ids = position_ids.to(cos.device) 14 | cos = cos[position_ids].unsqueeze(unsqueeze_dim).to(q.device) 15 | sin = sin[position_ids].unsqueeze(unsqueeze_dim).to(q.device) 16 | # print("cos.shape", cos.device, "sin.shape", sin.device, "q.shape", q.device, "k.shape", k.device) 17 | q_embed = (q * cos) + (rotate_half(q) * sin) 18 | k_embed = (k * cos) + (rotate_half(k) * sin) 19 | position_ids = position_ids.to(device) 20 | return q_embed, k_embed 21 | 22 | 23 | def apply_rotary_pos_emb_deepseek( 24 | q, k, cos, sin, position_ids, unsqueeze_dim=1 25 | ): 26 | """Applies Rotary Position Embedding to the query and key tensors. 27 | 28 | Args: 29 | q (`torch.Tensor`): The query tensor. 30 | k (`torch.Tensor`): The key tensor. 31 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 32 | sin (`torch.Tensor`): The sine part of the rotary embedding. 33 | position_ids (`torch.Tensor`): 34 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be 35 | used to pass offsetted position ids when working with a KV-cache. 36 | unsqueeze_dim (`int`, *optional*, defaults to 1): 37 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 38 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 39 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 40 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 41 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 42 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 43 | Returns: 44 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 45 | """ 46 | device = position_ids.device 47 | position_ids = position_ids.to(cos.device) 48 | cos = cos[position_ids].unsqueeze(unsqueeze_dim).to(q.device) 49 | sin = sin[position_ids].unsqueeze(unsqueeze_dim).to(q.device) 50 | 51 | b, h, s, d = q.shape 52 | q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) 53 | 54 | b, h, s, d = k.shape 55 | k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) 56 | 57 | q_embed = (q * cos) + (rotate_half(q) * sin) 58 | k_embed = (k * cos) + (rotate_half(k) * sin) 59 | position_ids = position_ids.to(device) 60 | return q_embed, k_embed 61 | -------------------------------------------------------------------------------- /moe_infinity/distributed/expert_prefetcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) EfficientMoE. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # EfficientMoE Team 5 | 6 | import torch.distributed as dist 7 | from torch.distributed import rpc 8 | from transformers import PretrainedConfig 9 | 10 | from moe_infinity.utils import parse_moe_param 11 | 12 | 13 | def _call_expert_prefetcher(method, *args, **kwargs): 14 | global _expert_prefetcher 15 | func = getattr(_expert_prefetcher, method) 16 | return func(*args, **kwargs) 17 | 18 | 19 | class DistributedExpertPrefetcher(object): 20 | cache_file_rd = None 21 | 22 | def __init__(self, config: PretrainedConfig): 23 | print(config) 24 | self.num_layers, self.num_experts, self.num_encoder_layers = ( 25 | parse_moe_param(config) 26 | ) 27 | 28 | def set_archer_engine(self, archer_engine): 29 | global _expert_prefetcher 30 | _expert_prefetcher = archer_engine 31 | self.archer_engine = archer_engine 32 | 33 | def set_device_map_manager(self, device_map_manager): 34 | self.device_map_manager = device_map_manager 35 | 36 | def set_archer_prefetch(self, archer_prefetch): 37 | self.archer_prefetch = archer_prefetch 38 | 39 | def prefetch_experts(self, layer_id, expert_matrix): 40 | expert_list = [] 41 | # print("expert_tensor_map", self.expert_tensor_map) 42 | for i in range(layer_id, self.num_layers): 43 | for j in range(self.num_experts): 44 | if expert_matrix[i, j] > 0: 45 | expert_list.append( 46 | (self.expert_tensor_map[(i, j)], expert_matrix[i, j]) 47 | ) 48 | ordered_expert_list = sorted( 49 | expert_list, key=lambda x: x[1], reverse=True 50 | ) 51 | tensor_ids = [x[0] for x in ordered_expert_list] 52 | device_list = self.device_map_manager.get_target_device(tensor_ids) 53 | 54 | if len(tensor_ids) > 0: 55 | self._replace_cache_candidates(tensor_ids) 56 | for meta in device_list: 57 | rank, gpu_id, tensor_id = meta 58 | if rank == dist.get_rank(): 59 | self.archer_engine.enqueue_prefetch(tensor_id, gpu_id) 60 | else: 61 | rpc.rpc_async( 62 | f"worker_{rank}", 63 | _call_expert_prefetcher, 64 | args=("enqueue_prefetch", tensor_id, gpu_id), 65 | ) 66 | 67 | def _replace_cache_candidates(self, tensor_ids): 68 | futures = [] 69 | for k in range(dist.get_world_size()): 70 | if k == dist.get_rank(): 71 | self.archer_engine.replace_cache_candidates(tensor_ids) 72 | else: 73 | future = rpc.rpc_async( 74 | f"worker_{k}", 75 | _call_expert_prefetcher, 76 | args=("replace_cache_candidates", tensor_ids), 77 | ) 78 | futures.append(future) 79 | -------------------------------------------------------------------------------- /core/utils/logger.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include "base/logging.h" 9 | 10 | inline void print(base::LogStream& stream) {} 11 | 12 | template 13 | inline void print(base::LogStream& stream, T first, Args... args) { 14 | stream << first; 15 | if constexpr (sizeof...(args) > 0) { 16 | stream << " "; 17 | print(stream, args...); // Recursive call 18 | } 19 | } 20 | 21 | #define DLOG_TRACE(...) \ 22 | do { \ 23 | if (base::Logger::logLevel() <= base::Logger::TRACE) \ 24 | print(base::Logger(__FILE__, __LINE__, base::Logger::TRACE, __func__) \ 25 | .stream(), \ 26 | __VA_ARGS__); \ 27 | } while (0) 28 | 29 | #define DLOG_DEBUG(...) \ 30 | do { \ 31 | if (base::Logger::logLevel() <= base::Logger::DEBUG) \ 32 | print(base::Logger(__FILE__, __LINE__, base::Logger::DEBUG, __func__) \ 33 | .stream(), \ 34 | __VA_ARGS__); \ 35 | } while (0) 36 | 37 | #define DLOG_INFO(...) \ 38 | do { \ 39 | if (base::Logger::logLevel() <= base::Logger::INFO) \ 40 | print(base::Logger(__FILE__, __LINE__).stream(), __VA_ARGS__); \ 41 | } while (0) 42 | 43 | #define DLOG_ERROR(...) \ 44 | do { \ 45 | if (base::Logger::logLevel() <= base::Logger::ERROR) \ 46 | print(base::Logger(__FILE__, __LINE__, base::Logger::ERROR).stream(), \ 47 | __VA_ARGS__); \ 48 | } while (0); 49 | 50 | #define DLOG_WARN(...) \ 51 | do { \ 52 | if (base::Logger::logLevel() <= base::Logger::WARN) \ 53 | print(base::Logger(__FILE__, __LINE__, base::Logger::WARN).stream(), \ 54 | __VA_ARGS__); \ 55 | } while (0) 56 | 57 | #define DLOG_FATAL(...) \ 58 | do { \ 59 | if (base::Logger::logLevel() <= base::Logger::FATAL) \ 60 | print(base::Logger(__FILE__, __LINE__, base::Logger::FATAL).stream(), \ 61 | __VA_ARGS__); \ 62 | } while (0) 63 | -------------------------------------------------------------------------------- /core/memory/fixed_size_allocator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "base/noncopyable.h" 9 | #include "utils/logger.h" 10 | #include "utils/cuda_utils.h" 11 | #include "common/types.h" 12 | 13 | template , 14 | typename Deleter = std::default_delete> 15 | class FixedSizeAllocator : public base::noncopyable { 16 | public: 17 | // Constructor - Allocate memory using custom allocator 18 | explicit FixedSizeAllocator(int num_slots, size_t slot_size) 19 | : num_slots_(num_slots), slot_size_(slot_size) { 20 | if (slot_size_ == 0) { 21 | DLOG_WARN << "slot size is 0 in FixedSizeAllocator"; 22 | return; 23 | } 24 | 25 | // LCM of slot_size and 2MB to be chunk size 26 | chunk_size_ = 2 * 1024 * 1024; 27 | if (slot_size_ % chunk_size_ != 0) { 28 | chunk_size_ = std::lcm(slot_size_, chunk_size_); 29 | } 30 | 31 | int num_chunks = num_slots_ * slot_size_ / chunk_size_; 32 | for (int i = 0; i < num_chunks; ++i) { 33 | void* raw_ptr = allocator(chunk_size_); 34 | if (raw_ptr == nullptr) { 35 | DLOG_FATAL << "Failed to allocate memory in FixedSizeAllocator"; 36 | } 37 | std::unique_ptr ptr(nullptr); 38 | ptr.reset(reinterpret_cast(raw_ptr)); 39 | chunks_.push_back(std::move(ptr)); 40 | } 41 | 42 | for (int i = 0; i < num_slots_; ++i) { 43 | int j = i * slot_size_ / chunk_size_; 44 | void* raw_ptr = chunks_[j].get() + (i * slot_size_ % chunk_size_); 45 | slot_map_[reinterpret_cast(raw_ptr)] = false; 46 | } 47 | 48 | DLOG_INFO << "FixedSizeAllocator created: num_slots=" << num_slots 49 | << ", slot_size=" << slot_size << ", chunk_size=" << chunk_size_; 50 | } 51 | 52 | // Access underlying pointer 53 | // T* get() const { return ptr.get(); } 54 | T* get_slot() const { 55 | for (auto& pair : slot_map_) { 56 | if (!pair.second) { 57 | pair.second = true; 58 | return reinterpret_cast(pair.first); 59 | } 60 | } 61 | // DLOG_WARN << "No empty slot in FixedSizeAllocator"; 62 | return nullptr; 63 | } 64 | void release_slot(T* slot) { 65 | if (slot == nullptr) { 66 | // DLOG_WARN << "Invalid slot in FixedSizeAllocator"; 67 | return; 68 | } 69 | if (slot_map_.find(slot) == slot_map_.end()) { 70 | DLOG_FATAL << "Invalid slot in FixedSizeAllocator"; 71 | } 72 | slot_map_[reinterpret_cast(slot)] = false; 73 | } 74 | 75 | private: 76 | std::vector> chunks_; // The allocated memory 77 | Allocator allocator; // Custom allocator 78 | std::unordered_map slot_map_; 79 | int num_slots_; 80 | size_t slot_size_; 81 | size_t chunk_size_; 82 | }; 83 | 84 | typedef FixedSizeAllocator 85 | CUDADeviceFixedSizeAllocator; 86 | typedef FixedSizeAllocator 87 | CUDAHostFixedSizeAllocator; 88 | -------------------------------------------------------------------------------- /tests/queues/test_lockfree_queue.cpp: -------------------------------------------------------------------------------- 1 | #include "utils/lockfree_queue.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | TEST(LockFreeQueueTest, SingleThreadedPushPop) { 8 | LockFreeQueue queue; 9 | int value; 10 | 11 | int a = 1; 12 | queue.Push(a); 13 | ASSERT_TRUE(queue.Pop(value)); 14 | ASSERT_EQ(value, 1); 15 | } 16 | 17 | TEST(LockFreeQueueTest, SequentialPushParallelPop) { 18 | LockFreeQueue queue; 19 | int value; 20 | 21 | // Sequential push 22 | for (int i = 0; i < 10; ++i) { 23 | queue.Push(i); 24 | } 25 | 26 | // Parallel pop 27 | std::vector threads; 28 | std::vector results(10); 29 | for (int i = 0; i < 10; ++i) { 30 | threads.emplace_back([&queue, &results, i]() { 31 | int val; 32 | while (!queue.Pop(val)) { 33 | // Busy wait 34 | } 35 | results[i] = val; 36 | }); 37 | } 38 | for (auto& t : threads) { 39 | t.join(); 40 | } 41 | 42 | // Verify results 43 | std::sort(results.begin(), results.end()); 44 | for (int i = 0; i < 10; ++i) { 45 | ASSERT_EQ(results[i], i); 46 | } 47 | } 48 | 49 | TEST(LockFreeQueueTest, ParallelPushSequentialPop) { 50 | LockFreeQueue queue; 51 | int value; 52 | 53 | // Parallel push 54 | std::vector threads; 55 | for (int i = 0; i < 10; ++i) { 56 | threads.emplace_back([&queue, i]() { 57 | int val = i; 58 | queue.Push(val); 59 | }); 60 | } 61 | for (auto& t : threads) { 62 | t.join(); 63 | } 64 | 65 | // Sequential pop 66 | std::vector results(10); 67 | for (int i = 0; i < 10; ++i) { 68 | ASSERT_TRUE(queue.Pop(value)); 69 | results[i] = value; 70 | } 71 | ASSERT_FALSE(queue.Pop(value)); // Queue should be empty 72 | 73 | // Verify results 74 | std::sort(results.begin(), results.end()); 75 | for (int i = 0; i < 10; ++i) { 76 | ASSERT_EQ(results[i], i); 77 | } 78 | } 79 | 80 | TEST(LockFreeQueueTest, ParallelPushParallelPop) { 81 | LockFreeQueue queue; 82 | int value; 83 | 84 | // Parallel push 85 | std::vector push_threads; 86 | for (int i = 0; i < 10; ++i) { 87 | push_threads.emplace_back([&queue, i]() { 88 | int val = i; 89 | queue.Push(val); 90 | }); 91 | } 92 | 93 | // Parallel pop 94 | std::vector pop_threads; 95 | std::vector results(10); 96 | for (int i = 0; i < 10; ++i) { 97 | pop_threads.emplace_back([&queue, &results, i]() { 98 | int val; 99 | while (!queue.Pop(val)) { 100 | // Busy wait 101 | } 102 | results[i] = val; 103 | }); 104 | } 105 | 106 | for (auto& t : push_threads) { 107 | t.join(); 108 | } 109 | for (auto& t : pop_threads) { 110 | t.join(); 111 | } 112 | 113 | // Verify results 114 | std::sort(results.begin(), results.end()); 115 | for (int i = 0; i < 10; ++i) { 116 | ASSERT_EQ(results[i], i); 117 | } 118 | } 119 | 120 | int main(int argc, char** argv) { 121 | ::testing::InitGoogleTest(&argc, argv); 122 | return RUN_ALL_TESTS(); 123 | } 124 | -------------------------------------------------------------------------------- /tests/queues/test_threadsafe_queue.cpp: -------------------------------------------------------------------------------- 1 | #include "utils/threadsafe_queue.h" 2 | #include 3 | #include 4 | #include 5 | 6 | TEST(ThreadSafeQueueTest, SingleThreadedPushPop) { 7 | ThreadSafeQueue queue; 8 | int value; 9 | 10 | int a = 1; 11 | queue.Push(a); 12 | ASSERT_TRUE(queue.Pop(value)); 13 | ASSERT_EQ(value, 1); 14 | } 15 | 16 | TEST(ThreadSafeQueueTest, SequentialPushParallelPop) { 17 | ThreadSafeQueue queue; 18 | int value; 19 | 20 | // Sequential push 21 | for (int i = 0; i < 10; ++i) { 22 | queue.Push(i); 23 | } 24 | 25 | // Parallel pop 26 | std::vector threads; 27 | std::vector results(10); 28 | for (int i = 0; i < 10; ++i) { 29 | threads.emplace_back([&queue, &results, i]() { 30 | int val; 31 | while (!queue.Pop(val)) { 32 | // Busy wait 33 | } 34 | results[i] = val; 35 | }); 36 | } 37 | for (auto& t : threads) { 38 | t.join(); 39 | } 40 | 41 | // Verify results 42 | std::sort(results.begin(), results.end()); 43 | for (int i = 0; i < 10; ++i) { 44 | ASSERT_EQ(results[i], i); 45 | } 46 | } 47 | 48 | TEST(ThreadSafeQueueTest, ParallelPushSequentialPop) { 49 | ThreadSafeQueue queue; 50 | int value; 51 | 52 | // Parallel push 53 | std::vector threads; 54 | for (int i = 0; i < 10; ++i) { 55 | threads.emplace_back([&queue, i]() { 56 | int val = i; 57 | queue.Push(val); 58 | }); 59 | } 60 | for (auto& t : threads) { 61 | t.join(); 62 | } 63 | 64 | // Sequential pop 65 | std::vector results(10); 66 | for (int i = 0; i < 10; ++i) { 67 | ASSERT_TRUE(queue.Pop(value)); 68 | results[i] = value; 69 | } 70 | ASSERT_FALSE(queue.TryPop(value)); // Queue should be empty 71 | 72 | // Verify results 73 | std::sort(results.begin(), results.end()); 74 | for (int i = 0; i < 10; ++i) { 75 | ASSERT_EQ(results[i], i); 76 | } 77 | } 78 | 79 | TEST(ThreadSafeQueueTest, ParallelPushParallelPop) { 80 | ThreadSafeQueue queue; 81 | int value; 82 | 83 | // Parallel push 84 | std::vector push_threads; 85 | for (int i = 0; i < 10; ++i) { 86 | push_threads.emplace_back([&queue, i]() { 87 | int val = i; 88 | queue.Push(val); 89 | }); 90 | } 91 | 92 | // Parallel pop 93 | std::vector pop_threads; 94 | std::vector results(10); 95 | for (int i = 0; i < 10; ++i) { 96 | pop_threads.emplace_back([&queue, &results, i]() { 97 | int val; 98 | while (!queue.Pop(val)) { 99 | // Busy wait 100 | } 101 | results[i] = val; 102 | }); 103 | } 104 | 105 | for (auto& t : push_threads) { 106 | t.join(); 107 | } 108 | for (auto& t : pop_threads) { 109 | t.join(); 110 | } 111 | 112 | // Verify results 113 | std::sort(results.begin(), results.end()); 114 | for (int i = 0; i < 10; ++i) { 115 | ASSERT_EQ(results[i], i); 116 | } 117 | } 118 | 119 | int main(int argc, char** argv) { 120 | ::testing::InitGoogleTest(&argc, argv); 121 | return RUN_ALL_TESTS(); 122 | } 123 | -------------------------------------------------------------------------------- /op_builder/prefetch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Microsoft Corporation. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # DeepSpeed Team 6 | 7 | # op_builder/async_io.py 8 | # 9 | # Part of the DeepSpeed Project, under the Apache-2.0 License. 10 | # See https://github.com/microsoft/DeepSpeed/blob/master/LICENSE for license information. 11 | # SPDX-License-Identifier: Apache-2.0 12 | 13 | # MoE-Infinity: replaced AsyncIOBuilder with PrefetchBuilder 14 | 15 | import glob 16 | import os 17 | 18 | from .builder import OpBuilder 19 | 20 | 21 | class PrefetchBuilder(OpBuilder): 22 | BUILD_VAR = "MOE_BUILD_PREFETCH" 23 | NAME = "prefetch" 24 | 25 | def __init__(self): 26 | super().__init__(name=self.NAME) 27 | 28 | def absolute_name(self): 29 | return f"moe_infinity.ops.prefetch.{self.NAME}_op" 30 | 31 | def sources(self): 32 | return [ 33 | "core/utils/logger.cpp", 34 | "core/utils/cuda_utils.cpp", 35 | "core/model/model_topology.cpp", 36 | "core/prefetch/archer_prefetch_handle.cpp", 37 | "core/prefetch/task_scheduler.cpp", 38 | "core/prefetch/task_thread.cpp", 39 | "core/memory/memory_pool.cpp", 40 | "core/memory/stream_pool.cpp", 41 | "core/memory/host_caching_allocator.cpp", 42 | "core/memory/device_caching_allocator.cpp", 43 | "core/python/py_archer_prefetch.cpp", 44 | "core/parallel/expert_dispatcher.cpp", 45 | "core/parallel/expert_module.cpp", 46 | "core/aio/archer_aio_thread.cpp", 47 | "core/aio/archer_prio_aio_handle.cpp", 48 | "core/aio/archer_aio_utils.cpp", 49 | "core/aio/archer_aio_threadpool.cpp", 50 | "core/aio/archer_tensor_handle.cpp", 51 | "core/aio/archer_tensor_index.cpp", 52 | "core/base/thread.cc", 53 | "core/base/exception.cc", 54 | "core/base/date.cc", 55 | "core/base/process_info.cc", 56 | "core/base/logging.cc", 57 | "core/base/log_file.cc", 58 | "core/base/timestamp.cc", 59 | "core/base/file_util.cc", 60 | "core/base/countdown_latch.cc", 61 | "core/base/timezone.cc", 62 | "core/base/log_stream.cc", 63 | "core/base/thread_pool.cc", 64 | ] 65 | 66 | def include_paths(self): 67 | return ["core"] 68 | 69 | def cxx_args(self): 70 | # -O0 for improved debugging, since performance is bound by I/O 71 | CPU_ARCH = self.cpu_arch() 72 | SIMD_WIDTH = self.simd_width() 73 | return [ 74 | "-g", 75 | "-Wall", 76 | "-O3", 77 | "-std=c++17", 78 | "-shared", 79 | "-fPIC", 80 | "-Wno-reorder", 81 | CPU_ARCH, 82 | "-fopenmp", 83 | SIMD_WIDTH, 84 | "-I/usr/local/cuda/include", 85 | "-L/usr/local/cuda/lib64", 86 | "-lcuda", 87 | "-lcudart", 88 | "-lcublas", 89 | "-lpthread", 90 | ] 91 | 92 | def extra_ldflags(self): 93 | return [] 94 | 95 | def is_compatible(self, verbose=True): 96 | return super().is_compatible(verbose) 97 | -------------------------------------------------------------------------------- /.github/workflows/publish-test.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Test PyPI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | paths-ignore: 9 | - '**.md' 10 | - 'examples/**' 11 | - 'tests/**' 12 | - 'docs/**' 13 | 14 | permissions: 15 | contents: write 16 | 17 | jobs: 18 | setup-version: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - name: Generate version number 22 | run: | 23 | VERSION_HASH=$(date +"%Y%m%d%H%M%S") 24 | echo "Generated version hash: $VERSION_HASH" 25 | echo $VERSION_HASH > version.txt 26 | 27 | - name: Upload version number as artifact 28 | uses: actions/upload-artifact@v4 29 | with: 30 | name: version 31 | path: version.txt 32 | 33 | wheel: 34 | name: Build Wheel 35 | runs-on: ${{ matrix.os }} 36 | permissions: write-all 37 | strategy: 38 | fail-fast: false 39 | matrix: 40 | os: ['ubuntu-20.04'] 41 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] 42 | cuda-version: ['12.1'] 43 | 44 | steps: 45 | - name: Checkout Source Code 46 | uses: actions/checkout@v3 47 | 48 | - name: Download version value artifact 49 | uses: actions/download-artifact@v4 50 | with: 51 | name: version 52 | path: artifact 53 | 54 | - name: Free disk space 55 | run: | 56 | rm -rf /usr/local/cuda-* /opt/cuda 57 | rm -rf /usr/local/cuda 58 | bash -x .github/workflows/scripts/free-disk-space.sh 59 | 60 | - name: Set up Python ${{ matrix.python-version }} 61 | uses: actions/setup-python@v4 62 | with: 63 | python-version: ${{ matrix.python-version }} 64 | 65 | - name: Install Dependencies 66 | run: | 67 | python3 -m pip install --upgrade pip 68 | python3 -m pip install build 69 | 70 | - name: Install CUDA ${{ matrix.cuda-version }} 71 | run: | 72 | bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} 73 | 74 | - name: Build Wheel 75 | shell: bash 76 | run: | 77 | VERSION_HASH=$(cat artifact/version.txt) 78 | MOEINF_VERSION=0.0.1dev${VERSION_HASH} BUILD_OPS=1 python3 -m build --wheel 79 | wheel_name=$(ls dist/*whl | xargs -n 1 basename) 80 | asset_name=${wheel_name//"linux"/"manylinux1"} 81 | echo "wheel_name=${wheel_name}" >> $GITHUB_ENV 82 | echo "asset_name=${asset_name}" >> $GITHUB_ENV 83 | 84 | - name: Build Source 85 | if: ${{ matrix.python-version == '3.8' }} 86 | run: | 87 | VERSION_HASH=$(cat artifact/version.txt) 88 | MOEINF_VERSION=0.0.1dev${VERSION_HASH} python3 -m build --sdist 89 | 90 | - name: Rename Wheel 91 | run: | 92 | mv dist/${{ env.wheel_name }} dist/${{ env.asset_name }} 93 | 94 | - name: Publish Package to Test PyPI 95 | uses: pypa/gh-action-pypi-publish@release/v1.8 96 | with: 97 | repository-url: https://test.pypi.org/legacy/ 98 | skip-existing: true 99 | env: 100 | TWINE_USERNAME: ${{ secrets.TEST_PYPI_USERNAME }} 101 | TWINE_PASSWORD: ${{ secrets.TEST_PYPI_PASSWORD }} 102 | -------------------------------------------------------------------------------- /core/base/thread_pool.cc: -------------------------------------------------------------------------------- 1 | // Use of this source code is governed by a BSD-style license 2 | // that can be found in the License file. 3 | // 4 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 5 | 6 | #include "thread_pool.h" 7 | 8 | #include "exception.h" 9 | 10 | #include 11 | #include 12 | 13 | using namespace base; 14 | 15 | ThreadPool::ThreadPool(const std::string& nameArg) 16 | : mutex_(), 17 | notEmpty_(), 18 | notFull_(), 19 | name_(nameArg), 20 | maxQueueSize_(0), 21 | running_(false) {} 22 | 23 | ThreadPool::~ThreadPool() { 24 | if (running_) { 25 | stop(); 26 | } 27 | } 28 | 29 | void ThreadPool::start(int numThreads) { 30 | assert(threads_.empty()); 31 | running_ = true; 32 | threads_.reserve(numThreads); 33 | for (int i = 0; i < numThreads; ++i) { 34 | char id[32]; 35 | snprintf(id, sizeof id, "%d", i + 1); 36 | threads_.emplace_back(new base::Thread( 37 | std::bind(&ThreadPool::runInThread, this), name_ + id)); 38 | threads_[i]->start(); 39 | } 40 | if (numThreads == 0 && threadInitCallback_) { 41 | threadInitCallback_(); 42 | } 43 | } 44 | 45 | void ThreadPool::stop() { 46 | { 47 | std::lock_guard lock(mutex_); 48 | running_ = false; 49 | } 50 | notEmpty_.notify_all(); 51 | for (auto& thr : threads_) { 52 | thr->join(); 53 | } 54 | } 55 | 56 | size_t ThreadPool::queueSize() const { 57 | std::lock_guard lock(mutex_); 58 | return queue_.size(); 59 | } 60 | 61 | void ThreadPool::run(Task task) { 62 | if (threads_.empty()) { 63 | task(); 64 | } else { 65 | std::unique_lock lock(mutex_); 66 | notFull_.wait(lock, [this] { return !isFull() || !running_; }); 67 | assert(!isFull()); 68 | 69 | queue_.push_back(std::move(task)); 70 | notEmpty_.notify_one(); 71 | } 72 | } 73 | 74 | ThreadPool::Task ThreadPool::take() { 75 | std::unique_lock lock(mutex_); 76 | // always use a while-loop, due to spurious wakeup 77 | notEmpty_.wait(lock, [this] { return !queue_.empty() || !running_; }); 78 | Task task; 79 | if (!queue_.empty()) { 80 | task = queue_.front(); 81 | queue_.pop_front(); 82 | if (maxQueueSize_ > 0) { 83 | notFull_.notify_one(); 84 | } 85 | } 86 | return task; 87 | } 88 | 89 | bool ThreadPool::isFull() const { 90 | // mutex_.assertLocked(); FIXME: assertLocked() is not a member of std::mutex 91 | return maxQueueSize_ > 0 && queue_.size() >= maxQueueSize_; 92 | } 93 | 94 | void ThreadPool::runInThread() { 95 | try { 96 | if (threadInitCallback_) { 97 | threadInitCallback_(); 98 | } 99 | while (running_) { 100 | Task task(take()); 101 | if (task) { 102 | task(); 103 | } 104 | } 105 | } catch (const Exception& ex) { 106 | fprintf(stderr, "exception caught in ThreadPool %s\n", name_.c_str()); 107 | fprintf(stderr, "reason: %s\n", ex.what()); 108 | fprintf(stderr, "stack trace: %s\n", ex.stackTrace()); 109 | abort(); 110 | } catch (const std::exception& ex) { 111 | fprintf(stderr, "exception caught in ThreadPool %s\n", name_.c_str()); 112 | fprintf(stderr, "reason: %s\n", ex.what()); 113 | abort(); 114 | } catch (...) { 115 | fprintf(stderr, "unknown exception caught in ThreadPool %s\n", 116 | name_.c_str()); 117 | throw; // rethrow 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Create Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | permissions: 9 | contents: write 10 | 11 | jobs: 12 | release: 13 | name: Create Release 14 | runs-on: ubuntu-20.04 15 | outputs: 16 | upload_url: ${{ steps.create_release.outputs.upload_url }} 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v3 20 | 21 | - name: Extract branch info 22 | shell: bash 23 | run: | 24 | echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV 25 | 26 | - name: Create Release 27 | id: create_release 28 | uses: actions/github-script@v6 29 | env: 30 | RELEASE_TAG: ${{ env.release_tag }} 31 | with: 32 | github-token: "${{ secrets.GITHUB_TOKEN }}" 33 | script: | 34 | const script = require('.github/workflows/scripts/create-release.js') 35 | await script(github, context, core) 36 | 37 | wheel: 38 | name: Build Wheel 39 | runs-on: ubuntu-20.04 40 | needs: release 41 | strategy: 42 | fail-fast: false 43 | matrix: 44 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] 45 | cuda-version: ['12.1'] 46 | 47 | steps: 48 | - name: Checkout Source Code 49 | uses: actions/checkout@v3 50 | 51 | - name: Free Disk Space 52 | run: | 53 | rm -rf /usr/local/cuda-* /opt/cuda 54 | rm -rf /usr/local/cuda 55 | bash -x .github/workflows/scripts/free-disk-space.sh 56 | 57 | - name: Set up Python ${{ matrix.python-version }} 58 | uses: actions/setup-python@v4 59 | with: 60 | python-version: ${{ matrix.python-version }} 61 | 62 | - name: Install Dependencies 63 | run: | 64 | python3 -m pip install --upgrade pip 65 | python3 -m pip install build 66 | 67 | - name: Install CUDA ${{ matrix.cuda-version }} 68 | run: | 69 | bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} 70 | 71 | - name: Build Wheel 72 | shell: bash 73 | run: | 74 | BUILD_OPS=1 python3 -m build --wheel 75 | wheel_name=$(ls dist/*whl | xargs -n 1 basename) 76 | asset_name=${wheel_name//"linux"/"manylinux1"} 77 | echo "wheel_name=${wheel_name}" >> $GITHUB_ENV 78 | echo "asset_name=${asset_name}" >> $GITHUB_ENV 79 | 80 | - name: Build Source 81 | if: ${{ matrix.python-version == '3.8' }} 82 | run: | 83 | python3 -m build --sdist 84 | 85 | - name: Rename Wheel 86 | run: | 87 | mv dist/${{ env.wheel_name }} dist/${{ env.asset_name }} 88 | 89 | - name: Upload Release Asset 90 | uses: actions/upload-release-asset@v1 91 | env: 92 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 93 | with: 94 | upload_url: ${{ needs.release.outputs.upload_url }} 95 | asset_path: ./dist/${{ env.asset_name }} 96 | asset_name: ${{ env.asset_name }} 97 | asset_content_type: application/* 98 | 99 | - name: Publish Package to PyPI 100 | uses: pypa/gh-action-pypi-publish@release/v1.8 101 | with: 102 | skip-existing: true 103 | env: 104 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 105 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 106 | -------------------------------------------------------------------------------- /core/base/timestamp.h: -------------------------------------------------------------------------------- 1 | #ifndef MUDUO_BASE_TIMESTAMP_H 2 | #define MUDUO_BASE_TIMESTAMP_H 3 | 4 | #include 5 | 6 | #include "copyable.h" 7 | #include "types.h" 8 | 9 | // #include 10 | 11 | namespace base { 12 | 13 | /// 14 | /// Time stamp in UTC, in microseconds resolution. 15 | /// 16 | /// This class is immutable. 17 | /// It's recommended to pass it by value, since it's passed in register on x64. 18 | /// 19 | class Timestamp : public copyable { 20 | public: 21 | /// 22 | /// Constructs an invalid Timestamp. 23 | /// 24 | Timestamp() : microSecondsSinceEpoch_(0) {} 25 | 26 | /// 27 | /// Constructs a Timestamp at specific time 28 | /// 29 | /// @param microSecondsSinceEpoch 30 | explicit Timestamp(int64_t microSecondsSinceEpochArg) 31 | : microSecondsSinceEpoch_(microSecondsSinceEpochArg) {} 32 | 33 | void swap(Timestamp& that) { 34 | std::swap(microSecondsSinceEpoch_, that.microSecondsSinceEpoch_); 35 | } 36 | 37 | // default copy/assignment/dtor are Okay 38 | 39 | std::string toString() const; 40 | std::string toFormattedString(bool showMicroseconds = true) const; 41 | 42 | bool valid() const { return microSecondsSinceEpoch_ > 0; } 43 | 44 | // for internal usage. 45 | int64_t microSecondsSinceEpoch() const { return microSecondsSinceEpoch_; } 46 | time_t secondsSinceEpoch() const { 47 | return static_cast(microSecondsSinceEpoch_ / 48 | kMicroSecondsPerSecond); 49 | } 50 | 51 | /// 52 | /// Get time of now. 53 | /// 54 | static Timestamp now(); 55 | static Timestamp invalid() { return Timestamp(); } 56 | 57 | static Timestamp fromUnixTime(time_t t) { return fromUnixTime(t, 0); } 58 | 59 | static Timestamp fromUnixTime(time_t t, int microseconds) { 60 | return Timestamp(static_cast(t) * kMicroSecondsPerSecond + 61 | microseconds); 62 | } 63 | 64 | static const int kMicroSecondsPerSecond = 1000 * 1000; 65 | 66 | private: 67 | int64_t microSecondsSinceEpoch_; 68 | }; 69 | 70 | inline bool operator<(Timestamp lhs, Timestamp rhs) { 71 | return lhs.microSecondsSinceEpoch() < rhs.microSecondsSinceEpoch(); 72 | } 73 | 74 | inline bool operator==(Timestamp lhs, Timestamp rhs) { 75 | return lhs.microSecondsSinceEpoch() == rhs.microSecondsSinceEpoch(); 76 | } 77 | 78 | /// 79 | /// Gets time difference of two timestamps, result in seconds. 80 | /// 81 | /// @param high, low 82 | /// @return (high-low) in seconds 83 | /// @c double has 52-bit precision, enough for one-microsecond 84 | /// resolution for next 100 years. 85 | inline double timeDifference(Timestamp high, Timestamp low) { 86 | int64_t diff = high.microSecondsSinceEpoch() - low.microSecondsSinceEpoch(); 87 | return static_cast(diff) / Timestamp::kMicroSecondsPerSecond; 88 | } 89 | 90 | inline int64_t timeDifferenceUs(Timestamp high, Timestamp low) { 91 | int64_t diff = high.microSecondsSinceEpoch() - low.microSecondsSinceEpoch(); 92 | return diff; 93 | } 94 | 95 | /// 96 | /// Add @c seconds to given timestamp. 97 | /// 98 | /// @return timestamp+seconds as Timestamp 99 | /// 100 | inline Timestamp addTime(Timestamp timestamp, double seconds) { 101 | int64_t delta = 102 | static_cast(seconds * Timestamp::kMicroSecondsPerSecond); 103 | return Timestamp(timestamp.microSecondsSinceEpoch() + delta); 104 | } 105 | 106 | } // namespace base 107 | #endif // MUDUO_BASE_TIMESTAMP_H 108 | -------------------------------------------------------------------------------- /moe_infinity/utils/checkpoints.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import os 17 | from typing import Union 18 | 19 | from accelerate.utils.constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME 20 | 21 | 22 | def get_checkpoint_paths(checkpoint: Union[str, os.PathLike]): 23 | """ 24 | Returns the paths to the checkpoint files in a given folder. 25 | 26 | Args: 27 | checkpoint_folder (`str` or `os.PathLike`): 28 | The folder where we will look for the checkpoint files. 29 | """ 30 | checkpoint_files = None 31 | index_filename = None 32 | if os.path.isfile(checkpoint): 33 | if str(checkpoint).endswith(".json"): 34 | index_filename = checkpoint 35 | else: 36 | checkpoint_files = [checkpoint] 37 | elif os.path.isdir(checkpoint): 38 | # check if the whole state dict is present 39 | potential_state_bin = [ 40 | f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME 41 | ] 42 | potential_state_safetensor = [ 43 | f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME 44 | ] 45 | if len(potential_state_bin) == 1: 46 | checkpoint_files = [ 47 | os.path.join(checkpoint, potential_state_bin[0]) 48 | ] 49 | elif len(potential_state_safetensor) == 1: 50 | checkpoint_files = [ 51 | os.path.join(checkpoint, potential_state_safetensor[0]) 52 | ] 53 | else: 54 | # otherwise check for sharded checkpoints 55 | potential_index = [ 56 | f for f in os.listdir(checkpoint) if f.endswith(".index.json") 57 | ] 58 | if len(potential_index) == 0: 59 | raise ValueError( 60 | f"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file" 61 | ) 62 | elif len(potential_index) == 1: 63 | index_filename = os.path.join(checkpoint, potential_index[0]) 64 | else: 65 | raise ValueError( 66 | f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones." 67 | ) 68 | else: 69 | raise ValueError( 70 | "`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded " 71 | f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}." 72 | ) 73 | 74 | if index_filename is not None: 75 | checkpoint_folder = os.path.split(index_filename)[0] 76 | with open(index_filename, "r") as f: 77 | index = json.loads(f.read()) 78 | 79 | if "weight_map" in index: 80 | index = index["weight_map"] 81 | checkpoint_files = sorted(list(set(index.values()))) 82 | checkpoint_files = [ 83 | os.path.join(checkpoint_folder, f) for f in checkpoint_files 84 | ] 85 | 86 | return checkpoint_files 87 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) EfficientMoE. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # EfficientMoE Team 5 | 6 | import io 7 | import os 8 | import sys 9 | 10 | from setuptools import find_packages, setup 11 | 12 | torch_available = True 13 | try: 14 | import torch # noqa: F401 15 | except ImportError: 16 | torch_available = False 17 | print( 18 | "[WARNING] Unable to import torch, pre-compiling ops will be disabled. " 19 | "Please visit https://pytorch.org/ to see how to properly install torch on your system." 20 | ) 21 | 22 | ROOT_DIR = os.path.dirname(__file__) 23 | 24 | sys.path.insert(0, ROOT_DIR) 25 | # sys.path.insert(0, os.path.join(ROOT_DIR, 'src')) 26 | 27 | from torch.utils import cpp_extension 28 | 29 | from op_builder.all_ops import ALL_OPS 30 | 31 | RED_START = "\033[31m" 32 | RED_END = "\033[0m" 33 | ERROR = f"{RED_START} [ERROR] {RED_END}" 34 | 35 | 36 | def fetch_requirements(path): 37 | with open(path, "r") as fd: 38 | return [r.strip() for r in fd.readlines()] 39 | 40 | 41 | def get_path(*filepath) -> str: 42 | return os.path.join(ROOT_DIR, *filepath) 43 | 44 | 45 | def abort(msg): 46 | print(f"{ERROR} {msg}") 47 | assert False, msg 48 | 49 | 50 | def read_readme() -> str: 51 | """Read the README file if present.""" 52 | p = get_path("README.md") 53 | if os.path.isfile(p): 54 | return io.open(get_path("README.md"), "r", encoding="utf-8").read() 55 | else: 56 | return "" 57 | 58 | 59 | install_requires = fetch_requirements("requirements.txt") 60 | 61 | ext_modules = [] 62 | 63 | BUILD_OP_DEFAULT = int(os.environ.get("BUILD_OPS", 0)) 64 | 65 | if BUILD_OP_DEFAULT: 66 | assert torch_available, "Unable to pre-compile ops without torch installed. Please install torch before attempting to pre-compile ops." 67 | compatible_ops = dict.fromkeys(ALL_OPS.keys(), False) 68 | install_ops = dict.fromkeys(ALL_OPS.keys(), False) 69 | for op_name, builder in ALL_OPS.items(): 70 | if builder is not None: 71 | op_compatible = builder.is_compatible() 72 | compatible_ops[op_name] = op_compatible 73 | if not op_compatible: 74 | abort(f"Unable to pre-compile {op_name}") 75 | ext_modules.append(builder.builder()) 76 | 77 | cmdclass = { 78 | "build_ext": cpp_extension.BuildExtension.with_options(use_ninja=True) 79 | } 80 | 81 | print(f"find_packages: {find_packages()}") 82 | 83 | # install all files in the package, rather than just the egg 84 | setup( 85 | name="moe_infinity", 86 | version=os.getenv("MOEINF_VERSION", "0.0.1"), 87 | packages=find_packages( 88 | exclude=["op_builder", "op_builder.*", "moe_infinity.ops.core.*"] 89 | ), 90 | package_data={ 91 | "moe_infinity.ops.prefetch": ["**/*.so"], 92 | "moe_infinity": ["ops/core/**"], 93 | }, 94 | include_package_data=True, 95 | install_requires=install_requires, 96 | author="EfficientMoE Team", 97 | long_description=read_readme(), 98 | long_description_content_type="text/markdown", 99 | url="https://github.com/EfficientMoE/MoE-Infinity", 100 | project_urls={"Homepage": "https://github.com/EfficientMoE/MoE-Infinity"}, 101 | classifiers=[ 102 | "Programming Language :: Python :: 3.8", 103 | "Programming Language :: Python :: 3.9", 104 | "Programming Language :: Python :: 3.10", 105 | "Programming Language :: Python :: 3.11", 106 | "License :: OSI Approved :: Apache Software License", 107 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 108 | ], 109 | license="Apache License 2.0", 110 | python_requires=">=3.8", 111 | ext_modules=ext_modules, 112 | cmdclass=cmdclass, 113 | ) 114 | -------------------------------------------------------------------------------- /core/prefetch/task_scheduler.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "base/noncopyable.h" 17 | #include "common/pytorch.h" 18 | #include "model/model_topology.h" 19 | 20 | #define SKIP_TO_NEXT_ITERATION \ 21 | std::this_thread::sleep_for(std::chrono::microseconds(10)); \ 22 | continue; 23 | 24 | #define NUM_PRIORITY 20UL 25 | 26 | struct Task { 27 | bool on_demand = false; 28 | NodePtr node; 29 | std::vector remove_nodes; 30 | std::uint32_t priority; 31 | std::uint64_t request_id; 32 | torch::Device src_device = DISK_DEVICE; 33 | torch::Device dst_device = DISK_DEVICE; 34 | cudaStream_t stream = nullptr; 35 | 36 | bool remove_layer = false; 37 | 38 | std::string DebugString() { 39 | std::stringstream ss; 40 | ss << "Task: node: " << node->str() << ", on_demand: " << on_demand 41 | << ", priority: " << priority << "[" << src_device.str() << "->" 42 | << dst_device.str() << "]"; 43 | return ss.str(); 44 | } 45 | }; 46 | typedef std::shared_ptr TaskPtr; 47 | 48 | class ArcherTaskPool : public base::noncopyable { 49 | public: 50 | void StartExec(const std::uint64_t& request_id, const NodePtr& node); 51 | void FetchExec(const std::uint64_t& request_id, const NodePtr& node); 52 | void StopExec(const std::uint64_t& request_id, const NodePtr& node); 53 | void EnqueueTask(const TaskPtr& task); 54 | 55 | void ClearQueue() { 56 | std::lock_guard lock(unified_mutex_); 57 | for (std::uint32_t priority = 1; priority < NUM_PRIORITY; priority++) { 58 | unified_queue_[priority].clear(); 59 | } 60 | } 61 | 62 | bool RemoveCachedSparseNode(const NodePtr& node, int device_id = -1); 63 | bool RemoveCachedDenseNode(const NodePtr& node); 64 | // void RemoveCachedNode(const NodePtr& node); 65 | 66 | void ReplaceCacheCandidates(const NodePtrList& candidates) { 67 | std::lock_guard lock(unified_mutex_); 68 | { 69 | std::lock_guard lock(this->candidates_mutex_); 70 | candidates_.clear(); 71 | for (auto& node : candidates) { 72 | candidates_.insert(node); 73 | } 74 | } 75 | 76 | for (std::uint32_t priority = 1; priority < NUM_PRIORITY; priority++) { 77 | unified_queue_[priority].clear(); 78 | } 79 | } 80 | 81 | DELETE_COPY_AND_ASSIGN(ArcherTaskPool); 82 | STATIC_GET_INSTANCE(ArcherTaskPool); 83 | 84 | ArcherTaskPool(); 85 | ~ArcherTaskPool() { 86 | std::cout << "ArcherTaskPool destructor" << std::endl; 87 | main_thread_stop_flag_.store(true); 88 | // wait for all threads to stop 89 | for (auto& thread_list : exec_threads_) { 90 | for (auto& thread : thread_list) { 91 | thread.join(); 92 | } 93 | } 94 | } 95 | 96 | private: 97 | void GPUThreadFunc(int gpu_id, int thread_id); 98 | 99 | void SetNodeDevice(const TaskPtr& task); 100 | 101 | std::string DebugString(const std::vector>& queue); 102 | 103 | private: 104 | std::vector> unified_queue_; // For ordered prefetch 105 | std::vector> gpu_min_priority_; 106 | std::unordered_map exec_queue_; 107 | std::mutex exec_mutex_; 108 | std::mutex unified_mutex_; 109 | std::mutex candidates_mutex_; 110 | 111 | std::vector> exec_threads_; 112 | 113 | std::unordered_set candidates_; 114 | 115 | std::atomic main_thread_stop_flag_; 116 | }; 117 | 118 | extern std::unique_ptr kTaskPool; 119 | -------------------------------------------------------------------------------- /core/aio/archer_aio_utils.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #include "archer_aio_utils.h" 7 | #include 8 | #include 9 | #include 10 | #include "utils/logger.h" 11 | 12 | const size_t kBlockSize = 1 * 1024 * 1024; 13 | const size_t kQueueDepth = std::thread::hardware_concurrency() / 14 | 4; // set to 1/4 total number of cores in the system 15 | 16 | int ArcherOpenFile(const char* filename) { 17 | const int flags = (O_RDWR | O_CREAT | O_DIRECT); 18 | const int mode = 0660; 19 | const auto fd = open(filename, flags, mode); 20 | if (fd < 0) { 21 | DLOG_FATAL("Failed to open file: ", filename); 22 | return -1; 23 | } 24 | return fd; 25 | } 26 | 27 | int ArcherCloseFile(const int fd) { 28 | const auto ret = close(fd); 29 | if (ret < 0) { 30 | DLOG_FATAL("Failed to close file: ", fd); 31 | return -1; 32 | } 33 | return 0; 34 | } 35 | 36 | int ArcherReadFileBatch(const int fd, void* buffer, const size_t num_bytes, 37 | const size_t offset) { 38 | std::vector> futures; 39 | const auto num_blocks = 40 | std::ceil(static_cast(num_bytes) / kBlockSize); 41 | // const auto num_threads = std::thread::hardware_concurrency(); 42 | 43 | for (auto i = 0; i < num_blocks; ++i) { 44 | const auto shift = i * kBlockSize; 45 | const auto xfer_buffer = (char*)buffer + shift; 46 | const auto xfer_offset = offset + shift; 47 | auto byte_count = kBlockSize; 48 | if ((shift + kBlockSize) > num_bytes) { 49 | byte_count = num_bytes - shift; 50 | } 51 | 52 | futures.emplace_back(std::async(std::launch::async, pread, fd, xfer_buffer, 53 | byte_count, xfer_offset)); 54 | } 55 | 56 | for (auto& future : futures) { 57 | const auto ret = future.get(); 58 | if (ret < 0) { 59 | DLOG_FATAL("Failed to read file: ", fd); 60 | return -1; 61 | } 62 | } 63 | 64 | return 0; 65 | } 66 | 67 | int ArcherWriteFileBatch(const int fd, const void* buffer, 68 | const size_t num_bytes, const size_t offset) { 69 | std::vector> futures; 70 | const auto num_blocks = 71 | std::ceil(static_cast(num_bytes) / kBlockSize); 72 | 73 | for (auto i = 0; i < num_blocks; ++i) { 74 | const auto shift = i * kBlockSize; 75 | const auto xfer_buffer = (char*)buffer + shift; 76 | const auto xfer_offset = offset + shift; 77 | auto byte_count = kBlockSize; 78 | if ((shift + kBlockSize) > num_bytes) { 79 | byte_count = num_bytes - shift; 80 | } 81 | 82 | futures.emplace_back(std::async(std::launch::async, pwrite, fd, xfer_buffer, 83 | byte_count, xfer_offset)); 84 | } 85 | 86 | for (auto& future : futures) { 87 | const auto ret = future.get(); 88 | if (ret < 0) { 89 | DLOG_FATAL("Failed to write file: ", fd, ", errno: ", errno, 90 | ", msg: ", strerror(errno)); 91 | return -1; 92 | } 93 | } 94 | 95 | return 0; 96 | } 97 | 98 | int ArcherReadFile(int fd, void* buffer, const size_t num_bytes, 99 | const size_t offset) { 100 | const auto ret = pread(fd, buffer, num_bytes, offset); 101 | if (ret < 0) { 102 | DLOG_FATAL("Failed to read file: ", fd, ", errno: ", errno, 103 | ", msg: ", strerror(errno)); 104 | return -1; 105 | } 106 | 107 | return 0; 108 | } 109 | 110 | int ArcherWriteFile(int fd, const void* buffer, size_t num_bytes, 111 | size_t offset) { 112 | const auto ret = pwrite(fd, buffer, num_bytes, offset); 113 | if (ret < 0) { 114 | DLOG_FATAL("Failed to write file: ", fd, ", errno: ", errno, 115 | ", msg: ", strerror(errno)); 116 | return -1; 117 | } 118 | 119 | return 0; 120 | } 121 | -------------------------------------------------------------------------------- /core/memory/host_caching_allocator.cpp: -------------------------------------------------------------------------------- 1 | //===- c10/mobile/CPUCachingAllocator.cpp ----------------===// 2 | // 3 | // Part of the Pytorch Project, under the BSD 3-Clause License. 4 | // See https://github.com/pytorch/pytorch/blob/main/LICENSE for license 5 | // information. SPDX-License-Identifier: BSD 3-Clause 6 | 7 | // MoE-Infinity: modified from c10::CPUCachingAllocator. 8 | // replaced c10::alloc_cpu with cudaHostAlloc 9 | 10 | #include "host_caching_allocator.h" 11 | #include 12 | #include 13 | 14 | namespace c10 { 15 | namespace HostCachingAllocator { 16 | 17 | std::mutex HostCachingAllocator::mutex_; 18 | ska::flat_hash_map HostCachingAllocator::allocation_map_; 19 | 20 | inline void* HostCachingAllocator::allocate_and_cache(const size_t bytes) { 21 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) 22 | void* ptr; 23 | auto cuda_err = cudaHostAlloc(&ptr, bytes, cudaHostAllocDefault); 24 | if (cuda_err != cudaSuccess) { 25 | free_cached(); 26 | cuda_err = cudaHostAlloc(&ptr, bytes, cudaHostAllocDefault); 27 | if (cuda_err != cudaSuccess) { 28 | throw std::runtime_error("cudaHostAlloc failed"); 29 | } 30 | } 31 | 32 | allocation_map_[ptr] = bytes; 33 | return ptr; 34 | } 35 | 36 | void* HostCachingAllocator::allocate(const size_t bytes) { 37 | std::lock_guard guard(mutex_); 38 | const auto& it = available_map_.find(bytes); 39 | if (it == available_map_.end() || it->second.empty()) { 40 | return allocate_and_cache(bytes); 41 | } 42 | return it->second.pop_back_val(); 43 | } 44 | 45 | void HostCachingAllocator::free(void* ptr) { 46 | // NB: since we are not really freeing the memory 47 | // the cases such as quantization code freeing original weights 48 | // on mobile, will not quite work, as we likely will hold 49 | // onto that memory. 50 | // NB: We can also enable max memory cached for better memory 51 | // management such that free will actually free the memory if 52 | // we are nearing or above the watermark. 53 | std::lock_guard guard(mutex_); 54 | // If this allocation was done before caching allocator was enabled 55 | // then free regularly 56 | const auto& it = allocation_map_.find(ptr); 57 | if (it == allocation_map_.end()) { 58 | // c10::free_cpu(ptr); 59 | cudaFreeHost(ptr); 60 | return; 61 | } 62 | const size_t alloc_size = it->second; 63 | available_map_[alloc_size].push_back(ptr); 64 | } 65 | 66 | void HostCachingAllocator::record_free(void* ptr) { 67 | // This function captures the case when the allocated memory 68 | // is being freed outside the scope of this allocator. 69 | // At the moment only way to capture this is to have the allocator, 70 | // that uses this CachingAllocator as the backing allocator, 71 | // call this function explicitly upon freeing memory while 72 | // outside the scope of caching allocator. 73 | // If the memory is freed in some other way, then we will likely 74 | // have undefined behavior or page fault. But this can be 75 | // the case without caching allocator as well. 76 | std::lock_guard guard(mutex_); 77 | const auto& it = allocation_map_.find(ptr); 78 | if (it != allocation_map_.end()) { 79 | allocation_map_.erase(it); 80 | } 81 | } 82 | 83 | void HostCachingAllocator::free_cached() { 84 | for (const auto& it : available_map_) { 85 | for (const auto ptr : it.second) { 86 | // c10::free_cpu(ptr); 87 | cudaFreeHost(ptr); 88 | // When cached memory is return to OS, it must be removed 89 | // from allocation_map. 90 | allocation_map_.erase(ptr); 91 | } 92 | } 93 | available_map_.clear(); 94 | } 95 | 96 | HostCachingAllocator::~HostCachingAllocator() { free_cached(); } 97 | 98 | HostCachingAllocator* caching_allocator = new HostCachingAllocator(); 99 | 100 | HostCachingAllocator* get() { 101 | if (caching_allocator == nullptr) { 102 | caching_allocator = new HostCachingAllocator(); 103 | } 104 | return caching_allocator; 105 | } 106 | 107 | } // namespace HostCachingAllocator 108 | 109 | } // namespace c10 110 | -------------------------------------------------------------------------------- /core/parallel/expert_dispatcher.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) EfficientMoE. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // EfficientMoE Team 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "base/noncopyable.h" 17 | #include "base/thread.h" 18 | #include "expert_module.h" 19 | 20 | enum MUTEX_TYPE { 21 | INPUT_MUTEX = 0, 22 | OUTPUT_MUTEX = 1, 23 | EXEC_MUTEX = 2, 24 | PENDING_MUTEX = 3 25 | }; 26 | 27 | class ExpertDispatcher : public base::noncopyable { 28 | public: 29 | typedef struct { 30 | int layer_idx = -1; 31 | int expert_idx = -1; 32 | int gpu_id = -1; 33 | bool remote = false; 34 | } CallArgs; 35 | typedef struct { 36 | torch::Tensor hidden_states = 37 | torch::empty({0}); // shallow copy, real tensor in python code 38 | ExpertNodePtr expert_node = nullptr; 39 | int out_gpu_id = -1; 40 | torch::ScalarType out_dtype = torch::kFloat32; 41 | bool evict = false; 42 | bool hit = false; 43 | } ExecArgs; 44 | typedef std::tuple CallResult; 45 | 46 | public: 47 | explicit ExpertDispatcher(int num_experts, int num_layers, int dtype, 48 | int expert_type, int num_threads = 8); 49 | ~ExpertDispatcher() { 50 | main_thread_stop_flag_.store(true); 51 | for (auto& thread : threads_) { 52 | thread->join(); 53 | } 54 | 55 | for (auto& stream : fetch_streams_) { 56 | cudaStreamDestroy(stream); 57 | } 58 | for (auto& stream : exec_streams_) { 59 | cudaStreamDestroy(stream); 60 | } 61 | for (auto& stream : out_streams_) { 62 | cudaStreamDestroy(stream); 63 | } 64 | } 65 | 66 | void SetInputs(const torch::Tensor& hidden_states, 67 | const torch::Tensor& router_mask) { 68 | hidden_states_ = hidden_states.clone(); 69 | router_mask_ = router_mask.clone(); 70 | } 71 | 72 | void EnqueueExpert(int layer_idx, int expert_idx, int gpu_id = -1, 73 | bool remote = false); 74 | 75 | void RegisterExpert(int layer_idx, int expert_idx, 76 | const std::vector& tensor_ids); 77 | void ClearExpertCacheCounts(); 78 | void SetExpectedQueue(int expected_pending = 0) { 79 | pending_.store(expected_pending); 80 | } 81 | 82 | std::vector WaitExpert() { return Wait(); } 83 | void SetNode(int layer_idx, int expert_idx, const NodePtr& node) { 84 | experts_[expert_idx][layer_idx]->node = node; 85 | } 86 | 87 | private: 88 | void Enqueue(CallArgs& args); 89 | std::vector Wait(); 90 | void Start() { start_ = true; } 91 | 92 | void GPUFetchFunc(int gpu_id); 93 | void GPUExecFunc(int gpu_id); 94 | 95 | // void GPUThreadFunc(int gpu_id); 96 | 97 | void OutputFunc(ExecArgs args, torch::Tensor output, int gpu_id); 98 | 99 | private: 100 | std::vector> threads_; 101 | std::mutex mutex_; 102 | std::vector> input_queue_; 103 | std::vector> exec_queue_; 104 | std::vector output_queue_; 105 | std::vector> experts_; 106 | std::atomic num_enqueued_; 107 | bool start_; 108 | int expert_type_; 109 | std::atomic main_thread_stop_flag_; 110 | 111 | std::atomic pending_; 112 | 113 | std::mutex pending_mutex_; 114 | std::condition_variable pending_cv_; 115 | 116 | std::vector input_mutex_; 117 | std::vector exec_mutex_; 118 | std::vector input_cv_; 119 | std::vector exec_cv_; 120 | 121 | std::mutex output_mutex_; 122 | // std::mutex exec_mutex_; 123 | std::mutex gpu_overload_mutex_; 124 | 125 | std::vector fetch_streams_; 126 | std::vector exec_streams_; 127 | std::vector out_streams_; 128 | 129 | std::vector gpu_overload_; 130 | 131 | torch::Tensor hidden_states_; 132 | torch::Tensor router_mask_; 133 | 134 | std::vector cache_sizes_; 135 | 136 | int cache_capacity_ = 0; 137 | }; 138 | -------------------------------------------------------------------------------- /core/base/logging.h: -------------------------------------------------------------------------------- 1 | #ifndef MUDUO_BASE_LOGGING_H 2 | #define MUDUO_BASE_LOGGING_H 3 | 4 | #include "log_stream.h" 5 | #include "timestamp.h" 6 | 7 | namespace base { 8 | 9 | class TimeZone; 10 | 11 | class Logger { 12 | public: 13 | enum LogLevel { 14 | TRACE, 15 | DEBUG, 16 | INFO, 17 | WARN, 18 | ERROR, 19 | FATAL, 20 | NUM_LOG_LEVELS, 21 | }; 22 | 23 | // compile time calculation of basename of source file 24 | class SourceFile { 25 | public: 26 | template 27 | inline SourceFile(const char (&arr)[N]) : data_(arr), size_(N - 1) { 28 | const char* slash = strrchr(data_, '/'); // builtin function 29 | if (slash) { 30 | data_ = slash + 1; 31 | size_ -= static_cast(data_ - arr); 32 | } 33 | } 34 | 35 | explicit SourceFile(const char* filename) : data_(filename) { 36 | const char* slash = strrchr(filename, '/'); 37 | if (slash) { 38 | data_ = slash + 1; 39 | } 40 | size_ = static_cast(strlen(data_)); 41 | } 42 | 43 | const char* data_; 44 | int size_; 45 | }; 46 | 47 | Logger(SourceFile file, int line); 48 | Logger(SourceFile file, int line, LogLevel level); 49 | Logger(SourceFile file, int line, LogLevel level, const char* func); 50 | Logger(SourceFile file, int line, bool toAbort); 51 | ~Logger(); 52 | 53 | LogStream& stream() { return impl_.stream_; } 54 | 55 | static LogLevel logLevel(); 56 | static void setLogLevel(LogLevel level); 57 | static void setLogLevel(const std::string& level); 58 | 59 | typedef void (*OutputFunc)(const char* msg, int len); 60 | typedef void (*FlushFunc)(); 61 | static void setOutput(OutputFunc); 62 | static void setFlush(FlushFunc); 63 | static void setTimeZone(const TimeZone& tz); 64 | 65 | private: 66 | class Impl { 67 | public: 68 | typedef Logger::LogLevel LogLevel; 69 | Impl(LogLevel level, int old_errno, const SourceFile& file, int line); 70 | void formatTime(); 71 | void finish(); 72 | 73 | Timestamp time_; 74 | LogStream stream_; 75 | LogLevel level_; 76 | int line_; 77 | SourceFile basename_; 78 | }; 79 | 80 | Impl impl_; 81 | }; 82 | 83 | extern Logger::LogLevel g_logLevel; 84 | 85 | inline Logger::LogLevel Logger::logLevel() { return g_logLevel; } 86 | 87 | // 88 | // CAUTION: do not write: 89 | // 90 | // if (good) 91 | // LOG_INFO << "Good news"; 92 | // else 93 | // LOG_WARN << "Bad news"; 94 | // 95 | // this expends to 96 | // 97 | // if (good) 98 | // if (logging_INFO) 99 | // logInfoStream << "Good news"; 100 | // else 101 | // logWarnStream << "Bad news"; 102 | // 103 | #define LOG_TRACE \ 104 | if (base::Logger::logLevel() <= base::Logger::TRACE) \ 105 | base::Logger(__FILE__, __LINE__, base::Logger::TRACE, __func__).stream() 106 | #define LOG_DEBUG \ 107 | if (base::Logger::logLevel() <= base::Logger::DEBUG) \ 108 | base::Logger(__FILE__, __LINE__, base::Logger::DEBUG, __func__).stream() 109 | #define LOG_INFO \ 110 | if (base::Logger::logLevel() <= base::Logger::INFO) \ 111 | base::Logger(__FILE__, __LINE__).stream() 112 | #define LOG_WARN base::Logger(__FILE__, __LINE__, base::Logger::WARN).stream() 113 | #define LOG_ERROR base::Logger(__FILE__, __LINE__, base::Logger::ERROR).stream() 114 | #define LOG_FATAL base::Logger(__FILE__, __LINE__, base::Logger::FATAL).stream() 115 | #define LOG_SYSERR base::Logger(__FILE__, __LINE__, false).stream() 116 | #define LOG_SYSFATAL base::Logger(__FILE__, __LINE__, true).stream() 117 | 118 | const char* strerror_tl(int savedErrno); 119 | 120 | // Taken from glog/logging.h 121 | // 122 | // Check that the input is non NULL. This very useful in constructor 123 | // initializer lists. 124 | 125 | #define CHECK_NOTNULL(val) \ 126 | ::base::CheckNotNull(__FILE__, __LINE__, "'" #val "' Must be non NULL", (val)) 127 | 128 | // A small helper for CHECK_NOTNULL(). 129 | template 130 | T* CheckNotNull(Logger::SourceFile file, int line, const char* names, T* ptr) { 131 | if (ptr == NULL) { 132 | Logger(file, line, Logger::FATAL).stream() << names; 133 | } 134 | return ptr; 135 | } 136 | 137 | } // namespace base 138 | 139 | #endif // MUDUO_BASE_LOGGING_H 140 | -------------------------------------------------------------------------------- /core/memory/device_caching_allocator.cpp: -------------------------------------------------------------------------------- 1 | //===- c10/mobile/CPUCachingAllocator.cpp ----------------===// 2 | // 3 | // Part of the Pytorch Project, under the BSD 3-Clause License. 4 | // See https://github.com/pytorch/pytorch/blob/main/LICENSE for license 5 | // information. SPDX-License-Identifier: BSD 3-Clause 6 | 7 | // MoE-Infinity: modified from c10::CPUCachingAllocator. 8 | // replaced c10::alloc_cpu with cudaDeviceAlloc 9 | 10 | #include "device_caching_allocator.h" 11 | #include 12 | #include 13 | #include "utils/logger.h" 14 | 15 | namespace c10 { 16 | namespace DeviceCachingAllocator { 17 | 18 | std::mutex DeviceCachingAllocator::mutex_; 19 | ska::flat_hash_map DeviceCachingAllocator::allocation_map_; 20 | 21 | inline void* DeviceCachingAllocator::allocate_and_cache(const size_t bytes) { 22 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) 23 | void* ptr; 24 | auto cuda_err = cudaMalloc(&ptr, bytes); 25 | if (cuda_err != cudaSuccess) { 26 | free_cached(); 27 | cuda_err = cudaMalloc(&ptr, bytes); 28 | if (cuda_err != cudaSuccess) { 29 | DLOG_ERROR("cudaMalloc failed", bytes, cuda_err); 30 | throw std::runtime_error("cudaMalloc failed"); 31 | } 32 | } 33 | 34 | allocation_map_[ptr] = bytes; 35 | return ptr; 36 | } 37 | 38 | void* DeviceCachingAllocator::allocate(const size_t bytes) { 39 | std::lock_guard guard(mutex_); 40 | const auto& it = available_map_.find(bytes); 41 | if (it == available_map_.end() || it->second.empty()) { 42 | return allocate_and_cache(bytes); 43 | } 44 | return it->second.pop_back_val(); 45 | } 46 | 47 | void DeviceCachingAllocator::free(void* ptr) { 48 | // NB: since we are not really freeing the memory 49 | // the cases such as quantization code freeing original weights 50 | // on mobile, will not quite work, as we likely will hold 51 | // onto that memory. 52 | // NB: We can also enable max memory cached for better memory 53 | // management such that free will actually free the memory if 54 | // we are nearing or above the watermark. 55 | std::lock_guard guard(mutex_); 56 | // If this allocation was done before caching allocator was enabled 57 | // then free regularly 58 | const auto& it = allocation_map_.find(ptr); 59 | if (it == allocation_map_.end()) { 60 | // c10::free_cpu(ptr); 61 | cudaFree(ptr); 62 | return; 63 | } 64 | const size_t alloc_size = it->second; 65 | available_map_[alloc_size].push_back(ptr); 66 | } 67 | 68 | void DeviceCachingAllocator::record_free(void* ptr) { 69 | // This function captures the case when the allocated memory 70 | // is being freed outside the scope of this allocator. 71 | // At the moment only way to capture this is to have the allocator, 72 | // that uses this CachingAllocator as the backing allocator, 73 | // call this function explicitly upon freeing memory while 74 | // outside the scope of caching allocator. 75 | // If the memory is freed in some other way, then we will likely 76 | // have undefined behavior or page fault. But this can be 77 | // the case without caching allocator as well. 78 | std::lock_guard guard(mutex_); 79 | const auto& it = allocation_map_.find(ptr); 80 | if (it != allocation_map_.end()) { 81 | allocation_map_.erase(it); 82 | } 83 | } 84 | 85 | void DeviceCachingAllocator::free_cached() { 86 | for (const auto& it : available_map_) { 87 | for (const auto ptr : it.second) { 88 | // c10::free_cpu(ptr); 89 | cudaFree(ptr); 90 | // When cached memory is return to OS, it must be removed 91 | // from allocation_map. 92 | allocation_map_.erase(ptr); 93 | } 94 | } 95 | available_map_.clear(); 96 | } 97 | 98 | DeviceCachingAllocator::~DeviceCachingAllocator() { free_cached(); } 99 | 100 | std::array caching_allocators; 101 | 102 | DeviceCachingAllocator* get(int device_id) { 103 | if (device_id < 0 || device_id >= 8) { 104 | throw std::runtime_error("Invalid device_id"); 105 | } 106 | if (caching_allocators[device_id] == nullptr) { 107 | caching_allocators[device_id] = new DeviceCachingAllocator(); 108 | } 109 | return caching_allocators[device_id]; 110 | } 111 | 112 | } // namespace DeviceCachingAllocator 113 | 114 | } // namespace c10 115 | -------------------------------------------------------------------------------- /moe_infinity/models/nllb_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) EfficientMoE. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # EfficientMoE Team 5 | 6 | from typing import Dict, Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | from transformers import NllbMoeConfig 11 | from transformers.models.nllb_moe.modeling_nllb_moe import ( 12 | NllbMoeDenseActDense, 13 | NllbMoeTop2Router, 14 | ) 15 | 16 | from moe_infinity.utils import ArcherConfig 17 | 18 | GPU_IDX_COUNTER = 0 19 | 20 | 21 | class SyncNllbMoeSparseMLP(nn.Module): 22 | archer_config: ArcherConfig = None 23 | layer_id: int = None 24 | 25 | def __init__( 26 | self, 27 | config: NllbMoeConfig, 28 | ffn_dim: int, 29 | expert_class: nn.Module = NllbMoeDenseActDense, 30 | ): 31 | super().__init__() 32 | self.router = NllbMoeTop2Router(config) 33 | self.moe_token_dropout = config.moe_token_dropout 34 | self.token_dropout = nn.Dropout(self.moe_token_dropout) 35 | 36 | self.num_experts = config.num_experts 37 | 38 | self.experts = nn.ModuleDict() 39 | for idx in range(self.num_experts): 40 | self.experts[f"expert_{idx}"] = expert_class(config, ffn_dim) 41 | 42 | self.archer_tracer = None 43 | self.archer_engine = None 44 | self.expert_tensor_ids: Dict[int, int] = None 45 | 46 | def forward( 47 | self, 48 | hidden_states: torch.Tensor, 49 | padding_mask: Optional[torch.Tensor] = None, 50 | ): 51 | batch_size, sequence_length, hidden_dim = hidden_states.shape 52 | 53 | top_1_mask, router_probs = self.router(hidden_states, padding_mask) 54 | combining_weights = router_probs.reshape( 55 | (batch_size, sequence_length, self.num_experts) 56 | ) 57 | router_mask = combining_weights.bool() 58 | 59 | next_states = torch.zeros_like(hidden_states) 60 | top_1_expert_index = torch.argmax(top_1_mask, dim=-1) 61 | 62 | # logits_except_top_1 = router_probs.masked_fill( 63 | # top_1_mask.bool(), float("-inf") 64 | # ) 65 | # top_2_expert_index = torch.argmax(logits_except_top_1, dim=-1) 66 | # top_2_mask = torch.nn.functional.one_hot(top_2_expert_index, num_classes=self.num_experts) 67 | 68 | # expert_index = torch.stack( 69 | # [top_1_expert_index, top_2_expert_index], dim=-1 70 | # ) 71 | # expert_index = expert_index.reshape(batch_size, sequence_length, 2) 72 | # for i in range(batch_size): 73 | # seq_id = self.seq_id_list[i] 74 | # expert_matrix = self.expert_predictor.predict( 75 | # seq_id, expert_index[i], self.layer_id 76 | # ) 77 | # self.expert_prefetcher.prefetch_experts( 78 | # self.layer_id, expert_matrix 79 | # ) 80 | 81 | results = self.expert_executor.dispatch_local( 82 | hidden_states, router_mask, self.layer_id 83 | ) 84 | for output, _, idx, _ in results: 85 | token_indices = router_mask[..., idx].bool() 86 | weights = combining_weights[..., idx] 87 | # print(router_mask.shape, combining_weights.shape, hidden_states.shape, flush=True) 88 | # print(output.shape, weights.shape, token_indices.shape, next_states.shape, flush=True) 89 | # print(output.shape, weights[token_indices].shape, next_states[token_indices].shape, flush=True) 90 | next_states[token_indices] += torch.einsum( 91 | "b,be->be", weights[token_indices], output.to(weights.device) 92 | ) 93 | 94 | # for expert_id, expert in self.experts.items(): 95 | # idx = int(expert_id.split("_")[-1]) 96 | # token_indices = router_mask[:, :, idx].bool() 97 | # weights = combining_weights[..., idx] 98 | 99 | # if token_indices.any(): 100 | # expert_output = expert(hidden_states[token_indices]).to(weights.device) 101 | # next_states[token_indices] += torch.einsum("b,be->be", weights[token_indices], expert_output) 102 | 103 | next_states[next_states == 0] = hidden_states[next_states == 0] 104 | hidden_states = next_states 105 | 106 | return hidden_states, ( 107 | router_probs.to("cuda:0", non_blocking=True), 108 | top_1_expert_index.to("cuda:0", non_blocking=True), 109 | ) 110 | -------------------------------------------------------------------------------- /core/base/log_stream.h: -------------------------------------------------------------------------------- 1 | #ifndef MUDUO_BASE_LOGSTREAM_H 2 | #define MUDUO_BASE_LOGSTREAM_H 3 | 4 | #include 5 | #include // memcpy 6 | 7 | #include 8 | 9 | #include "string_piece.h" 10 | #include "types.h" 11 | 12 | namespace base { 13 | 14 | namespace detail { 15 | 16 | const int kSmallBuffer = 4000; 17 | const int kLargeBuffer = 4000 * 1000; 18 | 19 | template 20 | class FixedBuffer : noncopyable { 21 | public: 22 | FixedBuffer() : cur_(data_) { setCookie(cookieStart); } 23 | 24 | ~FixedBuffer() { setCookie(cookieEnd); } 25 | 26 | void append(const char* /*restrict*/ buf, size_t len) { 27 | // FIXME: append partially 28 | if (implicit_cast(avail()) > len) { 29 | memcpy(cur_, buf, len); 30 | cur_ += len; 31 | } else { 32 | memcpy(cur_, buf, avail()); 33 | cur_ += avail(); 34 | } 35 | } 36 | 37 | const char* data() const { return data_; } 38 | int length() const { return static_cast(cur_ - data_); } 39 | 40 | // write to data_ directly 41 | char* current() { return cur_; } 42 | int avail() const { return static_cast(end() - cur_); } 43 | void add(size_t len) { cur_ += len; } 44 | 45 | void reset() { cur_ = data_; } 46 | void bzero() { ::bzero(data_, sizeof data_); } 47 | 48 | // for used by GDB 49 | const char* debugString(); 50 | void setCookie(void (*cookie)()) { cookie_ = cookie; } 51 | // for used by unit test 52 | std::string toString() const { return std::string(data_, length()); } 53 | StringPiece toStringPiece() const { return StringPiece(data_, length()); } 54 | 55 | private: 56 | const char* end() const { return data_ + sizeof data_; } 57 | // Must be outline function for cookies. 58 | static void cookieStart(); 59 | static void cookieEnd(); 60 | 61 | void (*cookie_)(); 62 | char data_[SIZE]; 63 | char* cur_; 64 | }; 65 | 66 | } // namespace detail 67 | 68 | class LogStream : noncopyable { 69 | typedef LogStream self; 70 | 71 | public: 72 | typedef detail::FixedBuffer Buffer; 73 | 74 | self& operator<<(bool v) { 75 | buffer_.append(v ? "1" : "0", 1); 76 | return *this; 77 | } 78 | 79 | self& operator<<(short); 80 | self& operator<<(unsigned short); 81 | self& operator<<(int); 82 | self& operator<<(unsigned int); 83 | self& operator<<(long); 84 | self& operator<<(unsigned long); 85 | self& operator<<(long long); 86 | self& operator<<(unsigned long long); 87 | 88 | self& operator<<(const void*); 89 | 90 | self& operator<<(float v) { 91 | *this << static_cast(v); 92 | return *this; 93 | } 94 | self& operator<<(double); 95 | // self& operator<<(long double); 96 | 97 | self& operator<<(char v) { 98 | buffer_.append(&v, 1); 99 | return *this; 100 | } 101 | 102 | // self& operator<<(signed char); 103 | // self& operator<<(unsigned char); 104 | 105 | self& operator<<(const char* str) { 106 | if (str) { 107 | buffer_.append(str, strlen(str)); 108 | } else { 109 | buffer_.append("(null)", 6); 110 | } 111 | return *this; 112 | } 113 | 114 | self& operator<<(const unsigned char* str) { 115 | return operator<<(reinterpret_cast(str)); 116 | } 117 | 118 | self& operator<<(const std::string& v) { 119 | buffer_.append(v.c_str(), v.size()); 120 | return *this; 121 | } 122 | 123 | self& operator<<(const StringPiece& v) { 124 | buffer_.append(v.data(), v.size()); 125 | return *this; 126 | } 127 | 128 | self& operator<<(const Buffer& v) { 129 | *this << v.toStringPiece(); 130 | return *this; 131 | } 132 | 133 | void append(const char* data, int len) { buffer_.append(data, len); } 134 | const Buffer& buffer() const { return buffer_; } 135 | void resetBuffer() { buffer_.reset(); } 136 | 137 | private: 138 | void staticCheck(); 139 | 140 | template 141 | void formatInteger(T); 142 | 143 | Buffer buffer_; 144 | 145 | static const int kMaxNumericSize = 32; 146 | }; 147 | 148 | class Fmt // : noncopyable 149 | { 150 | public: 151 | template 152 | Fmt(const char* fmt, T val); 153 | 154 | const char* data() const { return buf_; } 155 | int length() const { return length_; } 156 | 157 | private: 158 | char buf_[32]; 159 | int length_; 160 | }; 161 | 162 | inline LogStream& operator<<(LogStream& s, const Fmt& fmt) { 163 | s.append(fmt.data(), fmt.length()); 164 | return s; 165 | } 166 | 167 | } // namespace base 168 | #endif // MUDUO_BASE_LOGSTREAM_H 169 | -------------------------------------------------------------------------------- /core/base/file_util.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2010, Shuo Chen. All rights reserved. 2 | // http://code.google.com/p/muduo/ 3 | // 4 | // Use of this source code is governed by a BSD-style license 5 | // that can be found in the License file. 6 | 7 | // Author: Shuo Chen (chenshuo at chenshuo dot com) 8 | // 9 | 10 | #include "file_util.h" 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include "logging.h" // strerror_tl 20 | 21 | using namespace base; 22 | 23 | FileUtil::AppendFile::AppendFile(StringArg filename) 24 | : fp_(::fopen(filename.c_str(), "ae")), // 'e' for O_CLOEXEC 25 | writtenBytes_(0) { 26 | assert(fp_); 27 | ::setbuffer(fp_, buffer_, sizeof buffer_); 28 | // posix_fadvise POSIX_FADV_DONTNEED ? 29 | } 30 | 31 | FileUtil::AppendFile::~AppendFile() { ::fclose(fp_); } 32 | 33 | void FileUtil::AppendFile::append(const char* logline, const size_t len) { 34 | size_t n = write(logline, len); 35 | size_t remain = len - n; 36 | while (remain > 0) { 37 | size_t x = write(logline + n, remain); 38 | if (x == 0) { 39 | int err = ferror(fp_); 40 | if (err) { 41 | fprintf(stderr, "AppendFile::append() failed %s\n", strerror_tl(err)); 42 | } 43 | break; 44 | } 45 | n += x; 46 | remain = len - n; // remain -= x 47 | } 48 | 49 | writtenBytes_ += len; 50 | } 51 | 52 | void FileUtil::AppendFile::flush() { ::fflush(fp_); } 53 | 54 | size_t FileUtil::AppendFile::write(const char* logline, size_t len) { 55 | // #undef fwrite_unlocked 56 | return ::fwrite_unlocked(logline, 1, len, fp_); 57 | } 58 | 59 | FileUtil::ReadSmallFile::ReadSmallFile(StringArg filename) 60 | : fd_(::open(filename.c_str(), O_RDONLY | O_CLOEXEC)), err_(0) { 61 | buf_[0] = '\0'; 62 | if (fd_ < 0) { 63 | err_ = errno; 64 | } 65 | } 66 | 67 | FileUtil::ReadSmallFile::~ReadSmallFile() { 68 | if (fd_ >= 0) { 69 | ::close(fd_); // FIXME: check EINTR 70 | } 71 | } 72 | 73 | // return errno 74 | template 75 | int FileUtil::ReadSmallFile::readToString(int maxSize, String* content, 76 | int64_t* fileSize, 77 | int64_t* modifyTime, 78 | int64_t* createTime) { 79 | static_assert(sizeof(off_t) == 8, "_FILE_OFFSET_BITS = 64"); 80 | assert(content != NULL); 81 | int err = err_; 82 | if (fd_ >= 0) { 83 | content->clear(); 84 | 85 | if (fileSize) { 86 | struct stat statbuf; 87 | if (::fstat(fd_, &statbuf) == 0) { 88 | if (S_ISREG(statbuf.st_mode)) { 89 | *fileSize = statbuf.st_size; 90 | content->reserve(static_cast( 91 | std::min(implicit_cast(maxSize), *fileSize))); 92 | } else if (S_ISDIR(statbuf.st_mode)) { 93 | err = EISDIR; 94 | } 95 | if (modifyTime) { 96 | *modifyTime = statbuf.st_mtime; 97 | } 98 | if (createTime) { 99 | *createTime = statbuf.st_ctime; 100 | } 101 | } else { 102 | err = errno; 103 | } 104 | } 105 | 106 | while (content->size() < implicit_cast(maxSize)) { 107 | size_t toRead = std::min(implicit_cast(maxSize) - content->size(), 108 | sizeof(buf_)); 109 | ssize_t n = ::read(fd_, buf_, toRead); 110 | if (n > 0) { 111 | content->append(buf_, n); 112 | } else { 113 | if (n < 0) { 114 | err = errno; 115 | } 116 | break; 117 | } 118 | } 119 | } 120 | return err; 121 | } 122 | 123 | int FileUtil::ReadSmallFile::readToBuffer(int* size) { 124 | int err = err_; 125 | if (fd_ >= 0) { 126 | ssize_t n = ::pread(fd_, buf_, sizeof(buf_) - 1, 0); 127 | if (n >= 0) { 128 | if (size) { 129 | *size = static_cast(n); 130 | } 131 | buf_[n] = '\0'; 132 | } else { 133 | err = errno; 134 | } 135 | } 136 | return err; 137 | } 138 | 139 | template int FileUtil::readFile(StringArg filename, int maxSize, 140 | std::string* content, int64_t*, int64_t*, 141 | int64_t*); 142 | 143 | template int FileUtil::ReadSmallFile::readToString(int maxSize, 144 | std::string* content, 145 | int64_t*, int64_t*, 146 | int64_t*); 147 | -------------------------------------------------------------------------------- /core/memory/host_caching_allocator.h: -------------------------------------------------------------------------------- 1 | // c10/mobile/CPUCachingAllocator.h 2 | // 3 | // Part of the Pytorch Project, under the BSD-3-Clause License. 4 | // See https://github.com/pytorch/pytorch/blob/main/LICENSE for license 5 | // information. SPDX-License-Identifier: BSD-3-Clause 6 | 7 | // MoE-Infinity: modified from c10::CPUCachingAllocator. 8 | // replaced CPUCachingAllocator with HostCachingAllocator 9 | 10 | #pragma once 11 | 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | /* 18 | * HostCachingAllocator: 19 | * DISCLAIMER: 20 | * This is subject to change (beta) and only supported on mobile builds. 21 | * If code snippet such as in 'Usage pattern' is used outside of mobile 22 | * build you will not observe the intended behavior. 23 | * See below for more information. 24 | * Why? 25 | * It has been observed that some mobile platforms, such as pixel 3, return 26 | * memory aggressively to the system. This results in page faults in some 27 | * cases and ends up hurting performance. This caching allocator aims to address 28 | * that. Furthermore it also allows users to specify their own allocator by 29 | * implementing allocate/free virtual interfaces. What are the cons? There are 30 | * some cons that were observed where use of caching allocator led to worse 31 | * performance on some platforms. Reason being that the caching mechanism used 32 | * by this allocator left us worse off compared to the corresponding platform's 33 | * tuned memory allocator. In that case it seemed better to not use this 34 | * allocator. Note there are some ideas to fix this in the works. 35 | * 36 | * Usage: 37 | * Usage pattern: 38 | * Instantiate and own the caching allocator. 39 | * std::unique_ptr caching_allocator = 40 | * std::make_unique(); 41 | * Use caching allocator with a scoped guard at inference time. 42 | * { 43 | * WithHostCachingAllocatorGuard(caching_allocator.get()); 44 | * ... model.forward(...); 45 | * } 46 | */ 47 | 48 | namespace c10 { 49 | 50 | namespace HostCachingAllocator { 51 | 52 | class HostCachingAllocator { 53 | /* 54 | * What it does: 55 | * Caches all the allocations carried out by this allocator. 56 | * Cache key is the size of the allocation. 57 | * If requested size is found in the cache returns the cached pointer. 58 | * What it does not do: 59 | * No speculative allocation for any future allocations. 60 | */ 61 | private: 62 | inline void* allocate_and_cache(const size_t bytes); 63 | void free_cached(); 64 | 65 | protected: 66 | // Invariants. 67 | // 1. If memory is ever allocated via this allocator then 68 | // the pointer will exist in allocation_map_, unless the allocator 69 | // returned the memory to OS via free_cached. 70 | // 1.1. Therefore even when the said memory is "freed" via this 71 | // allocator (and thus cached), it will continue to stay 72 | // in allocation_map_. Furthermore it will also exist in 73 | // available_map_. Thus an allocated memory pointer can be in both 74 | // allocation_map_ and available_map_ simultaneously. 75 | // 2. Memory pointer maybe removed from allocation_map_, when it 76 | // is freed outside of the scope of this allocator, but was allocated 77 | // by this allocator. 78 | // 3. Available map only contains that memory which was allocated 79 | // by this allocator and subsequently freed by this allocator. 80 | // As a result of above invariants, allocated memory ptr cannot be in 81 | // available_map_ unless it is in allocation_map_ as well. 82 | ska::flat_hash_map> available_map_; 83 | static ska::flat_hash_map allocation_map_; 84 | // Since allocation_map, which is a global instance, is mutated/read via 85 | // all public APIs we need a global mutex. 86 | static std::mutex mutex_; 87 | 88 | public: 89 | static void record_free(void* ptr); 90 | virtual ~HostCachingAllocator(); 91 | // Checks the cache to see if allocation of size bytes can be found. 92 | // If so return cached memory, else 93 | // allocates memory, records it for caching and returns. 94 | virtual void* allocate(const size_t bytes); 95 | // Checks if the memory being freed is was marked for allocation by 96 | // an earlier call to allocate. If so cache the allocation. 97 | // Otherwise free. 98 | virtual void free(void* ptr); 99 | }; 100 | 101 | extern HostCachingAllocator* caching_allocator; 102 | 103 | extern HostCachingAllocator* get(); 104 | } // namespace HostCachingAllocator 105 | 106 | } // namespace c10 107 | --------------------------------------------------------------------------------