├── .gitignore ├── PlanMoE ├── __init__.py ├── custom │ ├── __init__.py │ ├── comm │ │ ├── abstract.cpp │ │ ├── abstract.h │ │ ├── hetu.cpp │ │ ├── hetu.h │ │ ├── layout_transform.cu │ │ ├── layout_transform.h │ │ ├── naive.cpp │ │ ├── naive.h │ │ ├── pipe.cpp │ │ └── pipe.h │ ├── compressor │ │ ├── abstract.cpp │ │ ├── abstract.h │ │ ├── gpulz.cu │ │ ├── gpulz.h │ │ ├── int8.cpp │ │ ├── int8.h │ │ ├── lz.cpp │ │ ├── lz.h │ │ ├── no.cpp │ │ ├── no.h │ │ ├── zfpc.cpp │ │ └── zfpc.h │ ├── custom_kernel.cpp │ ├── dd_comm.cpp │ ├── dd_comm.h │ ├── jit.cpp │ └── jit.h ├── examples │ ├── fairseq_moe │ │ ├── README.md │ │ ├── fairseq_patch.diff │ │ └── run_fairseq.sh │ ├── launch.py │ ├── megatron │ │ ├── README.md │ │ ├── schemoe_megatron.diff │ │ └── train_schemoe.sh │ ├── pre_test.py │ └── run_mpi.sh ├── experts │ ├── __init__.py │ └── ffn.py ├── gates │ ├── __init__.py │ ├── cosine_top.py │ └── top.py ├── impls │ ├── __init__.py │ ├── communicate.py │ ├── fast_dispatch.py │ ├── jit_compiler.py │ ├── losses.py │ ├── moe_layer.py │ └── overlap.py ├── jit.py ├── jit_kernels │ ├── __init__.py │ ├── gating.py │ └── sparse.py ├── launcher │ ├── __init__.py │ ├── execl.py │ └── run.py ├── moe.py ├── net.py ├── parted │ ├── __init__.py │ ├── backend │ │ ├── __init__.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── executor.py │ ├── patterns.py │ ├── solver.py │ └── spmdx.py └── system.py ├── README.md └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .DS_Store 163 | **/.DS_Store 164 | -------------------------------------------------------------------------------- /PlanMoE/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from . import system as system_init 5 | -------------------------------------------------------------------------------- /PlanMoE/custom/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /PlanMoE/custom/comm/abstract.cpp: -------------------------------------------------------------------------------- 1 | #include "abstract.h" 2 | 3 | AbstractComm::AbstractComm(std::vector *stream, 4 | std::vector g_nccl_comm, 5 | const int &g_world_size, 6 | const int &g_world_rank, 7 | const int &g_local_size, 8 | const int &g_local_rank) : 9 | stream(stream), 10 | g_nccl_comm(g_nccl_comm), 11 | g_world_size(g_world_size), 12 | g_world_rank(g_world_rank), 13 | g_local_size(g_local_size), 14 | g_local_rank(g_local_rank) { 15 | } 16 | 17 | void AbstractComm::pre_comm(const torch::Tensor &input) { 18 | } 19 | 20 | AbstractComm::~AbstractComm() { 21 | } -------------------------------------------------------------------------------- /PlanMoE/custom/comm/abstract.h: -------------------------------------------------------------------------------- 1 | #ifndef ABSTRACT_COMM 2 | #define ABSTRACT_COMM 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | class AbstractComm { 11 | public: 12 | // Declare all public members here 13 | virtual void all_to_all(const torch::Tensor &, const torch::Tensor &, size_t) = 0; 14 | AbstractComm(std::vector *, std::vector, const int &, const int &, const int &, const int &); 15 | virtual void pre_comm(const torch::Tensor &); 16 | virtual ~AbstractComm(); 17 | 18 | std::vector *stream; 19 | std::vector g_nccl_comm; 20 | int g_world_size; 21 | int g_world_rank; 22 | int g_local_size; 23 | int g_local_rank; 24 | }; 25 | 26 | #endif // ABSTRACT_COMM -------------------------------------------------------------------------------- /PlanMoE/custom/comm/hetu.cpp: -------------------------------------------------------------------------------- 1 | #include "hetu.h" 2 | 3 | HeTuComm::HeTuComm(std::vector *stream, 4 | std::vector g_nccl_comm, 5 | const int &g_world_size, 6 | const int &g_world_rank, 7 | const int &g_local_size, 8 | const int &g_local_rank) : 9 | AbstractComm(stream, g_nccl_comm, g_world_size, g_world_rank, g_local_size, g_local_rank) { 10 | } 11 | 12 | void HeTuComm::pre_comm(const torch::Tensor &input) { 13 | } 14 | 15 | void HeTuComm::all_to_all(const torch::Tensor &input, const torch::Tensor &output, size_t length) { 16 | 17 | group_input = torch::empty( 18 | {input.size(0) * g_local_size, input.size(1)}, 19 | torch::TensorOptions().device(input.device()).dtype(at::kFloat), 20 | torch::MemoryFormat::Contiguous); 21 | for (const auto &cuda_stream : *stream) { 22 | c10::cuda::CUDACachingAllocator::recordStream(group_input.storage().data_ptr(), cuda_stream); 23 | } 24 | 25 | group_output = torch::empty( 26 | {input.size(0) * g_local_size, input.size(1)}, 27 | torch::TensorOptions().device(input.device()).dtype(at::kFloat), 28 | torch::MemoryFormat::Contiguous); 29 | for (const auto &cuda_stream : *stream) { 30 | c10::cuda::CUDACachingAllocator::recordStream(group_output.storage().data_ptr(), cuda_stream); 31 | } 32 | 33 | ncclGroupStart(); 34 | if (g_local_rank == 0) { 35 | for (int i = 0; i < g_local_size; i++) { 36 | ncclRecv(((char *)group_input.data_ptr()) + i * input.nbytes(), input.nbytes(), ncclInt8, g_world_rank + i, g_nccl_comm[0], stream->at(0).stream()); 37 | } 38 | } 39 | ncclSend((char *)input.data_ptr(), input.nbytes(), ncclInt8, g_world_rank - g_local_rank, g_nccl_comm[0], stream->at(0).stream()); 40 | ncclGroupEnd(); 41 | if (g_local_rank == 0) { 42 | //std::cout << group_input << std::endl; 43 | _layout_transform( 44 | group_input, 45 | group_output, 46 | g_world_size, 47 | g_local_size, 48 | stream->at(0).stream()); 49 | 50 | int num_nodes = g_world_size / g_local_size; 51 | ncclGroupStart(); 52 | for (int i = 0; i < g_world_size; i += g_local_size) { 53 | ncclRecv(((char *)group_input.data_ptr()) + i * input.nbytes() / num_nodes, group_input.nbytes() / num_nodes, ncclInt8, i, g_nccl_comm[0], stream->at(0).stream()); 54 | ncclSend(((char *)group_output.data_ptr()) + i * input.nbytes() / num_nodes, group_input.nbytes() / num_nodes, ncclInt8, i, g_nccl_comm[0], stream->at(0).stream()); 55 | } 56 | ncclGroupEnd(); 57 | 58 | _reverse_layout_transform( 59 | group_input, 60 | group_output, 61 | g_world_size, 62 | g_local_size, 63 | stream->at(0).stream()); 64 | 65 | cudaStreamSynchronize(stream->at(0).stream()); 66 | } 67 | 68 | ncclGroupStart(); 69 | if (g_local_rank == 0) { 70 | for (int i = 0; i < g_local_size; i++) { 71 | ncclSend(((char *)group_output.data_ptr()) + i * input.nbytes(), input.nbytes(), ncclInt8, g_world_rank + i, g_nccl_comm[0], stream->at(0).stream()); 72 | } 73 | } 74 | ncclRecv((char *)output.data_ptr(), input.nbytes(), ncclInt8, g_world_rank - g_local_rank, g_nccl_comm[0], stream->at(0).stream()); 75 | ncclGroupEnd(); 76 | } -------------------------------------------------------------------------------- /PlanMoE/custom/comm/hetu.h: -------------------------------------------------------------------------------- 1 | #ifndef HETU_COMM 2 | #define HETU_COMM 3 | 4 | #include "abstract.h" 5 | #include "layout_transform.h" 6 | #include 7 | 8 | class HeTuComm : public AbstractComm { 9 | public: 10 | // Declare all public members here 11 | void all_to_all(const torch::Tensor &, const torch::Tensor &, size_t); 12 | HeTuComm(std::vector *, std::vector, const int &, const int &, const int &, const int &); 13 | void pre_comm(const torch::Tensor &); 14 | ~HeTuComm() override = default; 15 | 16 | torch::Tensor group_input, group_output; 17 | }; 18 | 19 | #endif // HETU_COMM -------------------------------------------------------------------------------- /PlanMoE/custom/comm/layout_transform.cu: -------------------------------------------------------------------------------- 1 | #include "layout_transform.h" 2 | 3 | __global__ void layout_transform_kernel(const float *input_data, float *output_data, int samples, int hidden, int g_world_size, int g_local_size) { 4 | int num_nodes = g_world_size / g_local_size; 5 | int data_size_per_gpu = samples / g_local_size; 6 | int data_size_per_gpu_per_node = data_size_per_gpu / (num_nodes); 7 | int data_size_per_gpu_per_gpu = data_size_per_gpu / (g_world_size); 8 | int data_size_per_node = samples / num_nodes; 9 | int gpu_id = 0; 10 | int target_node_id = 0; 11 | int target_gpu_id = 0; 12 | int tmp = 0; 13 | int offset = 0; 14 | for (int i = blockIdx.x; i < samples; i += gridDim.x) { 15 | gpu_id = i / data_size_per_gpu; 16 | tmp = i % data_size_per_gpu; 17 | target_node_id = tmp / data_size_per_gpu_per_node; 18 | tmp = tmp % data_size_per_gpu_per_node; 19 | target_gpu_id = tmp / data_size_per_gpu_per_gpu; 20 | offset = tmp % data_size_per_gpu_per_gpu; 21 | for (int j = threadIdx.x; j < hidden; j += 1024) { 22 | output_data[(target_node_id * data_size_per_node + target_gpu_id * data_size_per_gpu_per_node + gpu_id * data_size_per_gpu_per_gpu + offset) * (hidden) + j] = input_data[i * (hidden) + j]; 23 | } 24 | } 25 | } 26 | 27 | __global__ void reverse_layout_transform_kernel(const float *input_data, float *output_data, int samples, int hidden, int g_world_size, int g_local_size) { 28 | int num_nodes = g_world_size / g_local_size; 29 | int data_size_per_gpu = samples / g_local_size; 30 | int data_size_per_gpu_per_node = data_size_per_gpu / (num_nodes); 31 | int data_size_per_gpu_per_gpu = data_size_per_gpu / (g_world_size); 32 | int data_size_per_node = samples / num_nodes; 33 | int gpu_id = 0; 34 | int target_node_id = 0; 35 | int target_gpu_id = 0; 36 | int tmp = 0; 37 | int offset = 0; 38 | for (int i = blockIdx.x; i < samples; i += gridDim.x) { 39 | target_node_id = i / data_size_per_node; 40 | tmp = i % data_size_per_node; 41 | target_gpu_id = tmp / data_size_per_gpu_per_node; 42 | tmp = tmp % data_size_per_gpu_per_node; 43 | gpu_id = tmp / data_size_per_gpu_per_gpu; 44 | offset = tmp % data_size_per_gpu_per_gpu; 45 | for (int j = threadIdx.x; j < hidden; j += 1024) { 46 | output_data[(target_gpu_id * data_size_per_gpu + target_node_id * data_size_per_gpu_per_node + gpu_id * data_size_per_gpu_per_gpu + offset) * (hidden) + j] = input_data[i * (hidden) + j]; 47 | } 48 | } 49 | } 50 | 51 | void _layout_transform( 52 | torch::Tensor input, 53 | torch::Tensor output, 54 | int g_world_size, 55 | int g_local_size, 56 | cudaStream_t stream) { 57 | int samples = input.size(0); 58 | int hidden = input.size(1); 59 | const float *input_data = (const float *)(input.data_ptr()); 60 | float *output_data = (float *)(output.data_ptr()); 61 | dim3 blocks; 62 | dim3 threads; 63 | blocks.x = 128; 64 | threads.x = 1024; 65 | layout_transform_kernel<<>>(input_data, output_data, samples, hidden, g_world_size, g_local_size); 66 | } 67 | 68 | void _reverse_layout_transform( 69 | torch::Tensor input, 70 | torch::Tensor output, 71 | int g_world_size, 72 | int g_local_size, 73 | cudaStream_t stream) { 74 | int samples = input.size(0); 75 | int hidden = input.size(1); 76 | const float *input_data = (const float *)(input.data_ptr()); 77 | float *output_data = (float *)(output.data_ptr()); 78 | dim3 blocks; 79 | dim3 threads; 80 | blocks.x = 128; 81 | threads.x = 1024; 82 | reverse_layout_transform_kernel<<>>(input_data, output_data, samples, hidden, g_world_size, g_local_size); 83 | } 84 | -------------------------------------------------------------------------------- /PlanMoE/custom/comm/layout_transform.h: -------------------------------------------------------------------------------- 1 | #ifndef LAYOUT_TRANSFORM 2 | #define LAYOUT_TRANSFORM 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | extern "C" void _layout_transform( 9 | torch::Tensor input, 10 | torch::Tensor output, 11 | int g_world_size, 12 | int g_local_size, 13 | cudaStream_t stream); 14 | 15 | extern "C" void _reverse_layout_transform( 16 | torch::Tensor input, 17 | torch::Tensor output, 18 | int g_world_size, 19 | int g_local_size, 20 | cudaStream_t stream); 21 | 22 | #endif // LAYOUT_TRANSFORM -------------------------------------------------------------------------------- /PlanMoE/custom/comm/naive.cpp: -------------------------------------------------------------------------------- 1 | #include "naive.h" 2 | 3 | NaiveComm::NaiveComm(std::vector *stream, 4 | std::vector g_nccl_comm, 5 | const int &g_world_size, const int &g_world_rank, 6 | const int &g_local_size, const int &g_local_rank) : 7 | AbstractComm(stream, g_nccl_comm, g_world_size, g_world_rank, 8 | g_local_size, g_local_rank) { 9 | } 10 | 11 | void NaiveComm::all_to_all(const torch::Tensor &input, 12 | const torch::Tensor &output, size_t length) { 13 | // std::cout << length << " "; 14 | length = length / g_world_size; 15 | // std::cout << length << " " << g_world_size << " " << g_world_rank << " " 16 | // << g_local_size << " " << g_local_rank << std::endl; 17 | CHECK_EQ(0, ncclGroupStart()); 18 | for (int i = 0; i < g_world_size; ++i) { 19 | CHECK_EQ(0, 20 | ncclSend(((char *)input.data_ptr()) + i * length, length, 21 | ncclInt8, i, g_nccl_comm[0], stream->at(0).stream())); 22 | CHECK_EQ(0, 23 | ncclRecv(((char *)output.data_ptr()) + i * length, length, 24 | ncclInt8, i, g_nccl_comm[0], stream->at(0).stream())); 25 | } 26 | CHECK_EQ(0, ncclGroupEnd()); 27 | } -------------------------------------------------------------------------------- /PlanMoE/custom/comm/naive.h: -------------------------------------------------------------------------------- 1 | #ifndef NAIVE_COMM 2 | #define NAIVE_COMM 3 | 4 | #include "abstract.h" 5 | 6 | #undef CHECK_EQ 7 | 8 | #define CHECK_EQ(x, y) AT_ASSERTM((x) == (y), "CHECK_EQ fails.") 9 | 10 | 11 | class NaiveComm : public AbstractComm { 12 | public: 13 | // Declare all public members here 14 | void all_to_all(const torch::Tensor &, const torch::Tensor &, size_t); 15 | NaiveComm(std::vector *, std::vector, const int &, const int &, const int &, const int &); 16 | }; 17 | 18 | #endif // NAIVE_COMM -------------------------------------------------------------------------------- /PlanMoE/custom/comm/pipe.cpp: -------------------------------------------------------------------------------- 1 | #include "pipe.h" 2 | 3 | PipeComm::PipeComm(std::vector *stream, 4 | std::vector g_nccl_comm, 5 | const int &g_world_size, 6 | const int &g_world_rank, 7 | const int &g_local_size, 8 | const int &g_local_rank) : 9 | AbstractComm(stream, g_nccl_comm, g_world_size, g_world_rank, g_local_size, g_local_rank) { 10 | } 11 | 12 | void PipeComm::all_to_all(const torch::Tensor &input, const torch::Tensor &output, size_t length) { 13 | length = length / g_world_size; 14 | CHECK_EQ(0, ncclGroupStart()); 15 | for (int i = 0; i < g_world_size; ++i) { 16 | bool is_intra = (g_world_rank / g_local_size) == (i / g_local_size); 17 | CHECK_EQ(0, ncclSend(((char *)input.data_ptr()) + i * length, 18 | length, 19 | ncclInt8, 20 | i, 21 | g_nccl_comm[is_intra], 22 | stream->at(is_intra).stream())); 23 | CHECK_EQ(0, ncclRecv(((char *)output.data_ptr()) + i * length, 24 | length, 25 | ncclInt8, 26 | i, 27 | g_nccl_comm[is_intra], 28 | stream->at(is_intra).stream())); 29 | } 30 | CHECK_EQ(0, ncclGroupEnd()); 31 | } 32 | -------------------------------------------------------------------------------- /PlanMoE/custom/comm/pipe.h: -------------------------------------------------------------------------------- 1 | #ifndef PIPE_COMM 2 | #define PIPE_COMM 3 | 4 | #include "abstract.h" 5 | 6 | #undef CHECK_EQ 7 | 8 | #define CHECK_EQ(x, y) AT_ASSERTM((x) == (y), "CHECK_EQ fails.") 9 | 10 | 11 | class PipeComm : public AbstractComm { 12 | public: 13 | // Declare all public members here 14 | void all_to_all(const torch::Tensor &, const torch::Tensor &, size_t); 15 | PipeComm(std::vector *, std::vector, const int &, const int &, const int &, const int &); 16 | }; 17 | 18 | #endif // PIPE_COMM 19 | -------------------------------------------------------------------------------- /PlanMoE/custom/compressor/abstract.cpp: -------------------------------------------------------------------------------- 1 | #include "abstract.h" 2 | 3 | AbstractCompressor::AbstractCompressor(std::shared_ptr comm_ptr) { 4 | this->comm_ptr = comm_ptr; 5 | } 6 | 7 | AbstractCompressor::~AbstractCompressor() { 8 | comm_ptr.reset(); 9 | } 10 | 11 | void AbstractCompressor::pre_comm(const at::cuda::CUDAStream *cal_stream) { 12 | } -------------------------------------------------------------------------------- /PlanMoE/custom/compressor/abstract.h: -------------------------------------------------------------------------------- 1 | #ifndef ABSTRACT_COMPRESSOR 2 | #define ABSTRACT_COMPRESSOR 3 | 4 | #include "../comm/abstract.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | class AbstractCompressor { 12 | public: 13 | virtual torch::Tensor compress(const torch::Tensor &) = 0; 14 | virtual torch::Tensor decompress(const torch::Tensor &) = 0; 15 | virtual void all_to_all(const torch::Tensor &, const torch::Tensor &) = 0; 16 | virtual void pre_comm(const at::cuda::CUDAStream *); 17 | AbstractCompressor(std::shared_ptr); 18 | virtual ~AbstractCompressor(); 19 | 20 | torch::Tensor g_output; 21 | std::shared_ptr comm_ptr; 22 | }; 23 | 24 | #endif // ABSTRACT_COMPRESSOR -------------------------------------------------------------------------------- /PlanMoE/custom/compressor/gpulz.h: -------------------------------------------------------------------------------- 1 | #ifndef GPU_LZ 2 | #define GPU_LZ 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | std::vector lz_compress(torch::Tensor input, cudaStream_t stream); 10 | void lz_decompress( 11 | torch::Tensor output, 12 | int numOfBlocks, 13 | torch::Tensor flagArrOffsetGlobal, 14 | torch::Tensor compressedDataOffsetGlobal, 15 | torch::Tensor flagArrGlobal, 16 | torch::Tensor compressedDataGlobal, 17 | cudaStream_t stream); 18 | #endif // GPU_LZ -------------------------------------------------------------------------------- /PlanMoE/custom/compressor/int8.cpp: -------------------------------------------------------------------------------- 1 | #include "int8.h" 2 | 3 | Int8Compressor::Int8Compressor(std::shared_ptr comm_ptr) : 4 | AbstractCompressor(comm_ptr) { 5 | } 6 | 7 | torch::Tensor Int8Compressor::compress(const torch::Tensor &input) { 8 | sizes = input.sizes().vec(); 9 | dtype = input.dtype(); 10 | bias = std::get<0>(torch::min(input, -1, true)); 11 | scale = std::get<0>(torch::max(input, -1, true)); 12 | scale = at::sub(scale, bias); 13 | torch::Tensor output = at::sub(input, bias); 14 | output = at::div(output, scale); 15 | output = at::mul(output, 255.0); 16 | output = output.to(torch::kUInt8); 17 | length = output.nbytes(); 18 | 19 | torch::Tensor fp_output = torch::empty( 20 | {(length + 3) / 4}, 21 | torch::TensorOptions().device(input.device()).dtype(at::kFloat), 22 | torch::MemoryFormat::Contiguous); 23 | cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream().stream(); 24 | cudaMemcpyAsync(fp_output.data_ptr(), output.data_ptr(), length, cudaMemcpyDeviceToDevice, cuda_stream); 25 | for (auto &nccl_stream : *comm_ptr->stream) { 26 | c10::cuda::CUDACachingAllocator::recordStream(bias.storage().data_ptr(), nccl_stream); 27 | } 28 | for (auto &nccl_stream : *comm_ptr->stream) { 29 | c10::cuda::CUDACachingAllocator::recordStream(scale.storage().data_ptr(), nccl_stream); 30 | } 31 | g_bias = at::empty_like(bias); 32 | g_scale = at::empty_like(scale); 33 | for (auto &nccl_stream : *comm_ptr->stream) { 34 | c10::cuda::CUDACachingAllocator::recordStream(g_bias.storage().data_ptr(), nccl_stream); 35 | } 36 | for (auto &nccl_stream : *comm_ptr->stream) { 37 | c10::cuda::CUDACachingAllocator::recordStream(g_scale.storage().data_ptr(), nccl_stream); 38 | } 39 | this->g_output = at::empty_like(fp_output); 40 | for (auto &nccl_stream : *comm_ptr->stream) { 41 | c10::cuda::CUDACachingAllocator::recordStream(g_output.storage().data_ptr(), nccl_stream); 42 | } 43 | return fp_output; 44 | } 45 | 46 | torch::Tensor Int8Compressor::decompress(const torch::Tensor &input) { 47 | torch::Tensor output = torch::empty( 48 | sizes, 49 | torch::TensorOptions().device(input.device()).dtype(torch::kUInt8), 50 | torch::MemoryFormat::Contiguous); 51 | // std::cout << g_bias; 52 | cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream().stream(); 53 | cudaMemcpyAsync(output.data_ptr(), input.data_ptr(), length, cudaMemcpyDeviceToDevice, cuda_stream); 54 | 55 | output = output.to(dtype); 56 | output = at::div(output, 255.0); 57 | output = at::mul(output, g_scale); 58 | output = at::add(output, g_bias); 59 | return output; 60 | } 61 | 62 | void Int8Compressor::pre_comm(const at::cuda::CUDAStream *cal_stream) { 63 | // g_bias = at::empty_like(bias); 64 | // g_scale = at::empty_like(scale); 65 | // c10::cuda::CUDACachingAllocator::recordStream(g_scale.storage().data_ptr(), *cal_stream); 66 | // c10::cuda::CUDACachingAllocator::recordStream(g_bias.storage().data_ptr(), *cal_stream); 67 | } 68 | 69 | void Int8Compressor::all_to_all(const torch::Tensor &input, const torch::Tensor &output) { 70 | // std::cout << (output).data_ptr() << std::endl; 71 | comm_ptr->all_to_all(input, output, length); 72 | comm_ptr->all_to_all(bias, g_bias, bias.nbytes()); 73 | comm_ptr->all_to_all(scale, g_scale, scale.nbytes()); 74 | } -------------------------------------------------------------------------------- /PlanMoE/custom/compressor/int8.h: -------------------------------------------------------------------------------- 1 | #ifndef INT8_COMPRESSOR 2 | #define INT8_COMPRESSOR 3 | 4 | #include "abstract.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | class Int8Compressor : public AbstractCompressor { 14 | public: 15 | // Declare all public members here 16 | torch::Tensor compress(const torch::Tensor &); 17 | torch::Tensor decompress(const torch::Tensor &); 18 | void all_to_all(const torch::Tensor &, const torch::Tensor &); 19 | void pre_comm(const at::cuda::CUDAStream *); 20 | Int8Compressor(std::shared_ptr); 21 | ~Int8Compressor() = default; 22 | 23 | private: 24 | torch::Tensor bias, scale, g_bias, g_scale; 25 | caffe2::TypeMeta dtype; 26 | size_t length; 27 | std::vector sizes; 28 | }; 29 | 30 | #endif // INT8_COMPRESSOR -------------------------------------------------------------------------------- /PlanMoE/custom/compressor/lz.cpp: -------------------------------------------------------------------------------- 1 | #include "lz.h" 2 | 3 | torch::Tensor LzCompressor::compress(const torch::Tensor &input) { 4 | std::vector input_list = at::split(input, input.size(0) / comm_ptr->g_world_size, 0); 5 | std::vector flagArrOffsetGlobalList; 6 | std::vector compressedDataOffsetGlobalList; 7 | std::vector flagArrGlobalList; 8 | std::vector compressedDataGlobalList; 9 | for (auto &slice : input_list) { 10 | std::vector compress_ret = lz_compress(slice, at::cuda::getCurrentCUDAStream()); 11 | // placeholderList.push_back(compress_ret[0]); 12 | flagArrOffsetGlobalList.push_back(compress_ret[0]); 13 | compressedDataOffsetGlobalList.push_back(compress_ret[1]); 14 | flagArrGlobalList.push_back(compress_ret[2]); 15 | compressedDataGlobalList.push_back(compress_ret[3]); 16 | } 17 | flagArrOffsetGlobal = torch::cat(flagArrOffsetGlobalList, 0).contiguous(); 18 | compressedDataOffsetGlobal = torch::cat(compressedDataOffsetGlobalList, 0).contiguous(); 19 | flagArrGlobal = torch::cat(flagArrGlobalList, 0).contiguous(); 20 | compressedDataGlobal = torch::cat(compressedDataGlobalList, 0).contiguous(); 21 | g_flagArrOffsetGlobal = at::empty_like(flagArrOffsetGlobal); 22 | g_compressedDataOffsetGlobal = at::empty_like(compressedDataOffsetGlobal); 23 | g_flagArrGlobal = at::empty_like(flagArrGlobal); 24 | g_compressedDataGlobal = at::empty_like(compressedDataGlobal); 25 | g_output = at::zeros_like(input); 26 | 27 | for (auto &nccl_stream : *comm_ptr->stream) { 28 | c10::cuda::CUDACachingAllocator::recordStream(flagArrOffsetGlobal.storage().data_ptr(), nccl_stream); 29 | c10::cuda::CUDACachingAllocator::recordStream(compressedDataOffsetGlobal.storage().data_ptr(), nccl_stream); 30 | c10::cuda::CUDACachingAllocator::recordStream(flagArrGlobal.storage().data_ptr(), nccl_stream); 31 | c10::cuda::CUDACachingAllocator::recordStream(compressedDataGlobal.storage().data_ptr(), nccl_stream); 32 | c10::cuda::CUDACachingAllocator::recordStream(g_flagArrOffsetGlobal.storage().data_ptr(), nccl_stream); 33 | c10::cuda::CUDACachingAllocator::recordStream(g_compressedDataOffsetGlobal.storage().data_ptr(), nccl_stream); 34 | c10::cuda::CUDACachingAllocator::recordStream(g_flagArrGlobal.storage().data_ptr(), nccl_stream); 35 | c10::cuda::CUDACachingAllocator::recordStream(g_compressedDataGlobal.storage().data_ptr(), nccl_stream); 36 | } 37 | 38 | // for (auto &nccl_stream : *comm_ptr->stream) { 39 | // c10::cuda::CUDACachingAllocator::recordStream(placeholder.storage().data_ptr(), nccl_stream); 40 | // } 41 | return input; 42 | } 43 | torch::Tensor LzCompressor::decompress(const torch::Tensor &input) { 44 | int g_world_size = comm_ptr->g_world_size; 45 | std::vector input_list = at::split(input, input.size(0) / g_world_size, 0); 46 | std::vector flagArrOffsetGlobalList = at::split(g_flagArrOffsetGlobal, g_flagArrOffsetGlobal.size(0) / g_world_size, 0); 47 | std::vector compressedDataOffsetGlobalList = at::split(g_compressedDataOffsetGlobal, g_compressedDataOffsetGlobal.size(0) / g_world_size, 0); 48 | std::vector flagArrGlobalList = at::split(g_flagArrGlobal, g_flagArrGlobal.size(0) / g_world_size, 0); 49 | std::vector compressedDataGlobalList = at::split(g_compressedDataGlobal, g_compressedDataGlobal.size(0) / g_world_size, 0); 50 | int numOfBlocks = flagArrOffsetGlobalList[0].size(0) - 1; 51 | // std::cout << numOfBlocks << std::endl; 52 | for (int i = 0; i < g_world_size; ++i) { 53 | lz_decompress(input_list[i], 54 | numOfBlocks, 55 | flagArrOffsetGlobalList[i], 56 | compressedDataOffsetGlobalList[i], 57 | flagArrGlobalList[i], 58 | compressedDataGlobalList[i], 59 | at::cuda::getCurrentCUDAStream()); 60 | } 61 | return input; 62 | } 63 | 64 | LzCompressor::LzCompressor(std::shared_ptr comm_ptr) : 65 | AbstractCompressor(comm_ptr) { 66 | } 67 | 68 | void LzCompressor::all_to_all(const torch::Tensor &input, const torch::Tensor &output) { 69 | // comm_ptr->all_to_all(input, output, input.nbytes()); 70 | comm_ptr->all_to_all(flagArrOffsetGlobal, g_flagArrOffsetGlobal, flagArrOffsetGlobal.nbytes()); 71 | comm_ptr->all_to_all(compressedDataOffsetGlobal, g_compressedDataOffsetGlobal, compressedDataOffsetGlobal.nbytes()); 72 | comm_ptr->all_to_all(flagArrGlobal, g_flagArrGlobal, flagArrGlobal.nbytes()); 73 | comm_ptr->all_to_all(compressedDataGlobal, g_compressedDataGlobal, compressedDataGlobal.nbytes()); 74 | } -------------------------------------------------------------------------------- /PlanMoE/custom/compressor/lz.h: -------------------------------------------------------------------------------- 1 | #ifndef LZ 2 | #define LZ 3 | #include "abstract.h" 4 | #include "gpulz.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | class LzCompressor : public AbstractCompressor { 14 | public: 15 | // Declare all public members here 16 | torch::Tensor compress(const torch::Tensor &); 17 | torch::Tensor decompress(const torch::Tensor &); 18 | void all_to_all(const torch::Tensor &, const torch::Tensor &); 19 | LzCompressor(std::shared_ptr comm_ptr); 20 | 21 | private: 22 | torch::Tensor flagArrOffsetGlobal, compressedDataOffsetGlobal, flagArrGlobal, compressedDataGlobal; 23 | torch::Tensor g_flagArrOffsetGlobal, g_compressedDataOffsetGlobal, g_flagArrGlobal, g_compressedDataGlobal; 24 | }; 25 | #endif // LZ -------------------------------------------------------------------------------- /PlanMoE/custom/compressor/no.cpp: -------------------------------------------------------------------------------- 1 | #include "no.h" 2 | 3 | torch::Tensor NoCompressor::compress(const torch::Tensor &input) { 4 | g_output = at::empty_like(input); 5 | for (auto &nccl_stream : *comm_ptr->stream) { 6 | c10::cuda::CUDACachingAllocator::recordStream(g_output.storage().data_ptr(), nccl_stream); 7 | } 8 | return input; 9 | } 10 | torch::Tensor NoCompressor::decompress(const torch::Tensor &input) { 11 | return input; 12 | } 13 | 14 | NoCompressor::NoCompressor(std::shared_ptr comm_ptr) : 15 | AbstractCompressor(comm_ptr) { 16 | } 17 | 18 | void NoCompressor::all_to_all(const torch::Tensor &input, const torch::Tensor &output) { 19 | comm_ptr->all_to_all(input, output, input.nbytes()); 20 | } -------------------------------------------------------------------------------- /PlanMoE/custom/compressor/no.h: -------------------------------------------------------------------------------- 1 | #ifndef NO_COMPRESSOR 2 | #define NO_COMPRESSOR 3 | 4 | #include "abstract.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | class NoCompressor : public AbstractCompressor { 13 | public: 14 | // Declare all public members here 15 | torch::Tensor compress(const torch::Tensor &); 16 | torch::Tensor decompress(const torch::Tensor &); 17 | void all_to_all(const torch::Tensor &, const torch::Tensor &); 18 | NoCompressor(std::shared_ptr comm_ptr); 19 | }; 20 | 21 | #endif // NO_COMPRESSOR -------------------------------------------------------------------------------- /PlanMoE/custom/compressor/zfpc.cpp: -------------------------------------------------------------------------------- 1 | #include "zfpc.h" 2 | 3 | ZfpCompressor::ZfpCompressor(std::shared_ptr comm_ptr, 4 | double compress_rate) noexcept 5 | : 6 | AbstractCompressor(comm_ptr) { 7 | this->compress_rate = compress_rate; 8 | this->buffer = nullptr; 9 | this->last_bufsize = 0; 10 | } 11 | 12 | ZfpCompressor::~ZfpCompressor() { 13 | if (buffer) { 14 | // cudaFreeAsync(buffer, at::cuda::getCurrentCUDAStream().stream()); 15 | buffer = nullptr; 16 | } 17 | }; 18 | 19 | void ZfpCompressor::set_compress_rate(const double &compress_rate) { 20 | this->compress_rate = compress_rate; 21 | } 22 | 23 | torch::Tensor ZfpCompressor::compress(const torch::Tensor &input) { 24 | sizes = input.sizes().vec(); 25 | // std::cout << sizes[0] << " " << sizes[1] << " " << sizes[2] << std::endl; 26 | AT_ASSERTM(sizes[0] % 4 == 0, "zfp fails."); 27 | AT_ASSERTM(sizes[1] % 4 == 0, "zfp fails."); 28 | AT_ASSERTM(sizes[2] % 4 == 0, "zfp fails."); 29 | // std::cout << sizes[0] % 4 << " " << sizes[1] % 4 << " " << sizes[2] % 4 30 | // << std::endl; 31 | zfp_type type = zfp_type_float; /* array scalar type */ 32 | zfp_field *field = zfp_field_2d((char *)input.data_ptr(), type, sizes[2], 33 | sizes[1] * sizes[0]); 34 | cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream().stream(); 35 | field->cuda_stream = (void *)&cuda_stream; 36 | zfp_stream *zfp = zfp_stream_open(NULL); 37 | zfp_stream_set_rate(zfp, compress_rate, type, 2, 0); 38 | zfp_stream_set_execution(zfp, zfp_exec_cuda); 39 | size_t bufsize = zfp_stream_maximum_size(zfp, field); 40 | 41 | torch::Tensor output = torch::empty( 42 | {(bufsize + 1) / 2}, 43 | torch::TensorOptions().device(input.device()).dtype(at::kHalf), 44 | torch::MemoryFormat::Contiguous); 45 | 46 | bitstream *stream = stream_open(output.data_ptr(), bufsize); 47 | zfp_stream_set_bit_stream(zfp, stream); 48 | length = zfp_compress(zfp, field); 49 | // std::cout << length << std::endl; 50 | zfp_field_free(field); 51 | zfp_stream_close(zfp); 52 | 53 | g_output = torch::empty_like(output); 54 | for (auto &nccl_stream : *comm_ptr->stream) { 55 | c10::cuda::CUDACachingAllocator::recordStream( 56 | g_output.storage().data_ptr(), nccl_stream); 57 | } 58 | /*output = at::reshape(output, {2, 3, 16}); 59 | output = at::permute(output, {1, 0, 2}); 60 | output = output.contiguous();*/ 61 | 62 | return output; 63 | } 64 | 65 | torch::Tensor ZfpCompressor::decompress(const torch::Tensor &input) { 66 | zfp_type type = zfp_type_float; /* array scalar type */ 67 | // std::cout << sizes; 68 | torch::Tensor output = torch::empty( 69 | sizes, torch::TensorOptions().device(input.device()).dtype(at::kFloat), 70 | torch::MemoryFormat::Contiguous); 71 | // std::cout << sizes; 72 | zfp_field *field = zfp_field_2d((char *)output.data_ptr(), type, sizes[2], 73 | sizes[1] * sizes[0]); 74 | cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream().stream(); 75 | field->cuda_stream = (void *)&cuda_stream; 76 | zfp_stream *zfp = zfp_stream_open(NULL); 77 | zfp_stream_set_rate(zfp, compress_rate, type, 2, 0); 78 | zfp_stream_set_execution(zfp, zfp_exec_cuda); 79 | size_t bufsize = zfp_stream_maximum_size(zfp, field); 80 | bitstream *stream = stream_open(input.data_ptr(), bufsize); 81 | zfp_stream_set_bit_stream(zfp, stream); 82 | zfp_decompress(zfp, field); 83 | zfp_field_free(field); 84 | zfp_stream_close(zfp); 85 | return output; 86 | } 87 | 88 | void ZfpCompressor::all_to_all(const torch::Tensor &input, 89 | const torch::Tensor &output) { 90 | // cudaMemcpyAsync(output.data_ptr(), input.data_ptr(), input.nbytes(), 91 | // cudaMemcpyDeviceToDevice, stream->stream()); 92 | 93 | // std::cout <all_to_all(input, output, length); 95 | } -------------------------------------------------------------------------------- /PlanMoE/custom/compressor/zfpc.h: -------------------------------------------------------------------------------- 1 | #ifndef ZFP_COMPRESSOR 2 | #define ZFP_COMPRESSOR 3 | 4 | #include "abstract.h" 5 | #include "assert.h" 6 | #include "zfp.h" 7 | #include 8 | 9 | class ZfpCompressor : public AbstractCompressor { 10 | public: 11 | // Declare all public members here 12 | torch::Tensor compress(const torch::Tensor &); 13 | torch::Tensor decompress(const torch::Tensor &); 14 | void all_to_all(const torch::Tensor &, const torch::Tensor &); 15 | void set_cuda_stream(const cudaStream_t &); 16 | void set_compress_rate(const double &); 17 | ZfpCompressor(std::shared_ptr, double = 8.0) noexcept; 18 | ~ZfpCompressor(); 19 | 20 | public: 21 | double compress_rate; 22 | void *buffer; 23 | size_t last_bufsize; 24 | std::vector sizes; 25 | size_t length; 26 | }; 27 | 28 | #endif // ZFP_COMPRESSOR -------------------------------------------------------------------------------- /PlanMoE/custom/custom_kernel.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | 6 | #include "zfp.h" 7 | 8 | #if defined(USE_GPU) 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "comm/hetu.h" 18 | #include "comm/naive.h" 19 | #include "comm/pipe.h" 20 | #include "compressor/abstract.h" 21 | #include "compressor/int8.h" 22 | #include "compressor/lz.h" 23 | #include "compressor/no.h" 24 | #include "compressor/zfpc.h" 25 | #include "dd_comm.h" 26 | #include "jit.h" 27 | #else 28 | #undef USE_NCCL 29 | #endif 30 | 31 | #if defined(USE_NCCL) 32 | #include 33 | #endif 34 | 35 | #include 36 | #include 37 | 38 | #if defined(__linux__) 39 | #include 40 | #endif 41 | 42 | #undef CHECK_EQ 43 | #undef CHECK_NE 44 | #undef CHECK_LE 45 | #undef CHECK_CPU 46 | #undef CHECK_CUDA 47 | #undef CHECK_CONTIGUOUS 48 | 49 | #define CHECK_EQ(x, y) AT_ASSERTM((x) == (y), "CHECK_EQ fails.") 50 | #define CHECK_NE(x, y) AT_ASSERTM((x) != (y), "CHECK_NE fails.") 51 | #define CHECK_LE(x, y) AT_ASSERTM((x) <= (y), "CHECK_LE fails.") 52 | #define CHECK_CPU(x) AT_ASSERTM(!x.is_cuda(), #x " must be a CPU tensor") 53 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") 54 | #define CHECK_CONTIGUOUS(x) \ 55 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 56 | 57 | template 58 | static void invoke_cpu(const std::vector &ts, 59 | const std::vector &extra, int kernel_type) { 60 | int samples = extra[0]; 61 | int hidden = extra[1]; 62 | int capacity = extra[2]; 63 | dtype *gates1_s = static_cast(ts[0].data_ptr()); 64 | int *indices1_s = static_cast(ts[1].data_ptr()); 65 | int *locations1_s = static_cast(ts[2].data_ptr()); 66 | dtype *reshaped_input = static_cast(ts[3].data_ptr()); 67 | dtype *dispatched_input = static_cast(ts[4].data_ptr()); 68 | 69 | for (int i = 0; i < (int)ts.size(); ++i) 70 | CHECK_CONTIGUOUS(ts[i]); 71 | 72 | if (kernel_type == 0) { // forward 73 | for (int i = 0; i < samples; ++i) { 74 | if (locations1_s[i] < capacity && indices1_s[i] >= 0) { 75 | for (int j = 0; j < hidden; ++j) { 76 | dispatched_input[(indices1_s[i] * capacity + locations1_s[i]) * 77 | (hidden) + 78 | j] += gates1_s[i] * reshaped_input[i * (hidden) + j]; 79 | } 80 | } 81 | } 82 | } else if (kernel_type == 1) { // backward_data 83 | for (int i = 0; i < samples; ++i) { 84 | if (locations1_s[i] < capacity && indices1_s[i] >= 0) { 85 | for (int j = 0; j < hidden; ++j) { 86 | reshaped_input[i * hidden + j] = 87 | gates1_s[i] * 88 | dispatched_input[(indices1_s[i] * capacity + locations1_s[i]) * 89 | (hidden) + 90 | j]; 91 | } 92 | } else { 93 | for (int j = 0; j < hidden; ++j) { 94 | reshaped_input[i * hidden + j] = 0; 95 | } 96 | } 97 | } 98 | } else { // backward_gate 99 | for (int i = 0; i < samples; ++i) { 100 | gates1_s[i] = 0; 101 | if (locations1_s[i] >= capacity || indices1_s[i] < 0) 102 | continue; 103 | for (int j = 0; j < hidden; ++j) { 104 | gates1_s[i] += 105 | dispatched_input[(indices1_s[i] * capacity + locations1_s[i]) * 106 | (hidden) + 107 | j] * 108 | reshaped_input[i * hidden + j]; 109 | } 110 | } 111 | } 112 | } 113 | 114 | #if defined(USE_NCCL) 115 | 116 | static std::vector g_nccl_comm; 117 | static std::vector> g_cuda_events; 118 | static std::vector g_nccl_stream; 119 | static int g_world_size = 0; 120 | static int g_world_rank = 0; 121 | static int g_local_size = 0; 122 | static int g_local_rank = 0; 123 | 124 | // jit 125 | extern int mem_stride_copy_char_fd; 126 | extern int mem_stride_copy_uint4_fd; 127 | extern int mem_stride_copy_gridsize; 128 | extern int mem_stride_copy_blocksize; 129 | 130 | static size_t get_nccl_unique_id_size() { return sizeof(ncclUniqueId); } 131 | 132 | static void get_nccl_unique_id(torch::Tensor &nccl_unique_id_tensor) { 133 | ncclUniqueId nccl_unique_id; 134 | int num_stream = nccl_unique_id_tensor.size(0); 135 | for (int i = 0; i < num_stream; ++i) { 136 | CHECK_EQ(0, ncclGetUniqueId(&nccl_unique_id)); 137 | // CHECK_CPU(nccl_unique_id_tensor); 138 | // CHECK_EQ(nccl_unique_id_tensor.nbytes(), sizeof(ncclUniqueId)); 139 | memcpy((void *)nccl_unique_id_tensor.data_ptr() + i * sizeof(ncclUniqueId), 140 | &nccl_unique_id, sizeof(ncclUniqueId)); 141 | } 142 | } 143 | 144 | static void init_nccl(const torch::Tensor &nccl_unique_id_tensor, 145 | int world_size, int world_rank, int max_num_split) { 146 | int num_stream = nccl_unique_id_tensor.size(0); 147 | ncclUniqueId nccl_unique_id; 148 | g_nccl_comm.resize(num_stream); 149 | g_cuda_events.resize(num_stream); 150 | CHECK_CPU(nccl_unique_id_tensor); 151 | CHECK_EQ(nccl_unique_id_tensor.nbytes(), num_stream * sizeof(ncclUniqueId)); 152 | for (int i = 0; i < num_stream; ++i) { 153 | memcpy(&nccl_unique_id, 154 | ((void *)nccl_unique_id_tensor.data_ptr()) + 155 | i * sizeof(ncclUniqueId), 156 | sizeof(ncclUniqueId)); 157 | CHECK_EQ(0, ncclGroupStart()); 158 | CHECK_EQ(0, ncclCommInitRank(&g_nccl_comm[i], world_size, nccl_unique_id, 159 | world_rank)); 160 | CHECK_EQ(0, ncclGroupEnd()); 161 | g_nccl_stream.emplace_back(at::cuda::getStreamFromPool()); 162 | g_cuda_events[i].resize(max_num_split); 163 | } 164 | 165 | g_world_size = world_size; 166 | g_world_rank = world_rank; 167 | 168 | if (const char *local_size = std::getenv("LOCAL_SIZE")) { 169 | g_local_size = std::atoi(local_size); 170 | } else { 171 | CHECK_EQ(0, cudaGetDeviceCount(&g_local_size)); 172 | } 173 | CHECK_EQ(0, ncclCommCuDevice(g_nccl_comm[0], &g_local_rank)); 174 | // jit for nccl 175 | jit::jit_init(g_local_rank); 176 | } 177 | 178 | static torch::Tensor ¤t_stream_release(torch::Tensor &tensor, int idx) { 179 | return tensor; 180 | } 181 | 182 | static torch::Tensor ¤t_stream_acquire(torch::Tensor &tensor, int idx) { 183 | return tensor; 184 | } 185 | 186 | static torch::Tensor &nccl_stream_release(torch::Tensor &tensor, int idx) { 187 | return tensor; 188 | } 189 | 190 | static torch::Tensor &nccl_stream_acquire(torch::Tensor &tensor, int idx) { 191 | return tensor; 192 | } 193 | 194 | static AbstractCompressor * 195 | get_compressor(const std::string &name, 196 | std::shared_ptr comm_ptr) { 197 | if (name == "int8") { 198 | return new Int8Compressor(comm_ptr); 199 | } 200 | if (name == "zfp") { 201 | return new ZfpCompressor(comm_ptr); 202 | } 203 | if (name == "lz") { 204 | return new LzCompressor(comm_ptr); 205 | } 206 | return new NoCompressor(comm_ptr); 207 | } 208 | 209 | static std::shared_ptr get_comm(const std::string &name) { 210 | if (name == "dd") { 211 | return std::make_shared(&g_nccl_stream, g_nccl_comm, g_world_size, 212 | g_world_rank, g_local_size, g_local_rank); 213 | } 214 | if (name == "pipe") { 215 | return std::make_shared(&g_nccl_stream, g_nccl_comm, g_world_size, 216 | g_world_rank, g_local_size, g_local_rank); 217 | } 218 | if (name == "hetu") { 219 | return std::make_shared(&g_nccl_stream, g_nccl_comm, g_world_size, 220 | g_world_rank, g_local_size, g_local_rank); 221 | } 222 | return std::make_shared(&g_nccl_stream, g_nccl_comm, g_world_size, 223 | g_world_rank, g_local_size, g_local_rank); 224 | } 225 | 226 | static std::vector compress_ptr_lst; 227 | static int comm_cnt = 0; 228 | static int decompress_cnt = 0; 229 | 230 | static void clear_ptr_lst() { 231 | for (auto &compress_ptr : compress_ptr_lst) { 232 | if (compress_ptr) { 233 | delete compress_ptr; 234 | compress_ptr = nullptr; 235 | } 236 | } 237 | compress_ptr_lst.resize(0); 238 | comm_cnt = 0; 239 | decompress_cnt = 0; 240 | } 241 | 242 | static torch::Tensor compress_operation(const torch::Tensor &input, 243 | const std::string &str, 244 | const std::string &comm_name) { 245 | size_t idx = compress_ptr_lst.size(); 246 | std::shared_ptr comm_ptr = get_comm(comm_name); 247 | // std::make_shared(&g_nccl_stream, g_nccl_comm, g_world_size, 248 | // g_world_rank, g_local_size, g_local_rank); 249 | compress_ptr_lst.emplace_back(get_compressor(str, comm_ptr)); 250 | torch::Tensor after_compress = compress_ptr_lst.back()->compress(input); 251 | for (auto &events : g_cuda_events) { 252 | events[idx].record(at::cuda::getCurrentCUDAStream()); 253 | } 254 | for (auto &nccl_stream : g_nccl_stream) { 255 | c10::cuda::CUDACachingAllocator::recordStream( 256 | after_compress.storage().data_ptr(), nccl_stream); 257 | } 258 | return after_compress; 259 | } 260 | 261 | static torch::Tensor comm_operation(const torch::Tensor &input) { 262 | const int idx = comm_cnt++; 263 | const at::cuda::CUDAStream &original_stream = 264 | at::cuda::getCurrentCUDAStream(); 265 | // std::cout << "????" << std::endl; 266 | compress_ptr_lst[idx]->pre_comm(&original_stream); 267 | 268 | // std::cout << "!!!!" << std::endl; 269 | for (int i = 0; i < g_nccl_stream.size(); ++i) { 270 | g_cuda_events[i][idx].block(g_nccl_stream[i]); 271 | } 272 | compress_ptr_lst[idx]->all_to_all(input, compress_ptr_lst[idx]->g_output); 273 | for (int i = 0; i < g_nccl_stream.size(); ++i) { 274 | g_cuda_events[i][idx].record(g_nccl_stream[i]); 275 | } 276 | return compress_ptr_lst[idx]->g_output; 277 | } 278 | 279 | static torch::Tensor decompress_operation(const torch::Tensor &input) { 280 | const int idx = decompress_cnt++; 281 | for (auto &event : g_cuda_events) { 282 | event[idx].block(at::cuda::getCurrentCUDAStream()); 283 | } 284 | torch::Tensor output = compress_ptr_lst[idx]->decompress(input); 285 | delete compress_ptr_lst[idx]; 286 | compress_ptr_lst[idx] = nullptr; 287 | return output; 288 | } 289 | 290 | #endif 291 | 292 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 293 | #if defined(USE_GPU) 294 | 295 | m.def("update_sdk_home", &jit::update_sdk_home, 296 | "Configure SDK HOME Path for GPU (CUDA)"); 297 | m.def("invoke", &jit::invoke, "Generic Invoke for GPU (CUDA)"); 298 | m.def("inject_source", &jit::inject_source, "Inject Source for GPU (CUDA)"); 299 | #endif 300 | m.def("invoke_cpu_fp32", &invoke_cpu, "Invoke for Sparse Ops (CPU)"); 301 | m.def("invoke_cpu_fp64", &invoke_cpu, "Invoke for Sparse Ops (CPU)"); 302 | #if defined(USE_NCCL) 303 | m.def("get_nccl_unique_id_size", &get_nccl_unique_id_size, 304 | "Get size of ncclUniqueId in bytes"); 305 | m.def("get_nccl_unique_id", &get_nccl_unique_id, 306 | "Get ncclUniqueId for NCCL initialization"); 307 | m.def("init_nccl", &init_nccl, "NCCL initialization"); 308 | m.def("current_stream_release", ¤t_stream_release, 309 | "Record CUDA event on current stream to i-th event slot"); 310 | m.def("current_stream_acquire", ¤t_stream_acquire, 311 | "Let current stream wait CUDA event in i-th event slot"); 312 | m.def("nccl_stream_release", &nccl_stream_release, 313 | "Record CUDA event on NCCL stream to i-th event slot"); 314 | m.def("nccl_stream_acquire", &nccl_stream_acquire, 315 | "Let NCCL stream wait CUDA event in i-th event slot"); 316 | m.def("compress_operation", &compress_operation, "Compress Operation"); 317 | m.def("comm_operation", &comm_operation, "Comm Operation"); 318 | m.def("decompress_operation", &decompress_operation, "Decompress Operation"); 319 | m.def("clear_ptr_lst", &clear_ptr_lst, "Clear Ptr Lst"); 320 | #endif 321 | } 322 | 323 | #if defined(USE_GPU) 324 | #include 325 | #define DEFINE_KERNEL(x, y) \ 326 | static int x = -1; \ 327 | if (x == -1) { \ 328 | x = y; \ 329 | } 330 | 331 | torch::Tensor warp_cumsum(torch::Tensor x) { 332 | CHECK_CUDA(x); 333 | CHECK_EQ(x.dim(), 2); 334 | x = x.to(torch::kInt32).contiguous(); 335 | 336 | auto y = torch::empty_like(x); 337 | 338 | DEFINE_KERNEL(cumsum_fn, jit::inject_source(R"( 339 | extern "C" __global__ void cumsum_fn(int* input0 /* (num_samples, batch_num) */, int* output0 /* (num_samples, batch_num) */, int num_samples) { 340 | #define thread_num 1024 341 | #define batch_num ((int)gridDim.x) 342 | 343 | __shared__ int temp[thread_num + 1]; 344 | int thid = threadIdx.x, bid = blockIdx.x; 345 | int last_sum = -1; 346 | 347 | for (int S = 0; S < num_samples; S += thread_num, output0 += thread_num * batch_num, input0 += thread_num * batch_num) { 348 | int offset = 1; 349 | if (S + thid < num_samples) 350 | temp[thid] = input0[thid * batch_num + bid]; 351 | for (int d = thread_num >> 1; d > 0; d >>= 1) { 352 | __syncthreads(); 353 | if (thid < d) 354 | temp[offset * (2 * thid + 2) - 1] += temp[offset * (2 * thid + 1) - 1]; 355 | offset *= 2; 356 | } 357 | if (thid == 0) 358 | temp[thread_num] = temp[thread_num - 1], temp[thread_num - 1] = 0; 359 | for (int d = 1; d < thread_num; d *= 2) { 360 | offset >>= 1; 361 | __syncthreads(); 362 | if (thid < d) { 363 | int ai = offset * (2 * thid + 1) - 1; 364 | int bi = offset * (2 * thid + 2) - 1; 365 | int t = temp[ai]; 366 | temp[ai] = temp[bi]; 367 | temp[bi] += t; 368 | } 369 | } 370 | __syncthreads(); 371 | if (S + thid < num_samples) 372 | output0[thid * batch_num + bid] = temp[thid + 1] + last_sum; 373 | __syncthreads(); 374 | last_sum += temp[thread_num]; 375 | } 376 | } 377 | )")); 378 | 379 | jit::jit_execute_with_values({x.data_ptr(), y.data_ptr(), (void *)x.size(0)}, 380 | cumsum_fn, x.device().index(), x.size(1), 1024, 381 | nullptr); 382 | return y; 383 | } 384 | 385 | TORCH_LIBRARY(tutel_ops, m) { m.def("cumsum", warp_cumsum); } 386 | #endif 387 | -------------------------------------------------------------------------------- /PlanMoE/custom/dd_comm.cpp: -------------------------------------------------------------------------------- 1 | #include "dd_comm.h" 2 | 3 | DdComm::DdComm(std::vector *stream, 4 | std::vector g_nccl_comm, 5 | const int &g_world_size, const int &g_world_rank, 6 | const int &g_local_size, const int &g_local_rank) 7 | : AbstractComm(stream, g_nccl_comm, g_world_size, g_world_rank, 8 | g_local_size, g_local_rank) {} 9 | 10 | extern int jit::mem_stride_copy_char_fd; 11 | extern int jit::mem_stride_copy_uint4_fd; 12 | extern int jit::mem_stride_copy_gridsize; 13 | extern int jit::mem_stride_copy_blocksize; 14 | 15 | void DdComm::all_to_all(const torch::Tensor &input, const torch::Tensor &output, size_t length) { 16 | size_t slice_size = length / g_world_size; 17 | size_t slice_size_uint4 = slice_size / sizeof(uint4); 18 | 19 | // Save original stream and switch to NCCL stream 20 | // Output tensors must be allocated in NCCL stream context to prevent PyTorch Caching Allocator from recycling it 21 | // const at::cuda::CUDAStream &original_stream = at::cuda::getCurrentCUDAStream(); 22 | // at::cuda::setCurrentCUDAStream(get_nccl_stream()); 23 | 24 | // // Computation stream allocator will add blocking event to nccl stream after nccl kernels 25 | // c10::cuda::CUDACachingAllocator::recordStream(input.storage().data_ptr(), get_nccl_stream()); 26 | 27 | int nranks = g_world_size, ngpus = 4; 28 | CHECK_EQ(0, nranks % ngpus); 29 | int nnodes = nranks / ngpus; 30 | 31 | // torch::Tensor tmp_output = torch::empty_like(input, torch::MemoryFormat::Contiguous); 32 | void *input_buff = (void *)input.data_ptr(); 33 | void *output_buff = (void *)output.data_ptr(); 34 | 35 | if (!(ngpus == 1 || nnodes == 1)) { 36 | int node_rank = g_world_rank / ngpus, local_rank = g_local_rank; 37 | // phase 0. per-gpu (ngpus) stride copy 38 | //std::cout << jit::mem_stride_copy_char_fd << " " << jit::mem_stride_copy_gridsize << " " << jit::mem_stride_copy_blocksize << " " << jit::mem_stride_copy_uint4_fd << std::endl; 39 | slice_size < sizeof(uint4) ? jit::jit_execute( 40 | {&output_buff, &input_buff, &slice_size, &ngpus, &nnodes}, jit::mem_stride_copy_char_fd, 41 | input.device().index(), jit::mem_stride_copy_gridsize, jit::mem_stride_copy_blocksize, stream->at(0).stream()) : 42 | jit::jit_execute( 43 | {&output_buff, &input_buff, &slice_size_uint4, &ngpus, &nnodes}, jit::mem_stride_copy_uint4_fd, 44 | input.device().index(), jit::mem_stride_copy_gridsize, jit::mem_stride_copy_blocksize, stream->at(0).stream()); 45 | 46 | // phase 1. intra-node alltoall 47 | CHECK_EQ(0, ncclGroupStart()); 48 | for (int g = 0; g < ngpus; g++) { 49 | CHECK_EQ(0, ncclSend(((char *)output_buff) + g * nnodes * slice_size, nnodes * slice_size, ncclInt8, g + node_rank * ngpus, g_nccl_comm[0], stream->at(0).stream())); 50 | CHECK_EQ(0, ncclRecv(((char *)input_buff) + g * nnodes * slice_size, nnodes * slice_size, ncclInt8, g + node_rank * ngpus, g_nccl_comm[0], stream->at(0).stream())); 51 | } 52 | CHECK_EQ(0, ncclGroupEnd()); 53 | 54 | // phase 2. per-gpu (nnodes) stride copy 55 | slice_size < sizeof(uint4) ? jit::jit_execute({&output_buff, &input_buff, &slice_size, &nnodes, &ngpus}, jit::mem_stride_copy_char_fd, 56 | input.device().index(), jit::mem_stride_copy_gridsize, jit::mem_stride_copy_blocksize, stream->at(0).stream()) : 57 | jit::jit_execute({&output_buff, &input_buff, &slice_size_uint4, &nnodes, &ngpus}, jit::mem_stride_copy_uint4_fd, 58 | input.device().index(), jit::mem_stride_copy_gridsize, jit::mem_stride_copy_blocksize, stream->at(0).stream()); 59 | 60 | 61 | // phase 3. inter-node alltoall 62 | CHECK_EQ(0, ncclGroupStart()); 63 | for (int n = 0; n < nnodes; n++) { 64 | CHECK_EQ(0, ncclSend(((char *)output_buff) + n * ngpus * slice_size, ngpus * slice_size, ncclInt8, n * ngpus + local_rank, g_nccl_comm[0], stream->at(0).stream())); 65 | CHECK_EQ(0, ncclRecv(((char *)input_buff) + n * ngpus * slice_size, ngpus * slice_size, ncclInt8, n * ngpus + local_rank, g_nccl_comm[0], stream->at(0).stream())); 66 | } 67 | CHECK_EQ(0, ncclGroupEnd()); 68 | 69 | // // Switch to original stream 70 | // at::cuda::setCurrentCUDAStream(original_stream); 71 | 72 | cudaMemcpyAsync(output.data_ptr(), input.data_ptr(), length, cudaMemcpyDeviceToDevice, stream->at(0).stream()); 73 | 74 | 75 | // return input; 76 | } else { 77 | CHECK_EQ(0, ncclGroupStart()); 78 | for (int r = 0; r < nranks; r++) { 79 | CHECK_EQ(0, ncclSend(((char *)input_buff) + r * slice_size, slice_size, ncclInt8, r, g_nccl_comm[0], stream->at(0).stream())); 80 | CHECK_EQ(0, ncclRecv(((char *)output_buff) + r * slice_size, slice_size, ncclInt8, r, g_nccl_comm[0], stream->at(0).stream())); 81 | } 82 | CHECK_EQ(0, ncclGroupEnd()); 83 | 84 | // NCCL stream allocator will add blocking event to computation stream after computation kernels 85 | // c10::cuda::CUDACachingAllocator::recordStream(tmp_output.storage().data_ptr(), original_stream); 86 | 87 | // Switch to original stream 88 | // at::cuda::setCurrentCUDAStream(original_stream); 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /PlanMoE/custom/dd_comm.h: -------------------------------------------------------------------------------- 1 | #ifndef DD_COMM 2 | #define DD_COMM 3 | 4 | #include "comm/abstract.h" 5 | #include "jit.h" 6 | 7 | #undef CHECK_EQ 8 | 9 | #define CHECK_EQ(x, y) AT_ASSERTM((x) == (y), "CHECK_EQ fails.") 10 | 11 | 12 | class DdComm : public AbstractComm { 13 | public: 14 | // Declare all public members here 15 | void all_to_all(const torch::Tensor &, const torch::Tensor &, size_t); 16 | DdComm(std::vector *, std::vector, const int &, const int &, const int &, const int &); 17 | }; 18 | 19 | #endif // DD_COMM -------------------------------------------------------------------------------- /PlanMoE/custom/jit.cpp: -------------------------------------------------------------------------------- 1 | #include "jit.h" 2 | 3 | namespace jit { 4 | 5 | int mem_stride_copy_char_fd = -1; 6 | int mem_stride_copy_uint4_fd = -1; 7 | int mem_stride_copy_gridsize = 1; 8 | int mem_stride_copy_blocksize = 1; 9 | std::string __sdk_home__; 10 | std::vector _gms; 11 | 12 | } -------------------------------------------------------------------------------- /PlanMoE/custom/jit.h: -------------------------------------------------------------------------------- 1 | #ifndef JIT 2 | #define JIT 3 | 4 | #include 5 | #if defined(USE_GPU) 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #if defined(__linux__) 16 | #include 17 | #endif 18 | 19 | #undef CHECK_EQ 20 | #undef CHECK_NE 21 | #undef CHECK_LE 22 | #undef CHECK_CPU 23 | #undef CHECK_CUDA 24 | #undef CHECK_CONTIGUOUS 25 | 26 | #define CHECK_EQ(x, y) AT_ASSERTM((x) == (y), "CHECK_EQ fails.") 27 | #define CHECK_NE(x, y) AT_ASSERTM((x) != (y), "CHECK_NE fails.") 28 | #define CHECK_LE(x, y) AT_ASSERTM((x) <= (y), "CHECK_LE fails.") 29 | #define CHECK_CPU(x) AT_ASSERTM(!x.is_cuda(), #x " must be a CPU tensor") 30 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") 31 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 32 | 33 | namespace jit { 34 | 35 | extern int mem_stride_copy_char_fd; 36 | extern int mem_stride_copy_uint4_fd; 37 | extern int mem_stride_copy_gridsize; 38 | extern int mem_stride_copy_blocksize; 39 | 40 | inline static std::string file_read(const char *path) { 41 | FILE *fp = fopen(path, "rb"); 42 | CHECK_EQ(true, fp != nullptr); 43 | fseek(fp, 0, SEEK_END); 44 | size_t code_size = ftell(fp); 45 | fseek(fp, 0, SEEK_SET); 46 | std::string code; 47 | code.resize(code_size); 48 | CHECK_EQ(code_size, fread((void *)code.data(), 1, code_size, fp)); 49 | fclose(fp); 50 | return code; 51 | } 52 | 53 | inline static void file_write(const char *path, const std::string &code) { 54 | FILE *fp = fopen(path, "wb"); 55 | CHECK_EQ(true, fp != nullptr); 56 | CHECK_EQ(code.size(), fwrite((void *)code.data(), 1, code.size(), fp)); 57 | fclose(fp); 58 | } 59 | 60 | extern std::string __sdk_home__; 61 | 62 | static void update_sdk_home(const torch::Tensor &sdk_path) { 63 | CHECK_CPU(sdk_path); 64 | __sdk_home__ = static_cast(sdk_path.data_ptr()); 65 | } 66 | 67 | inline std::string sdk_path(const std::string &rel = "") { 68 | static std::string cuda_home, cc; 69 | if (cuda_home.size() == 0) { 70 | #if !defined(__HIP_PLATFORM_HCC__) 71 | cc = "bin/nvcc"; 72 | #else 73 | cc = "bin/hipcc"; 74 | #endif 75 | 76 | #if defined(__linux__) 77 | cuda_home = __sdk_home__ + std::string("/"); 78 | #else 79 | cuda_home = __sdk_home__ + std::string("\\"); 80 | #endif 81 | } 82 | if (rel.size() > 0) 83 | return cuda_home + rel; 84 | return cuda_home + cc; 85 | } 86 | 87 | static std::string nvcc_compile(const char *code, const std::string &arch) { 88 | #if defined(__linux__) 89 | char code_path[] = "/tmp/torch-tutel-XXXXXX.cu"; 90 | CHECK_NE(-1, mkstemps(code_path, 3)); 91 | 92 | file_write(code_path, code); 93 | std::string fatbin_path = code_path + std::string(".fatbin"); 94 | 95 | std::string entry = sdk_path(); 96 | if (access(entry.c_str(), F_OK) != 0) { 97 | LOG(FATAL) << "Failed to detect CUDA compiler file: " << entry << ", please set CUDA_HOME environment to configure CUDA SDK location correctly."; 98 | exit(1); 99 | } 100 | pid_t pid = fork(); 101 | if (pid == 0) { 102 | #if !defined(__HIP_PLATFORM_HCC__) 103 | CHECK_EQ(-1, execl(entry.c_str(), entry.c_str(), code_path, "-o", fatbin_path.c_str(), "--fatbin", "-O4", "-gencode", ("arch=compute_" + arch + ",code=sm_" + arch).c_str(), (char *)NULL)); 104 | #else 105 | CHECK_EQ(-1, execl(entry.c_str(), entry.c_str(), code_path, "-o", fatbin_path.c_str(), "--genco", "-O4", "-w", ("--amdgpu-target=" + arch).c_str(), (char *)NULL)); 106 | #endif 107 | exit(1); 108 | } else { 109 | wait(NULL); 110 | } 111 | auto image = file_read(fatbin_path.data()); 112 | unlink(fatbin_path.data()); 113 | unlink(code_path); 114 | return image; 115 | #else 116 | return ""; 117 | #endif 118 | } 119 | 120 | static std::string nvrtc_compile(const char *code, const std::string &arch) { 121 | #if !defined(__HIP_PLATFORM_HCC__) 122 | std::string arch_option = "--gpu-architecture=compute_" + arch, include_path = "--include-path=" + sdk_path("include"); 123 | std::vector param_cstrings = {"--restrict", include_path.c_str(), arch_option.c_str(), "--use_fast_math", "--extra-device-vectorization"}; 124 | #else 125 | std::string arch_option = "--gpu-architecture=" + arch; 126 | std::vector param_cstrings = {arch_option.c_str(), "-O4"}; 127 | #endif 128 | nvrtcProgram prog; 129 | 130 | CHECK_EQ(0, nvrtcCreateProgram(&prog, code, nullptr, 0, nullptr, nullptr)); 131 | nvrtcResult res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); 132 | 133 | size_t log_size; 134 | CHECK_EQ(0, nvrtcGetProgramLogSize(prog, &log_size)); 135 | std::string log; 136 | log.resize(log_size); 137 | CHECK_EQ(0, nvrtcGetProgramLog(prog, &log[0])); 138 | if (0 != res) { 139 | static bool once_flag = false; 140 | if (!once_flag) { 141 | once_flag = true; 142 | LOG(WARNING) << log << " Failed to use NVRTC for JIT compilation in this Pytorch version, try another approach using CUDA compiler.. (To always disable NVRTC, please: export USE_NVRTC=0)"; 143 | } 144 | CHECK_EQ(0, nvrtcDestroyProgram(&prog)); 145 | return ""; 146 | } 147 | 148 | size_t ptx_size; 149 | CHECK_EQ(0, nvrtcGetPTXSize(prog, &ptx_size)); 150 | 151 | std::string ptx; 152 | ptx.resize(ptx_size); 153 | CHECK_EQ(0, nvrtcGetPTX(prog, &ptx[0])); 154 | CHECK_EQ(0, nvrtcDestroyProgram(&prog)); 155 | return ptx; 156 | } 157 | 158 | struct ModuleConfig { 159 | // Handling JIT compilation in Multi-gpu cases 160 | std::vector hFunc; 161 | std::string code, fname; 162 | dim3 blocks, threads; 163 | }; 164 | 165 | extern std::vector _gms; 166 | 167 | inline static CUfunction jit_activate(int fd, int dev) { 168 | //std::cout << (&_gms) << " " << fd << std::endl; 169 | auto &gm = _gms[fd]; 170 | if (gm.hFunc.size() <= dev) 171 | gm.hFunc.resize(dev + 1); 172 | 173 | if (gm.hFunc[dev] == nullptr) { 174 | #if !defined(__HIP_PLATFORM_HCC__) 175 | int major, minor; 176 | CHECK_EQ(0, cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, dev)); 177 | CHECK_EQ(0, cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, dev)); 178 | std::string arch = std::to_string(major) + std::to_string(minor); 179 | #else 180 | hipDeviceProp_t prop; 181 | CHECK_EQ(0, hipGetDeviceProperties(&prop, dev)); 182 | std::string arch = prop.gcnArchName; 183 | #endif 184 | const char *source = gm.code.data(), *pos, *tail; 185 | 186 | int use_nvrtc = getenv("USE_NVRTC") ? std::atoi(getenv("USE_NVRTC")) : 0; 187 | std::string image; 188 | if (use_nvrtc || (image = nvcc_compile(source, arch)) == "") { 189 | image = nvrtc_compile(source, arch); 190 | } 191 | 192 | long launch_bound; 193 | { 194 | char tag[] = " __launch_bounds__("; 195 | const char *pos = strstr(source, tag); 196 | launch_bound = pos ? std::atol(pos + sizeof(tag) - 1) : 1024L; 197 | } 198 | 199 | static CUjit_option options[] = {CU_JIT_OPTIMIZATION_LEVEL, CU_JIT_THREADS_PER_BLOCK}; 200 | static void *values[] = {(void *)4L, (void *)launch_bound}; 201 | 202 | CUmodule hMod = nullptr; 203 | CHECK_EQ(0, cuModuleLoadDataEx(&hMod, image.c_str(), sizeof(options) / sizeof(*options), options, values)); 204 | CHECK_NE(nullptr, hMod); 205 | 206 | CHECK_NE(nullptr, (pos = strstr(source, " void "))); 207 | pos += 6; 208 | CHECK_NE(nullptr, (tail = strchr(pos, '('))); 209 | 210 | std::string fname = std::string(pos, tail - pos); 211 | gm.fname = fname; 212 | CHECK_EQ(0, cuModuleGetFunction(&gm.hFunc[dev], hMod, fname.c_str())); 213 | CHECK_NE(nullptr, gm.hFunc[dev]); 214 | } 215 | 216 | return gm.hFunc[dev]; 217 | } 218 | 219 | static void jit_execute(const std::vector &ppargs, int fd, int dev, const dim3 &blocks, const dim3 &threads, cudaStream_t stream = 0) { 220 | CUfunction hfunc = jit_activate(fd, dev); 221 | CHECK_EQ(0, cuLaunchKernel(hfunc, blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, 0, stream, (void **)ppargs.data(), nullptr)); 222 | } 223 | 224 | static void jit_execute_with_values(const std::vector &pargs, int fd, int dev, const dim3 &blocks, const dim3 &threads, cudaStream_t stream = 0) { 225 | std::vector ppargs(pargs.size()); 226 | for (int i = 0; i < ppargs.size(); ++i) 227 | ppargs[i] = &pargs[i]; 228 | jit_execute(ppargs, fd, dev, blocks, threads, stream); 229 | } 230 | 231 | static int inject_source(const std::string &headless_code) { 232 | int fd = _gms.size(); 233 | _gms.resize(fd + 1); 234 | 235 | auto &gm = _gms[fd]; 236 | #if !defined(__HIP_PLATFORM_HCC__) 237 | gm.code = "#include \n#include \n" + headless_code; 238 | #else 239 | gm.code = "#include \n" + headless_code; 240 | #endif 241 | 242 | const char *source = headless_code.c_str(); 243 | { 244 | char tag[] = "// [thread_extent] blockIdx.x = "; 245 | const char *pos = strstr(source, tag); 246 | gm.blocks.x = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; 247 | } 248 | { 249 | char tag[] = "// [thread_extent] blockIdx.y = "; 250 | const char *pos = strstr(source, tag); 251 | gm.blocks.y = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; 252 | } 253 | { 254 | char tag[] = "// [thread_extent] blockIdx.z = "; 255 | const char *pos = strstr(source, tag); 256 | gm.blocks.z = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; 257 | } 258 | { 259 | char tag[] = "// [thread_extent] threadIdx.x = "; 260 | const char *pos = strstr(source, tag); 261 | gm.threads.x = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; 262 | } 263 | { 264 | char tag[] = "// [thread_extent] threadIdx.y = "; 265 | const char *pos = strstr(source, tag); 266 | gm.threads.y = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; 267 | } 268 | { 269 | char tag[] = "// [thread_extent] threadIdx.z = "; 270 | const char *pos = strstr(source, tag); 271 | gm.threads.z = pos ? std::atoi(pos + sizeof(tag) - 1) : 1; 272 | } 273 | 274 | return fd; 275 | } 276 | 277 | static void invoke(const std::vector &ts, const std::vector &args, const std::vector &blocks, int fd) { 278 | std::vector pargs(ts.size() + args.size()), ppargs(ts.size() + args.size()); 279 | for (int i = 0; i < (int)ts.size(); ++i) { 280 | CHECK_CUDA(ts[i]); 281 | pargs[i] = ts[i].data_ptr(), ppargs[i] = &pargs[i]; 282 | } 283 | for (int i = (int)ts.size(); i < (int)pargs.size(); ++i) { 284 | pargs[i] = (void *)args[i - ts.size()], ppargs[i] = &pargs[i]; 285 | } 286 | 287 | int dev = ts[0].device().index(); 288 | CHECK_EQ(0, cudaSetDevice(dev)); 289 | if (blocks.size() == 0) 290 | jit_execute(ppargs, fd, dev, _gms[fd].blocks, _gms[fd].threads, at::cuda::getDefaultCUDAStream().stream()); 291 | else if (blocks.size() == 1) 292 | jit_execute(ppargs, fd, dev, dim3(blocks[0]), _gms[fd].threads, at::cuda::getDefaultCUDAStream().stream()); 293 | else if (blocks.size() == 2) 294 | jit_execute(ppargs, fd, dev, dim3(blocks[0], blocks[1]), _gms[fd].threads, at::cuda::getDefaultCUDAStream().stream()); 295 | else 296 | jit_execute(ppargs, fd, dev, dim3(blocks[0], blocks[1], blocks[2]), _gms[fd].threads, at::cuda::getDefaultCUDAStream().stream()); 297 | } 298 | 299 | inline void jit_init(int g_local_rank){ 300 | if (mem_stride_copy_uint4_fd == -1) { 301 | std::string mem_stride_copy_cu = R"( 302 | extern "C" __global__ void memStrideCopyKernel( 303 | $T *__restrict__ out, const $T *__restrict__ in, 304 | const size_t size, const int height, const int width) { 305 | const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; 306 | for (size_t i = tid; i < size * height * width; i += gridDim.x * blockDim.x) { 307 | const size_t index = i / size, offset = i % size; 308 | const size_t j = (width * (index % height) + (index / height)) * size + offset; 309 | out[j] = in[i]; 310 | } 311 | } 312 | )"; 313 | mem_stride_copy_char_fd = inject_source(std::regex_replace(mem_stride_copy_cu, std::regex("\\$T"), "char")); 314 | mem_stride_copy_uint4_fd = inject_source(std::regex_replace(mem_stride_copy_cu, std::regex("\\$T"), "uint4")); 315 | CHECK_NE(-1, mem_stride_copy_char_fd); 316 | CHECK_NE(-1, mem_stride_copy_uint4_fd); 317 | CUfunction hfunc = jit_activate(mem_stride_copy_uint4_fd, g_local_rank); 318 | #if !defined(__HIP_PLATFORM_HCC__) 319 | CHECK_EQ(0, cuOccupancyMaxPotentialBlockSize(&mem_stride_copy_gridsize, &mem_stride_copy_blocksize, hfunc, 0, 0, 0)); 320 | #else 321 | CHECK_EQ(0, hipModuleOccupancyMaxPotentialBlockSize(&mem_stride_copy_gridsize, &mem_stride_copy_blocksize, hfunc, 0, 0)); 322 | #endif 323 | } 324 | } 325 | 326 | 327 | } // namespace jit 328 | #endif 329 | 330 | #endif -------------------------------------------------------------------------------- /PlanMoE/examples/fairseq_moe/README.md: -------------------------------------------------------------------------------- 1 | # Training WikiText-103 on fairseq with ScheMoE: 2 | 3 | ## Install fairseq 4 | 5 | ```shell 6 | git clone https://github.com/facebookresearch/fairseq --branch main 7 | cd fairseq/ && git checkout b5e7b250913120409b872a940fbafec4d43c7b13 8 | # This patch is an example to train Fairseq MoE transformers. 9 | # Note that the current patch only works for `legacy_ddp` backend, and `--checkpoint-activations` must be disabled. 10 | git apply ../fairseq_patch.diff 11 | python3 -m pip install omegaconf==2.0.5 hydra-core==1.0.7 12 | python3 -m pip install --no-deps --editable . 13 | ``` 14 | 15 | ## Prepare the dataset 16 | 17 | Download [WikiText-103 dataset](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/): 18 | 19 | ```shell 20 | curl -LO https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip && unzip wikitext-103-v1.zip 21 | ``` 22 | 23 | Preprocess the data: 24 | 25 | ```shell 26 | fairseq-preprocess \ 27 | --only-source \ 28 | --trainpref wikitext-103/wiki.train.tokens \ 29 | --validpref wikitext-103/wiki.valid.tokens \ 30 | --testpref wikitext-103/wiki.test.tokens \ 31 | --destdir ./wikitext-103 \ 32 | --workers 20 33 | 34 | ``` 35 | 36 | ## Train a Model with ScheMoE moe (MOE is moe-freq) 37 | 38 | ```shell 39 | 40 | # Example of Training with 8GPUs (FP32) 41 | MOE=1 L_AUX_WT=0.01 ../run_fairseq.sh ./wikitext-103 42 | 43 | # Example of Training with 8GPUs (FP16) 44 | FP16=1 NO_OVERFLOW=0 MOE=1 L_AUX_WT=0.01 ../run_fairseq.sh ./wikitext-103 45 | 46 | ``` 47 | -------------------------------------------------------------------------------- /PlanMoE/examples/fairseq_moe/fairseq_patch.diff: -------------------------------------------------------------------------------- 1 | diff --git a/fairseq/models/transformer/transformer_decoder.py b/fairseq/models/transformer/transformer_decoder.py 2 | index 61aaa09..458bd40 100644 3 | --- a/fairseq/models/transformer/transformer_decoder.py 4 | +++ b/fairseq/models/transformer/transformer_decoder.py 5 | @@ -115,9 +115,14 @@ class TransformerDecoderBase(FairseqIncrementalDecoder): 6 | self.layers = LayerDropModuleList(p=self.decoder_layerdrop) 7 | else: 8 | self.layers = nn.ModuleList([]) 9 | + 10 | + def config_with_index(cfg, index): 11 | + cfg.transformer_index = index 12 | + return cfg 13 | + 14 | self.layers.extend( 15 | [ 16 | - self.build_decoder_layer(cfg, no_encoder_attn) 17 | + self.build_decoder_layer(config_with_index(cfg, _), no_encoder_attn) 18 | for _ in range(cfg.decoder.layers) 19 | ] 20 | ) 21 | diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py 22 | index 2e687b9..8c1166b 100644 23 | --- a/fairseq/modules/transformer_layer.py 24 | +++ b/fairseq/modules/transformer_layer.py 25 | @@ -324,18 +324,33 @@ class TransformerDecoderLayerBase(nn.Module): 26 | else None 27 | ) 28 | 29 | - self.fc1 = self.build_fc1( 30 | - self.embed_dim, 31 | - cfg.decoder.ffn_embed_dim, 32 | - self.quant_noise, 33 | - self.quant_noise_block_size, 34 | - ) 35 | - self.fc2 = self.build_fc2( 36 | - cfg.decoder.ffn_embed_dim, 37 | - self.embed_dim, 38 | - self.quant_noise, 39 | - self.quant_noise_block_size, 40 | - ) 41 | + self.moe_freq = int(torch.os.environ.get('MOE', 0)) 42 | + self.use_moe = (self.moe_freq > 0) and (cfg.transformer_index + 1) % self.moe_freq == 0 43 | + 44 | + if self.use_moe: 45 | + assert self.quant_noise == 0, "Unhandled quant_noise > 0.0 for MoE layer." 46 | + from schemoe.moe import moe_layer 47 | + self.moe_ffn = moe_layer( 48 | + gate_type={'type' : 'top', 'k' : 2, 'capacity_factor': 0.0, 'fp32_gate': True, 'gate_noise': 1.0}, 49 | + model_dim=self.embed_dim, 50 | + experts={'count_per_node': 1,'type': 'ffn', 'hidden_size_per_expert': cfg.decoder.ffn_embed_dim, 51 | + 'activation_fn' : lambda x: 52 | + self.activation_dropout_module(x) if self.ffn_layernorm is None else self.ffn_layernorm(self.activation_dropout_module(x))}, 53 | + scan_expert_func = lambda name, param: setattr(param, 'expert', True), # The mask is only compatible with Fairseq based on legacy_ddp 54 | + ) 55 | + else: 56 | + self.fc1 = self.build_fc1( 57 | + self.embed_dim, 58 | + cfg.decoder.ffn_embed_dim, 59 | + self.quant_noise, 60 | + self.quant_noise_block_size, 61 | + ) 62 | + self.fc2 = self.build_fc2( 63 | + cfg.decoder.ffn_embed_dim, 64 | + self.embed_dim, 65 | + self.quant_noise, 66 | + self.quant_noise_block_size, 67 | + ) 68 | 69 | self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) 70 | self.need_attn = True 71 | @@ -504,11 +519,18 @@ class TransformerDecoderLayerBase(nn.Module): 72 | if self.normalize_before: 73 | x = self.final_layer_norm(x) 74 | 75 | - x = self.activation_fn(self.fc1(x)) 76 | - x = self.activation_dropout_module(x) 77 | - if self.ffn_layernorm is not None: 78 | - x = self.ffn_layernorm(x) 79 | - x = self.fc2(x) 80 | + if self.use_moe: 81 | + x = self.moe_ffn(x) 82 | + from schemoe import system 83 | + if x.l_aux.requires_grad: 84 | + system.cache().set(id(self.moe_ffn), (x.numel() // x.size(-1), x.l_aux)) 85 | + else: 86 | + x = self.activation_fn(self.fc1(x)) 87 | + x = self.activation_dropout_module(x) 88 | + if self.ffn_layernorm is not None: 89 | + x = self.ffn_layernorm(x) 90 | + x = self.fc2(x) 91 | + 92 | x = self.dropout_module(x) 93 | if self.w_resid is not None: 94 | residual = torch.mul(self.w_resid, residual) 95 | diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py 96 | index 2c4ee32..15c264a 100644 97 | --- a/fairseq/optim/fp16_optimizer.py 98 | +++ b/fairseq/optim/fp16_optimizer.py 99 | @@ -207,6 +207,11 @@ class _FP16OptimizerMixin(object): 100 | 101 | def step(self, closure=None, groups=None): 102 | """Performs a single optimization step.""" 103 | + if int(torch.os.environ.get('NO_OVERFLOW', 0)) > 0: 104 | + for x, y in zip(self.fp16_params, self.fp32_params): 105 | + x.grad[torch.isinf(x.grad)] = 0 106 | + y.grad[torch.isinf(y.grad)] = 0 107 | + 108 | self._sync_fp16_grads_to_fp32() 109 | 110 | if getattr(self, "supports_step_with_scale", False): 111 | diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py 112 | index 273dbdd..4c8f06e 100644 113 | --- a/fairseq/tasks/fairseq_task.py 114 | +++ b/fairseq/tasks/fairseq_task.py 115 | @@ -513,6 +513,16 @@ class FairseqTask(object): 116 | if ignore_grad: 117 | loss *= 0 118 | with torch.autograd.profiler.record_function("backward"): 119 | + from schemoe import system 120 | + l_aux_wt = float(torch.os.environ.get('L_AUX_WT', 0.0)) 121 | + if l_aux_wt: 122 | + l_aux = None 123 | + for samples, x in system.cache().get(): 124 | + x *= l_aux_wt * samples 125 | + l_aux = x if l_aux is None else l_aux + x 126 | + system.cache().reset() 127 | + if l_aux is not None: 128 | + loss += l_aux 129 | optimizer.backward(loss) 130 | return loss, sample_size, logging_output 131 | 132 | -------------------------------------------------------------------------------- /PlanMoE/examples/fairseq_moe/run_fairseq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | 3 | cd $(dirname $0)/fairseq 4 | 5 | if [[ "$FP16" == "1" ]]; then 6 | FLAGS=${FLAGS:---fp16 --fp16-init-scale 4 --fp16-no-flatten-grads} 7 | fi 8 | 9 | python3 -m torch.distributed.launch --nproc_per_node=8 train.py ${@:-./wikitext-103} \ 10 | --ddp-backend legacy_ddp \ 11 | --task language_modeling --tokens-per-sample 256 --batch-size 8 \ 12 | --arch transformer_lm_gpt2_tiny \ 13 | --optimizer adam --adam-betas "(0.9,0.98)" \ 14 | --lr 0.0001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ 15 | --max-update 500000 --log-format json --log-interval 100 \ 16 | ${FLAGS} \ 17 | --save-dir ./fairseq_checkpoints 18 | -------------------------------------------------------------------------------- /PlanMoE/examples/launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | def main(): 6 | host_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 7 | host_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 8 | local_size = int(os.environ.get('LOCAL_SIZE', 1)) 9 | 10 | master_addr = os.environ['MASTER_ADDR'] if host_size > 1 else '127.0.0.1' 11 | master_port = int(os.environ.get('MASTER_PORT', 23232)) 12 | if 'OMP_NUM_THREADS' not in os.environ: 13 | os.environ['OMP_NUM_THREADS'] = str(local_size) 14 | cmd_args = [sys.executable, '-m'] + ['torch.distributed.launch',] + [ 15 | '--nproc_per_node=%d' % local_size, 16 | '--nnodes=%d' % host_size, 17 | '--node_rank=%d' % host_rank, 18 | '--master_addr=%s' % master_addr, 19 | '--master_port=%s' % master_port] + sys.argv[1:] 20 | 21 | os.execl(cmd_args[0], *cmd_args) 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /PlanMoE/examples/megatron/README.md: -------------------------------------------------------------------------------- 1 | # Using ScheMoE with Megatron 2 | 3 | ``` 4 | git clone https://github.com/NVIDIA/Megatron-LM.git 5 | cd Megatron-LM 6 | git checkout core_r0.9.0 7 | git apply ../schemoe_megatron.diff 8 | bash train_schemoe.sh 9 | ``` 10 | -------------------------------------------------------------------------------- /PlanMoE/examples/megatron/schemoe_megatron.diff: -------------------------------------------------------------------------------- 1 | diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py 2 | index eeeb1e3d..0c754477 100644 3 | --- a/megatron/core/transformer/transformer_layer.py 4 | +++ b/megatron/core/transformer/transformer_layer.py 5 | @@ -324,7 +324,12 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): 6 | pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) 7 | 8 | # MLP. 9 | - mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) 10 | + from megatron.training import get_args 11 | + args = get_args() 12 | + if args.schemoe: 13 | + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output), None 14 | + else: 15 | + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) 16 | 17 | # TODO: could we move `bias_dropout_add_exec_handler` itself 18 | # inside the module provided in the `bias_dropout_add_spec` module? 19 | diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py 20 | index 0217d71e..7abb9d0f 100644 21 | --- a/megatron/training/arguments.py 22 | +++ b/megatron/training/arguments.py 23 | @@ -1277,6 +1277,18 @@ def _add_training_args(parser): 24 | help='Disables the Reduce-Scatter overlap with fprop GEMM.', 25 | dest='tp_comm_split_rs') 26 | 27 | + # add arguments 28 | + group.add_argument('--schemoe', action='store_true', 29 | + help='Use ScheMoE') 30 | + group.add_argument('--schemoe-overlap-degree', type=int, default=1, 31 | + help='ScheMoE a2a overlap degree') 32 | + group.add_argument('--schemoe-compress-name', type=str, default='no', 33 | + choices=['no', 'zfp', 'int8'], 34 | + help='ScheMoE compression name') 35 | + group.add_argument('--schemoe-comm-name', type=str, default='naive', 36 | + choices=['naive', 'pipe', 'dd', 'hetu'], 37 | + help='ScheMoE communication name') 38 | + 39 | return parser 40 | 41 | 42 | diff --git a/megatron/training/schemoe_moe_decorator.py b/megatron/training/schemoe_moe_decorator.py 43 | new file mode 100644 44 | index 00000000..d7bfb5b9 45 | --- /dev/null 46 | +++ b/megatron/training/schemoe_moe_decorator.py 47 | @@ -0,0 +1,55 @@ 48 | +from megatron.training import get_args 49 | +import torch.nn.functional as F 50 | +from schemoe.moe import moe_layer 51 | +import torch.distributed as dist 52 | +from .utils import print_rank_0 53 | + 54 | +def schmoe_moe(args, idx): 55 | + hidden_size = args.hidden_size 56 | + ffn_hidden_size = args.ffn_hidden_size 57 | + world_size = dist.get_world_size() 58 | + num_experts = args.num_experts 59 | + if args.moe_expert_capacity_factor is not None and args.moe_expert_capacity_factor > 0: 60 | + capacity_factor = args.moe_expert_capacity_factor 61 | + else: 62 | + capacity_factor = 0.0 63 | + print_rank_0(f"ScheMoE capacity factor: {capacity_factor}") 64 | + expert_per_node = num_experts // world_size 65 | + top_k = args.moe_router_topk 66 | + activation = F.gelu 67 | + compress_name = args.schemoe_compress_name 68 | + comm_name = args.schemoe_comm_name 69 | + overlap_degree = args.schemoe_overlap_degree 70 | + moe_ffn = moe_layer( 71 | + gate_type={ 72 | + 'type' : 'top', 'k' : top_k, 'capacity_factor': capacity_factor, 73 | + 'fp32_gate': True, 'gate_noise': 1.0 74 | + }, 75 | + model_dim=hidden_size, 76 | + experts={ 77 | + 'count_per_node': expert_per_node,'type': 'ffn', 78 | + 'hidden_size_per_expert': ffn_hidden_size, 79 | + 'activation_fn' : lambda x: activation(x) 80 | + }, 81 | + a2a_ffn_overlap_degree = overlap_degree, 82 | + compress_name = compress_name, 83 | + comm_name = comm_name, 84 | + scan_expert_func = lambda name, param: setattr(param, 'allreduce', False), 85 | + ) 86 | + return moe_ffn 87 | + 88 | + 89 | +def schemoe_model_provider(model_provider): 90 | + args = get_args() 91 | + def schemoe_model(pre_process=True, post_process=True): 92 | + model = model_provider() 93 | + 94 | + # for idx, l in enumerate(model.language_model.encoder.layers): 95 | + for idx, l in enumerate(model.decoder.layers): 96 | + l.mlp = schmoe_moe(args, idx) 97 | + 98 | + print_rank_0(f'ScheMoE model:\n{model}') 99 | + return model 100 | + 101 | + return schemoe_model 102 | + 103 | diff --git a/megatron/training/training.py b/megatron/training/training.py 104 | index 5556bb26..9df4c9a5 100644 105 | --- a/megatron/training/training.py 106 | +++ b/megatron/training/training.py 107 | @@ -269,6 +269,12 @@ def pretrain( 108 | args = get_args() 109 | timers = get_timers() 110 | 111 | + # Use ScheMoE 112 | + if args.schemoe: 113 | + from .schemoe_moe_decorator import schemoe_model_provider 114 | + print_rank_0(f"Use ScheMoE Model") 115 | + model_provider = schemoe_model_provider(model_provider) 116 | + 117 | # Track E2E metrics on pretrain start 118 | one_logger_utils.on_pretrain_start() 119 | 120 | -------------------------------------------------------------------------------- /PlanMoE/examples/megatron/train_schemoe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Runs the "175B" parameter model 4 | 5 | export CUDA_DEVICE_MAX_CONNECTIONS=1 6 | 7 | GPUS_PER_NODE=4 8 | # Change for multinode config 9 | MASTER_ADDR=localhost 10 | MASTER_PORT=6001 11 | NUM_NODES=1 12 | NODE_RANK=0 13 | WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) 14 | 15 | VOCAB_FILE=/workspace/datasets/megatron/gpt2-config/vocab.json 16 | MERGE_FILE=/workspace/datasets/megatron/gpt2-config/merges.txt 17 | DATA_PATH=/workspace/datasets/megatron/wikipedia/wikipedia_text_document 18 | 19 | DISTRIBUTED_ARGS=( 20 | --nproc_per_node $GPUS_PER_NODE 21 | --nnodes $NUM_NODES 22 | --master_addr $MASTER_ADDR 23 | --master_port $MASTER_PORT 24 | ) 25 | 26 | GPT_MODEL_ARGS=( 27 | --num-layers 24 28 | --hidden-size 1024 29 | --ffn-hidden-size 4096 30 | --num-attention-heads 16 31 | --seq-length 1024 32 | --max-position-embeddings 1024 33 | ) 34 | 35 | TRAINING_ARGS=( 36 | --micro-batch-size 2 37 | --global-batch-size 32 38 | --disable-bias-linear 39 | --train-iters 60000 40 | --weight-decay 1e-2 41 | --adam-beta1 0.9 42 | --adam-beta2 0.95 43 | --init-method-std 0.006 44 | --clip-grad 1.0 45 | --lr 0.00015 46 | --lr-decay-style cosine 47 | --min-lr 1.0e-5 48 | --lr-warmup-fraction .01 49 | --lr-decay-iters 320000 50 | --transformer-impl local 51 | ) 52 | 53 | MODEL_PARALLEL_ARGS=( 54 | --tensor-model-parallel-size 1 55 | --pipeline-model-parallel-size 1 56 | ) 57 | 58 | MOE_ARGS=( 59 | --num-experts 8 60 | --moe-router-topk 2 61 | --moe-token-dispatcher-type alltoall 62 | --expert-model-parallel-size $WORLD_SIZE 63 | --moe-expert-capacity-factor 1.5 64 | --schemoe 65 | --schemoe-overlap-degree 1 66 | --schemoe-compress-name 'no' 67 | --schemoe-comm-name 'naive' 68 | ) 69 | 70 | 71 | DATA_ARGS=( 72 | --data-path $DATA_PATH 73 | --vocab-file $VOCAB_FILE 74 | --merge-file $MERGE_FILE 75 | --split 949,50,1 76 | ) 77 | 78 | EVAL_AND_LOGGING_ARGS=( 79 | --log-interval 100 80 | --save-interval 10000 81 | --eval-interval 1000 82 | --eval-iters 10 83 | ) 84 | 85 | torchrun ${DISTRIBUTED_ARGS[@]} ./pretrain_gpt.py \ 86 | ${GPT_MODEL_ARGS[@]} \ 87 | ${TRAINING_ARGS[@]} \ 88 | ${MODEL_PARALLEL_ARGS[@]} \ 89 | ${DATA_ARGS[@]} \ 90 | ${MOE_ARGS[@]} \ 91 | ${EVAL_AND_LOGGING_ARGS[@]} \ 92 | 2>&1 | tee ./logs/schemoe.log 93 | 94 | -------------------------------------------------------------------------------- /PlanMoE/examples/pre_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Licensed under the MIT license. 5 | 6 | import os 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from torch import nn, Tensor 11 | import argparse 12 | import schemoe_custom_kernel 13 | import torch.distributed as dist 14 | import math 15 | from contextlib import nullcontext 16 | from typing import Any 17 | import time 18 | 19 | 20 | def decorate_trace_handler(args, rank): 21 | def trace_handler(prof): 22 | print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) 23 | if rank == 0: 24 | prof.export_chrome_trace( 25 | "./batch_size" 26 | + str(args.batch_size) 27 | + "#num_tokens" 28 | + str(args.num_tokens) 29 | + "#model_dim" 30 | + str(args.model_dim) 31 | + "#hidden_size" 32 | + str(args.hidden_size) 33 | + "#num_local_experts" 34 | + str(args.num_local_experts) 35 | + "#capacity_factor" 36 | + str(args.capacity_factor) 37 | + "#a2a_ffn_overlap_degree" 38 | + str(args.a2a_ffn_overlap_degree) 39 | + "#step_num" 40 | + str(prof.step_num) 41 | + ".json" 42 | ) 43 | 44 | return trace_handler 45 | 46 | 47 | parser = argparse.ArgumentParser() 48 | 49 | parser.add_argument("--local_rank", type=int, default=-1) 50 | parser.add_argument("--batch_size", type=int, default=16) 51 | parser.add_argument("--num_tokens", type=int, default=512) 52 | parser.add_argument("--model_dim", type=int, default=2048) 53 | parser.add_argument("--hidden_size", type=int, default=2048) 54 | parser.add_argument("--num_local_experts", type=int, default=2) 55 | parser.add_argument("--dtype", type=str, default="float32") 56 | parser.add_argument("--fp32_gate", default=False, action="store_true") 57 | parser.add_argument("--top", type=int, default=2) 58 | parser.add_argument("--a2a_ffn_overlap_degree", type=int, default=1) 59 | parser.add_argument("--num_steps", type=int, default=25) 60 | parser.add_argument("--capacity_factor", type=float, default=1.0) 61 | parser.add_argument("--parallel_type", type=str, default="auto") 62 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") 63 | parser.add_argument("--use_2dh", default=False, action="store_true") 64 | parser.add_argument("--record_shapes", default=False, action="store_true") 65 | parser.add_argument("--with_stack", default=False, action="store_true") 66 | parser.add_argument("--log", type=str, default="test.log") 67 | parser.add_argument("--encode", type=str, default="no") 68 | 69 | 70 | args = parser.parse_args() 71 | 72 | dist.init_process_group("nccl") 73 | 74 | dist_rank, dist_world_size = dist.get_rank(), dist.get_world_size() 75 | 76 | args.local_rank = os.environ.get("LOCAL_RANK", 0) 77 | 78 | 79 | def dist_print(*args): 80 | if dist_rank == 0: 81 | print(*args) 82 | 83 | 84 | device = torch.device("cuda:%s" % args.local_rank) 85 | torch.cuda.set_device(device) 86 | 87 | torch.set_printoptions(sci_mode=False) 88 | 89 | if args.dtype == "float32": 90 | torch.set_default_dtype(torch.float32) 91 | elif args.dtype == "float64": 92 | torch.set_default_dtype(torch.float64) 93 | elif args.dtype == "float16": 94 | torch.set_default_dtype(torch.float16) 95 | elif args.dtype == "bfloat16": 96 | torch.set_default_dtype(torch.bfloat16) 97 | else: 98 | raise Exception("Unrecognized data type specified: %s" % args.dtype) 99 | 100 | from schemoe.impls import communicate as C 101 | 102 | torch.manual_seed(0) 103 | 104 | 105 | def single_case( 106 | batch_size, 107 | num_tokens, 108 | model_dim, 109 | hidden_size, 110 | num_local_experts, 111 | top_value, 112 | a2a_ffn_overlap_degree, 113 | capacity_factor, 114 | ): 115 | fc1_weight = torch.randn( 116 | num_local_experts, 117 | model_dim, 118 | hidden_size, 119 | dtype=torch.get_default_dtype(), 120 | device=device, 121 | ) 122 | fc2_weight = torch.randn( 123 | num_local_experts, 124 | hidden_size, 125 | model_dim, 126 | dtype=torch.get_default_dtype(), 127 | device=device, 128 | ) 129 | 130 | def zc(x, y): 131 | return (x + y - 1) // y * y 132 | 133 | expert_num = num_local_experts * dist_world_size 134 | x = torch.tensor( 135 | torch.randn( 136 | [ 137 | expert_num, 138 | zc( 139 | int(top_value * math.ceil(batch_size * num_tokens / expert_num) * capacity_factor), 140 | a2a_ffn_overlap_degree if args.encode != "zfp" else a2a_ffn_overlap_degree * 4, 141 | ), 142 | model_dim, 143 | ], 144 | dtype=torch.float32, 145 | device="cpu", 146 | ) 147 | .detach() 148 | .numpy(), 149 | dtype=torch.get_default_dtype(), 150 | requires_grad=False, 151 | device=device, 152 | ) 153 | lst = [] 154 | 155 | tuples = ( 156 | dist_world_size, 157 | args.dtype, 158 | model_dim, 159 | hidden_size, 160 | batch_size * num_tokens, 161 | num_local_experts, 162 | top_value, 163 | a2a_ffn_overlap_degree, 164 | capacity_factor, 165 | device, 166 | ) 167 | dist_print( 168 | "[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, topK = %s, a2a_ffn_overlap_degree = %s, capacity_factor = `%s`, device = `%s`" 169 | % tuples 170 | ) 171 | 172 | if dist_rank == 0: 173 | with open(args.log, "a+") as f: 174 | f.write(str(batch_size) + "," + str(num_tokens) + "," + str(model_dim) + "," + str(hidden_size) + "," + "Naive" + "," + str(capacity_factor) + "," + str(a2a_ffn_overlap_degree) + ",") 175 | C.AllToAllStatus.init(dist.group.WORLD, a2a_ffn_overlap_degree, 1) 176 | with torch.no_grad(): 177 | for _ in range(args.num_steps): 178 | cuda_start = torch.cuda.Event(enable_timing=True) 179 | cuda_end = torch.cuda.Event(enable_timing=True) 180 | schemoe_custom_kernel.clear_ptr_lst() 181 | # cuda_start.record() 182 | input = x.clone() 183 | # y = simple_all_to_all(y) 184 | # y = AllToAll2DAsync.apply(y) 185 | # cuda_end.record() 186 | # torch.cuda.synchronize() 187 | # if dist_rank == 0: 188 | # lst.append(cuda_start.elapsed_time(cuda_end)) 189 | cuda_start.record() 190 | split_size = input.shape[1] // a2a_ffn_overlap_degree 191 | input_split = list(input.split(split_size, dim=1)) 192 | for i in range(a2a_ffn_overlap_degree): 193 | input_split[i] = input_split[i].contiguous() 194 | 195 | # input_size = input_split[i].size() 196 | # input_split[i] = input_split[i].view((-1, input_size[-1])) 197 | # cuda_start.record() 198 | input_split[i] = schemoe_custom_kernel.compress_operation(input_split[i], args.encode, "naive") 199 | # print(input_split[i].storage()) 200 | input_split[i] = schemoe_custom_kernel.comm_operation(input_split[i]) 201 | for i in range(a2a_ffn_overlap_degree): 202 | input_split[i] = schemoe_custom_kernel.decompress_operation(input_split[i]) 203 | # input_split[i] = input_split[i].view(input_size) 204 | # input_split[i] = torch.matmul(input_split[i], fc1_weight) 205 | # input_split[i] = torch.nn.functional.relu(input_split[i]) 206 | # input_split[i] = torch.matmul(input_split[i], fc2_weight) 207 | # input_split[i] = input_split[i].view((-1, input_size[-1])) 208 | input_split[i] = schemoe_custom_kernel.compress_operation(input_split[i], args.encode, "naive") 209 | input_split[i] = schemoe_custom_kernel.comm_operation(input_split[i]) 210 | for i in range(a2a_ffn_overlap_degree): 211 | input_split[i] = schemoe_custom_kernel.decompress_operation(input_split[i]) 212 | # input_split[i] = input_split[i].view(input_size) 213 | output = torch.cat(input_split, dim=1) 214 | print(output - input) 215 | cuda_end.record() 216 | torch.cuda.synchronize() 217 | if dist_rank == 0: 218 | lst.append(cuda_start.elapsed_time(cuda_end)) 219 | torch.distributed.barrier() 220 | if dist_rank == 0: 221 | print("step:", _) 222 | if dist_rank == 0: 223 | with open(args.log, "a+") as f: 224 | f.write(str(lst[5:]) + "\n") 225 | 226 | 227 | # 512, 1024, 2048, 4096, 8192 228 | 229 | for batch_size in [ 230 | 8, 231 | ]: 232 | for num_tokens in [ 233 | 2048, 234 | ]: 235 | for model_dim in [ 236 | 1024, 237 | ]: 238 | for hidden_size in [ 239 | 1024, 240 | ]: 241 | for num_local_experts in [ 242 | 1, 243 | ]: 244 | for top_value in [ 245 | 2, 246 | ]: 247 | for capacity_factor in [ 248 | 1.2, 249 | ]: 250 | single_case( 251 | batch_size, 252 | num_tokens, 253 | model_dim, 254 | hidden_size, 255 | num_local_experts, 256 | top_value, 257 | args.a2a_ffn_overlap_degree, 258 | capacity_factor, 259 | ) 260 | -------------------------------------------------------------------------------- /PlanMoE/examples/run_mpi.sh: -------------------------------------------------------------------------------- 1 | PYTHON=/home/xinglinpan/miniconda3/envs/eurosys2024fall/bin/python 2 | LD_LIBRARY_PATH="/home/xinglinpan/nccl_2.12.12-1+cuda10.2_x86_64/lib:/home/xinglinpan/zfp/build/lib:/usr/local/cuda-10.2/lib64/" 3 | 4 | NNODES=${#ADDR_LIST[@]} 5 | MASTER_ADDR=${ADDR_LIST[0]} 6 | 7 | for a2a_ffn_overlap_degree in 2; do 8 | mpiexec -x PATH=$PATH -x CUDA_HOME=/usr/local/cuda-10.2/ -x NCCL_DEBUG=WARN -x LD_LIBRARY_PATH=$LD_LIBRARY_PATH -x MASTER_ADDR=ethgpu9 -x LOCAL_SIZE=4 --prefix /home/xinglinpan/mpi/openmpi-4.1.4/ --host ethgpu9,ethgpu10 -bind-to none $PYTHON launch.py pre_test.py --a2a_ffn_overlap_degree=$a2a_ffn_overlap_degree --log='test.log' --encode='no' 9 | sleep 5s 10 | done 11 | -------------------------------------------------------------------------------- /PlanMoE/experts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /PlanMoE/experts/ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | from ..net import zero_gather 6 | 7 | 8 | class FusedExpertsNetwork(torch.nn.Module): 9 | def __init__(self, hidden_size_per_expert, activation_fn=None, activation_fn_with_self=None, output_dim=None): 10 | super().__init__() 11 | self.skip_expert = (int(torch.os.environ.get('SKIP_EXPERT', '0')) != 0) 12 | self.hidden_size_per_expert = hidden_size_per_expert 13 | self.output_dim = output_dim 14 | 15 | if activation_fn_with_self is not None: 16 | assert activation_fn is None, "Option `activation_fn_with_self` has been specified, please keep exactly one of them." 17 | def activation_fn(x): return activation_fn_with_self(x, self) 18 | if activation_fn is None: 19 | def activation_fn(x): return F.relu(x) 20 | self.activation_fn = activation_fn 21 | 22 | def update(self, ctx): 23 | if ctx.sharded_count > 1: 24 | assert self.hidden_size_per_expert % ctx.sharded_count == 0, f"Can't evenly divide hidden_size_per_expert ({self.hidden_size_per_expert}) to {ctx.sharded_count} slices." 25 | 26 | hidden_size = self.hidden_size_per_expert // ctx.sharded_count 27 | model_dim = ctx.model_dim 28 | local_experts = ctx.num_local_experts 29 | self.output_dim = self.output_dim or model_dim 30 | 31 | fc1_weight = torch.empty(1, local_experts, model_dim, hidden_size) 32 | fc2_weight = torch.empty( 33 | 1, local_experts, hidden_size, self.output_dim) 34 | fc1_bias = torch.empty(1, local_experts, hidden_size) 35 | fc2_bias = torch.empty( 36 | 1, local_experts, (self.output_dim + ctx.sharded_count - 1) // ctx.sharded_count) 37 | 38 | for i in range(local_experts): 39 | fc1 = torch.nn.Linear(model_dim, hidden_size) 40 | fc2 = torch.nn.Linear(hidden_size, self.output_dim) 41 | fc1_weight[0, i, :, :], fc1_bias[0, 42 | i, :] = fc1.weight.t(), fc1.bias 43 | fc2_weight[0, i, :, :], fc2_bias[0, i, 44 | :] = fc2.weight.t(), fc2.bias[:fc2_bias.size(-1)] 45 | 46 | self.register_parameter( 47 | name='batched_fc1_w', param=torch.nn.Parameter(fc1_weight.squeeze(0))) 48 | self.register_parameter( 49 | name='batched_fc2_w', param=torch.nn.Parameter(fc2_weight.squeeze(0))) 50 | self.register_parameter(name='batched_fc1_bias', 51 | param=torch.nn.Parameter(fc1_bias.squeeze(0))) 52 | self.register_parameter(name='batched_fc2_bias', 53 | param=torch.nn.Parameter(fc2_bias.squeeze(0))) 54 | 55 | def extra_repr(self): 56 | return 'model_dim=%d, hidden_size=%d, output_dim=%d, local_experts=%d' % ( 57 | self.batched_fc1_w.size(1), 58 | self.batched_fc1_w.size(2), 59 | self.batched_fc2_w.size(2), 60 | self.batched_fc1_w.size(0) 61 | ) 62 | 63 | def forward(self, x, ctx): 64 | if self.skip_expert: 65 | return x 66 | 67 | # x = x.to(torch.float32) 68 | 69 | batched_fc1_w = self.batched_fc1_w 70 | batched_fc2_w = self.batched_fc2_w 71 | batched_fc1_bias = self.batched_fc1_bias.unsqueeze(1) 72 | batched_fc2_bias = self.batched_fc2_bias.unsqueeze(1) 73 | 74 | assert ctx.ffn_zero_group is None 75 | 76 | y = torch.add(torch.matmul(x, batched_fc1_w), batched_fc1_bias) 77 | y = self.activation_fn(y) 78 | y = torch.add(torch.matmul(y, batched_fc2_w), batched_fc2_bias) 79 | 80 | # y = y.to(torch.float16) 81 | 82 | return y 83 | 84 | def to(self, *args, **kwargs): 85 | self = super().to(*args, **kwargs) 86 | self.fc1_weight = self.fc1_weight.to(*args, **kwargs) 87 | self.fc2_weight = self.fc2_weight.to(*args, **kwargs) 88 | self.fc1_bias = self.fc1_bias.to(*args, **kwargs) 89 | self.fc2_bias = self.fc2_bias.to(*args, **kwargs) 90 | return self 91 | 92 | 93 | ExpertModule = FusedExpertsNetwork -------------------------------------------------------------------------------- /PlanMoE/gates/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /PlanMoE/gates/cosine_top.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | class CosineTopKGate(torch.nn.Module): 8 | def __init__(self, model_dim, num_global_experts, k=1, fp32_gate=False, proj_dim=256, init_t=0.5, **options): 9 | super(CosineTopKGate, self).__init__() 10 | self.top_k = min(num_global_experts, int(k)) 11 | self.fp32_gate = fp32_gate 12 | self.temperature = torch.nn.Parameter(torch.log(torch.full([1], 1.0 / init_t)), requires_grad=True) 13 | self.cosine_projector = torch.nn.Linear(model_dim, proj_dim) 14 | self.sim_matrix = torch.nn.Parameter(torch.randn(size=(proj_dim, num_global_experts)), requires_grad=True) 15 | torch.nn.init.normal_(self.sim_matrix, 0, 0.01) 16 | 17 | for opt in options: 18 | if opt not in ('capacity_factor', 'gate_noise'): 19 | raise Exception('Unrecognized argument provided to Gating module: %s' % opt) 20 | 21 | def forward(self, x): 22 | if self.fp32_gate: 23 | x = x.float() 24 | cosine_projector = self.cosine_projector.float() 25 | sim_matrix = self.sim_matrix.float() 26 | else: 27 | cosine_projector = self.cosine_projector 28 | sim_matrix = self.sim_matrix 29 | logits = torch.matmul(F.normalize(cosine_projector(x), dim=1), 30 | F.normalize(sim_matrix, dim=0)) 31 | logit_scale = torch.clamp(self.temperature, max=torch.log(torch.tensor(1. / 0.01))).exp() 32 | logits = logits * logit_scale 33 | return logits 34 | 35 | 36 | Gate = CosineTopKGate 37 | -------------------------------------------------------------------------------- /PlanMoE/gates/top.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | 6 | class LinearTopKGate(torch.nn.Module): 7 | def __init__(self, model_dim, num_global_experts, k=1, fp32_gate=False, **options): 8 | super().__init__() 9 | try: 10 | self.wg = torch.nn.Linear(model_dim, num_global_experts, bias=False, dtype=torch.float32 if fp32_gate else None) 11 | except: 12 | self.wg = torch.nn.Linear(model_dim, num_global_experts, bias=False) 13 | self.top_k = min(num_global_experts, int(k)) 14 | self.fp32_gate = fp32_gate 15 | 16 | for opt in options: 17 | if opt not in ('capacity_factor', 'gate_noise'): 18 | raise Exception('Unrecognized argument provided to Gating module: %s' % opt) 19 | 20 | def forward(self, x): 21 | if self.fp32_gate: 22 | x = x.float() 23 | wg = self.wg.float() 24 | else: 25 | wg = self.wg 26 | return wg(x) 27 | 28 | Gate = LinearTopKGate 29 | -------------------------------------------------------------------------------- /PlanMoE/impls/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /PlanMoE/impls/fast_dispatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast 5 | 6 | import logging 7 | import torch 8 | from torch import Tensor 9 | 10 | from .jit_compiler import IS_HIP_EXTENSION 11 | from ..jit_kernels import sparse as jit_kernel 12 | from ..jit_kernels.gating import fast_cumsum_sub_one 13 | from .communicate import get_world_rank, simple_all_reduce 14 | from . import losses 15 | 16 | class GatingEncoder(torch.autograd.Function): 17 | @staticmethod 18 | def forward(ctx: Any, config: Any, reshaped_input: Tensor, *gates_): 19 | ctx.reshaped_input = reshaped_input 20 | ctx.config = config 21 | if gates_: 22 | ctx.gates_h2 = [x.view(-1, 1).repeat(1, 2) if x.dtype == torch.float16 else x for x in gates_] 23 | else: 24 | ctx.gates_h2 = [ctx.config.ones_helper] * len(ctx.config.indices_) 25 | 26 | dispatched_input = torch.zeros([ctx.config.num_global_experts * ctx.config.capacity, ctx.config.model_dim], dtype=reshaped_input.dtype, device=reshaped_input.device) 27 | for g, i, l in zip(ctx.gates_h2, ctx.config.indices_, ctx.config.locations_): 28 | ctx.config.func_fwd(g, i, l, reshaped_input, dispatched_input, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity]) 29 | return dispatched_input 30 | 31 | @staticmethod 32 | def backward(ctx: Any, dispatched_input: Tensor): 33 | dispatched_input = dispatched_input.contiguous() 34 | last_result = None 35 | for g, i, l in zip(ctx.gates_h2, ctx.config.indices_, ctx.config.locations_): 36 | grad_data = torch.empty(ctx.reshaped_input.shape, dtype=dispatched_input.dtype, device=dispatched_input.device) 37 | ctx.config.func_bwd_data(g, i, l, grad_data, dispatched_input, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity]) 38 | last_result = grad_data if last_result is None else last_result + grad_data 39 | 40 | grad_gates = [] 41 | if id(ctx.gates_h2[0]) != id(ctx.config.ones_helper): 42 | for i, l in zip(ctx.config.indices_, ctx.config.locations_): 43 | grad_gates1_s = torch.empty([ctx.config.sample_size,], dtype=dispatched_input.dtype, device=dispatched_input.device) 44 | ctx.config.func_bwd_gate(grad_gates1_s, i, l, ctx.reshaped_input, dispatched_input, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity]) 45 | grad_gates.append(grad_gates1_s) 46 | return (None, last_result, *grad_gates) 47 | 48 | 49 | class GatingDecoder(torch.autograd.Function): 50 | @staticmethod 51 | def forward(ctx: Any, config: Any, expert_output: Tensor, *gates_): 52 | ctx.expert_output = expert_output 53 | ctx.config = config 54 | if gates_: 55 | ctx.gates_h2 = [x.view(-1, 1).repeat(1, 2) if x.dtype == torch.float16 else x for x in gates_] 56 | else: 57 | ctx.gates_h2 = [ctx.config.ones_helper] * len(ctx.config.indices_) 58 | 59 | last_result = None 60 | for g, i, l in zip(ctx.gates_h2, ctx.config.indices_, ctx.config.locations_): 61 | single_output = torch.empty([config.sample_size, config.model_dim], dtype=expert_output.dtype, device=expert_output.device) 62 | config.func_bwd_data(g, i, l, single_output, expert_output, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity]) 63 | last_result = single_output if last_result is None else last_result + single_output 64 | return last_result 65 | 66 | @staticmethod 67 | def backward(ctx: Any, combined_output: Tensor): 68 | combined_output = combined_output.contiguous() 69 | grad_expert_output = torch.zeros(ctx.expert_output.shape, dtype=combined_output.dtype, device=combined_output.device) 70 | for g, i, l in zip(ctx.gates_h2, ctx.config.indices_, ctx.config.locations_): 71 | ctx.config.func_fwd(g, i, l, combined_output, grad_expert_output, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity]) 72 | 73 | grad_gates = [] 74 | if id(ctx.gates_h2[0]) != id(ctx.config.ones_helper): 75 | for i, l in zip(ctx.config.indices_, ctx.config.locations_): 76 | grad_gates1_s = torch.empty([ctx.config.sample_size,], dtype=combined_output.dtype, device=combined_output.device) 77 | ctx.config.func_bwd_gate(grad_gates1_s, i, l, combined_output, ctx.expert_output, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity]) 78 | grad_gates.append(grad_gates1_s) 79 | return (None, grad_expert_output, *grad_gates) 80 | 81 | 82 | class TutelMoeFastDispatcher: 83 | 84 | kernel_pool = dict() 85 | ones_helper = None 86 | 87 | def __init__(self, num_global_experts, capacity, model_dim, dispatch_dtype): 88 | self.num_global_experts = int(num_global_experts) 89 | self.capacity = int(capacity) 90 | self.model_dim = int(model_dim) 91 | self.dtype = dispatch_dtype 92 | if IS_HIP_EXTENSION or dispatch_dtype != torch.float16: 93 | self.dtype = torch.float32 94 | self.original_dtype = dispatch_dtype 95 | self.aligned_dim = model_dim // (2 if self.dtype == torch.float16 else 1) 96 | self.is_cuda = None 97 | 98 | def update(self, indices_, locations_, gates_, capacity=None, is_postscore=True): 99 | self.indices_ = [x.to(torch.int32).view(-1) for x in indices_] 100 | self.locations_ = [x.to(torch.int32) for x in locations_] 101 | self.gates_ = [x.to(self.dtype) for x in gates_] 102 | self.is_postscore = is_postscore 103 | self.sample_size, self.capacity = int(self.indices_[0].size(0)), int(capacity) or self.capacity 104 | 105 | if self.is_cuda != indices_[0].is_cuda: 106 | self.is_cuda = indices_[0].is_cuda 107 | if self.is_cuda not in TutelMoeFastDispatcher.kernel_pool: 108 | self.func_fwd = jit_kernel.create_forward(self.dtype, indices_[0].is_cuda) 109 | self.func_bwd_data = jit_kernel.create_backward_data(self.dtype, indices_[0].is_cuda) 110 | self.func_bwd_gate = jit_kernel.create_backward_gate(self.dtype, indices_[0].is_cuda) 111 | TutelMoeFastDispatcher.kernel_pool[self.is_cuda] = self.func_fwd, self.func_bwd_data, self.func_bwd_gate 112 | else: 113 | self.func_fwd, self.func_bwd_data, self.func_bwd_gate = TutelMoeFastDispatcher.kernel_pool[self.is_cuda] 114 | 115 | if TutelMoeFastDispatcher.ones_helper is None or TutelMoeFastDispatcher.ones_helper.size(0) < self.sample_size: 116 | TutelMoeFastDispatcher.ones_helper = torch.ones([self.sample_size, 2], dtype=self.dtype, device=self.indices_[0].device) 117 | if TutelMoeFastDispatcher.ones_helper.is_cuda != self.indices_[0].is_cuda: 118 | TutelMoeFastDispatcher.ones_helper = torch.ones([TutelMoeFastDispatcher.ones_helper.size(0), 2], dtype=self.dtype, device=self.indices_[0].device) 119 | self.ones_helper = TutelMoeFastDispatcher.ones_helper 120 | 121 | def encode(self, data): 122 | if self.is_postscore: 123 | return GatingEncoder.apply(self, data.to(self.dtype)).to(self.original_dtype) 124 | else: 125 | return GatingEncoder.apply(self, data.to(self.dtype), *self.gates_).to(self.original_dtype) 126 | 127 | def decode(self, data): 128 | if self.is_postscore: 129 | return GatingDecoder.apply(self, data.to(self.dtype), *self.gates_).to(self.original_dtype) 130 | else: 131 | return GatingDecoder.apply(self, data.to(self.dtype)).to(self.original_dtype) 132 | 133 | fast_dispatcher = TutelMoeFastDispatcher 134 | 135 | def compute_sorted_location(x, importance_scores): 136 | sorted_x = x[importance_scores.argsort(dim=0)] 137 | sorted_cumsum = fast_cumsum_sub_one(sorted_x) * sorted_x 138 | return sorted_cumsum[importance_scores.argsort(dim=0).argsort(dim=0)] 139 | 140 | def extract_critical(scores, top_k, loss_fn=losses.gshard_loss, capacity_factor=1.0, batch_prioritized_routing=False, normalize_gate=True, group=None, alignment=1): 141 | num_global_experts = int(scores.size(1)) 142 | top_k, top_k_original = min(top_k, num_global_experts), top_k 143 | topk_indices = torch.topk(scores, top_k, dim=1).indices 144 | 145 | indices_s = [x.view(-1) for x in topk_indices.chunk(top_k, dim=1)] 146 | 147 | masks_se = [losses._one_hot_with_dtype(x, num_classes=num_global_experts, dtype=x.dtype) for x in indices_s] 148 | gates_s = [(scores * x).sum(dim=1) for x in masks_se] 149 | 150 | l_loss = loss_fn(scores, topk_indices) if loss_fn is not None else None 151 | 152 | if batch_prioritized_routing: 153 | importance_scores = -1 * scores.max(dim=1)[0] 154 | compute_location = lambda x: compute_sorted_location(x, importance_scores) 155 | else: 156 | compute_location = fast_cumsum_sub_one 157 | 158 | locations1 = compute_location(masks_se[0]) 159 | 160 | locations_s = [torch.sum(locations1 * masks_se[0], dim=1).to(torch.int32)] 161 | 162 | if top_k > 1: 163 | acc_base = None 164 | for k in range(1, top_k): 165 | acc_base = torch.sum(masks_se[k - 1], dim=0, keepdim=True) if acc_base is None else acc_base + torch.sum(masks_se[k - 1], dim=0, keepdim=True) 166 | locations2 = compute_location(masks_se[k]) 167 | locations2 += acc_base 168 | locations_s.append(torch.sum(locations2 * masks_se[k], dim=1).to(torch.int32)) 169 | 170 | if normalize_gate: 171 | denom_s = torch.clamp(sum(gates_s), min=torch.finfo(gates_s[0].dtype).eps) 172 | gates_s = [x / denom_s for x in gates_s] 173 | 174 | indices_s = [x.to(torch.int32) for x in indices_s] 175 | 176 | #=== The count of Expert token === 177 | #print(torch.bincount(indices_s[0]) + torch.bincount(indices_s[1])) 178 | 179 | samples_per_expert = ((int(scores.size(0)) + num_global_experts - 1) // num_global_experts) 180 | if capacity_factor > 0: 181 | capacity = top_k * int(capacity_factor * samples_per_expert) 182 | else: 183 | capacity = torch.max(torch.cat(locations_s, dim=0)) 184 | capacity = int(simple_all_reduce(capacity, group=group, op=torch.distributed.ReduceOp.MAX)) + 1 185 | if capacity_factor < 0: 186 | capacity = min(capacity, top_k * int(-capacity_factor * ((int(scores.size(0)) + num_global_experts - 1) // num_global_experts))) 187 | 188 | remainder = capacity % alignment 189 | if remainder > 0: 190 | capacity = capacity + alignment - remainder 191 | 192 | #if get_world_rank(group) == 0: 193 | # logging.info(f"Capacity = {capacity}, real-time capacity-factor for top-{top_k_original} = {capacity / (top_k * samples_per_expert)}") 194 | 195 | return (num_global_experts, indices_s, locations_s, gates_s, capacity), l_loss 196 | 197 | def fast_encode(data, critial_data, is_postscore=True): 198 | num_global_experts = critial_data[0] 199 | dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype) 200 | dispatcher.update(*critial_data[1:], is_postscore=is_postscore) 201 | return dispatcher.encode(data).view(num_global_experts, -1, data.size(-1)) 202 | 203 | def fast_decode(data, critial_data, is_postscore=True): 204 | num_global_experts = critial_data[0] 205 | dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype) 206 | dispatcher.update(*critial_data[1:], is_postscore=is_postscore) 207 | return dispatcher.decode(data).view(-1, data.size(-1)) 208 | -------------------------------------------------------------------------------- /PlanMoE/impls/jit_compiler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import os, tempfile 6 | 7 | try: 8 | import schemoe_custom_kernel 9 | except: 10 | raise Exception("Cannot import JIT optimized kernels. Did you forget to install Custom Kernel Extension?") 11 | 12 | try: 13 | from torch.utils.cpp_extension import IS_HIP_EXTENSION, CUDA_HOME, ROCM_HOME 14 | except: 15 | IS_HIP_EXTENSION = False 16 | 17 | if hasattr(schemoe_custom_kernel, 'update_sdk_home'): 18 | SDK_HOME = CUDA_HOME if not IS_HIP_EXTENSION else ROCM_HOME 19 | schemoe_custom_kernel.update_sdk_home(torch.tensor([ord(x) for x in SDK_HOME] + [0], dtype=torch.int8, device='cpu')) 20 | 21 | class JitCompiler: 22 | @staticmethod 23 | def create_raw(source): 24 | torch.cuda.init() 25 | if not hasattr(schemoe_custom_kernel, 'inject_source'): 26 | raise Exception('CUDA support is disabled during Tutel installation. Please configure CUDA correctly and reinstall Tutel to enable CUDA support, or report Tutel installation logs for help.') 27 | __ctx__ = schemoe_custom_kernel.inject_source(source) 28 | 29 | def func(*inputs, extra=[], blocks=[]): 30 | schemoe_custom_kernel.invoke(inputs, extra, blocks, __ctx__) 31 | return func 32 | 33 | @staticmethod 34 | def generate_kernel(keyword_dict, template): 35 | for key in keyword_dict: 36 | template = template.replace('@%s@' % key, str(keyword_dict[key])) 37 | return JitCompiler.create_raw(template) 38 | 39 | @staticmethod 40 | def generate_cpu_kernel(kernel_type): 41 | def func(*inputs, extra=[]): 42 | if inputs[0].dtype is torch.float32: 43 | schemoe_custom_kernel.invoke_cpu_fp32(inputs, extra, kernel_type) 44 | elif inputs[0].dtype is torch.float64: 45 | schemoe_custom_kernel.invoke_cpu_fp64(inputs, extra, kernel_type) 46 | else: 47 | raise Exception("CPU kernel only supports float32 and float64!") 48 | 49 | return func 50 | 51 | def create_cuda_kernel(source, keyword_dict={}): 52 | return JitCompiler.generate_kernel(keyword_dict, source) 53 | 54 | -------------------------------------------------------------------------------- /PlanMoE/impls/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | from torch.distributions.normal import Normal 6 | 7 | def _one_hot_with_dtype(data, num_classes, dtype, hot_value=1): 8 | result = torch.zeros([data.size(0), num_classes], device=data.device, dtype=dtype) 9 | result.scatter_(1, data.unsqueeze(-1), hot_value) 10 | return result 11 | 12 | def gshard_loss(scores_w_noise, top_ids): 13 | num_samples, num_global_experts = int(scores_w_noise.size(0)), int(scores_w_noise.size(1)) 14 | mask = _one_hot_with_dtype(top_ids[:, 0], num_global_experts, dtype=scores_w_noise.dtype, 15 | hot_value=num_global_experts / num_samples) 16 | me = torch.sum(scores_w_noise, dim=0) 17 | ce = torch.sum(mask, dim=0) 18 | l_aux = torch.sum(me * ce) / num_samples 19 | return l_aux 20 | 21 | def load_importance_loss(scores_wo_noise, topk_logits, num_global_experts, gate_noise): 22 | def load_loss(scores_wo_noise, topk_logits, num_global_experts, gate_noise): 23 | normal = Normal( 24 | torch.tensor([0.0], device=scores_wo_noise.device), 25 | torch.tensor([gate_noise / num_global_experts], device=scores_wo_noise.device), 26 | ) 27 | threshold = topk_logits[:, -1].view(-1, 1).float() 28 | diff = scores_wo_noise.float() - threshold.float() 29 | prob = normal.cdf(diff) 30 | Load = prob.sum(0) 31 | l_load = Load.float().var() / (Load.float().mean() ** 2 + 1e-10) 32 | return l_load 33 | 34 | def importance_loss(scores_wo_noise): 35 | Impi = scores_wo_noise.float().sum(0) 36 | l_imp = Impi.float().var() / (Impi.float().mean() ** 2 + 1e-10) 37 | 38 | return l_imp 39 | 40 | l_imp = importance_loss(scores_wo_noise) 41 | l_load = load_loss(scores_wo_noise, topk_logits, num_global_experts, gate_noise) 42 | return (l_imp + l_load) / 2.0 -------------------------------------------------------------------------------- /PlanMoE/impls/moe_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast 5 | 6 | import copy 7 | import os 8 | import re 9 | import time 10 | import logging 11 | import collections 12 | import importlib 13 | 14 | import math 15 | import torch 16 | from torch import Tensor 17 | import torch.distributed as dist 18 | from torch.nn import ModuleList 19 | import torch.nn.functional as F 20 | 21 | from ..impls import communicate as C 22 | from ..impls.fast_dispatch import fast_encode, fast_decode, extract_critical 23 | from ..impls.overlap import a2a_ffn_overlap_forward 24 | from . import losses 25 | 26 | 27 | class MOELayer(torch.nn.Module): 28 | """Tutel optimized MOELayer 29 | """ 30 | @staticmethod 31 | def global_expert_count(num_local_experts, group=None): 32 | if not isinstance(num_local_experts, int): 33 | num_local_experts = -int(1 / (num_local_experts + 1e-5)) 34 | world_size = C.get_world_size(group) 35 | if num_local_experts == 0: 36 | raise Exception( 37 | "Invalid value of num_local_experts: %d" % num_local_experts) 38 | if num_local_experts > 0: 39 | return num_local_experts * world_size 40 | assert world_size % - \ 41 | num_local_experts == 0, "Excepting {-num_local_experts} devices to share an expert param, while global device count is {world_size}." 42 | return world_size // -num_local_experts 43 | 44 | def __init__( 45 | self, 46 | gate_type, 47 | model_dim: int, 48 | experts=None, 49 | scan_expert_func=None, 50 | result_func=None, 51 | group=None, 52 | seeds=None, 53 | a2a_ffn_overlap_degree=1, 54 | is_postscore=True, 55 | batch_prioritized_routing=False, 56 | normalize_gate=True, 57 | is_gshard_loss=True, 58 | parallel_type='auto', 59 | use_2dh=False, 60 | index=0, 61 | compress_name='no', 62 | comm_name='naive', 63 | **kwargs 64 | ): 65 | super().__init__() 66 | assert model_dim % 2 == 0, "Model_dim (%s) must be even value, while this Model_dim mod 2 > 0." % model_dim 67 | group = group or dist.group.WORLD 68 | 69 | if 'pad_samples' in kwargs: 70 | logging.warning( 71 | f"`pad_samples` option in Tutel Moe-layer has been deprecated, as Tutel always assumes `pad_samples=False` for better efficiency.") 72 | kwargs.pop('pad_samples') 73 | for k in kwargs: 74 | raise Exception( 75 | 'Unrecognized argument provided to Tutel Moe-layer: %s' % k) 76 | 77 | self.group = group 78 | self.result_func = result_func 79 | self.skip_moe = (int(os.environ.get('SKIP_MOE', '0')) != 0) 80 | 81 | self.num_local_experts = experts.pop('count_per_node', 1) 82 | self.num_global_experts = MOELayer.global_expert_count( 83 | self.num_local_experts, self.group) 84 | 85 | self.world_size = C.get_world_size(self.group) 86 | if self.num_global_experts < self.world_size: 87 | sharded_count = self.world_size // self.num_global_experts 88 | self.num_local_experts = 1 89 | self.ffn_zero_group = C.create_groups_from_world( 90 | group_count=self.num_global_experts).model_group 91 | else: 92 | sharded_count = 1 93 | self.ffn_zero_group = None 94 | 95 | if sharded_count == 1: 96 | self.auto_parallel, self.use_model_parallel = False, False 97 | elif parallel_type == 'auto': 98 | self.auto_parallel, self.use_model_parallel = True, False 99 | else: 100 | self.auto_parallel, self.use_model_parallel = False, ( 101 | parallel_type == 'model') 102 | 103 | self.model_dim = model_dim 104 | self.sharded_count = sharded_count 105 | 106 | self.is_postscore = is_postscore 107 | self.batch_prioritized_routing = batch_prioritized_routing 108 | if int(os.environ.get('BATCH_PRIO', 0)) != 0: 109 | self.batch_prioritized_routing = True 110 | self.normalize_gate = normalize_gate 111 | self.is_gshard_loss = is_gshard_loss 112 | 113 | self.a2a_ffn_overlap_degree = a2a_ffn_overlap_degree 114 | self.use_2dh = use_2dh 115 | self.compress_name = compress_name 116 | self.comm_name = comm_name 117 | 118 | if seeds is not None and seeds[1] is not None: 119 | torch.manual_seed(seeds[1]) 120 | 121 | experts_type = experts.pop('type') 122 | if experts_type == 'custom': 123 | self.experts = cast(ModuleList, experts['module']) 124 | else: 125 | assert re.match( 126 | r'[a-zA-Z0-9\_]+', experts_type), "Expert type must only include digits, letters and underline characters." 127 | try: 128 | fused_experts = importlib.import_module( 129 | f'...experts.{experts_type}', __name__) 130 | except ModuleNotFoundError: 131 | raise Exception( 132 | 'Builtin expert type is not recognized: %s' % experts_type) 133 | 134 | if experts_type == 'ffn': 135 | assert 'fused_custom_fn' not in experts, "`fused_custom_fn` option for Tutel Moe-layer has been deprecated, please follows helloworld_from_scratch.py for custom construction instead." 136 | assert 'implicit_dropout_p' not in experts, "`implicit_dropout_p` option for Tutel Moe-layer has been deprecated, please use torch.nn.Dropout(p=implicit_dropout_p) on custom activation_fn (for fc1_dropout) and after Tutel Moe-layer (for fc2_dropout) instead." 137 | 138 | self.experts = fused_experts.ExpertModule(**experts) 139 | 140 | self.experts.update(self) 141 | 142 | 143 | if scan_expert_func is not None: 144 | for n, p in self.experts.named_parameters(): 145 | scan_expert_func(n, p) 146 | for n, p in self.experts.named_parameters(): 147 | setattr(p, '_tutel_expert', True) 148 | 149 | if isinstance(gate_type, str): 150 | assert re.match( 151 | r'^Top[0-9]+Gate$', gate_type), "Unrecognized gate_type: %s" % gate_type 152 | top_k = int(gate_type[3:-4]) 153 | logging.warning( 154 | f"gate_type value `{gate_type}` in Tutel Moe-layer has been deprecated, please use gate_type = {{'type': 'top', 'k': {top_k}}} instead.") 155 | gate_type = {'type': 'top', 'k': top_k} 156 | 157 | if not isinstance(gate_type, list): 158 | gate_type = [gate_type] 159 | 160 | self.gates = [] 161 | for gi, single_gate_type in enumerate(gate_type): 162 | gate_type = single_gate_type['type'] 163 | single_gate_type.pop('type') 164 | assert re.match( 165 | r'[a-zA-Z0-9\_]+', gate_type), "Gate type must only include digits, letters and underline characters." 166 | 167 | if seeds is not None and seeds[0] is not None: 168 | torch.manual_seed(seeds[0] + gi) 169 | try: 170 | single_gate = importlib.import_module( 171 | f'...gates.{gate_type}', __name__) 172 | except ModuleNotFoundError: 173 | raise Exception("Unrecognized gate_type: %s" % gate_type) 174 | 175 | gate_module = single_gate.Gate( 176 | model_dim=self.model_dim, num_global_experts=self.num_global_experts, **single_gate_type) 177 | if not hasattr(gate_module, 'gate_noise'): 178 | gate_module.gate_noise = single_gate_type.get( 179 | 'gate_noise', 0.0) 180 | if not hasattr(gate_module, 'capacity_factor'): 181 | gate_module.capacity_factor = single_gate_type.get( 182 | 'capacity_factor', float(os.environ.get('CAP_FACTOR', 1.0))) 183 | 184 | self.gates += [gate_module] 185 | 186 | self.gates = ModuleList(self.gates) 187 | 188 | if seeds is not None and len(seeds) > 2 and seeds[2] is not None: 189 | torch.manual_seed(seeds[2]) 190 | self.save_count = 0 191 | 192 | def extra_repr(self): 193 | return 'Top-K(s) = %s, Total-Experts = %d [managed by %d device(s)],' % ( 194 | [f'k={x.top_k}, noise={x.gate_noise}' for x in self.gates], 195 | self.num_global_experts, 196 | self.world_size, 197 | ) 198 | 199 | def get_parameter_iterator(self, param_type): 200 | if param_type == 'gate': 201 | return self.gates.named_parameters() 202 | elif param_type == 'local_experts': 203 | return self.experts.named_parameters() 204 | else: 205 | raise Exception( 206 | "Specified parameter type is not recognized: %s. Valid `param_type` includes: gate, local_experts." % param_type) 207 | 208 | def expert_local(self, x, reserve_shape): 209 | y = self.experts(x.view(x.size(0), x.size(1), *reserve_shape), self) 210 | self.protected_shape = y.shape 211 | return y.reshape(y.size(0), y.size(1), -1) 212 | 213 | def forward(self, input: Tensor, gate_index=0, capacity_factor=None, top_k=None, a2a_ffn_overlap_degree=None, reserve_dims=1): 214 | if self.skip_moe: 215 | result_output = input 216 | result_output.l_aux = None 217 | return self.result_func(result_output) if self.result_func is not None else result_output 218 | 219 | original_shape, original_dtype = input.shape, input.dtype 220 | assert len( 221 | original_shape) >= 2, "Input data must be at least 2D tensor: (s)amples, .., (m)odel_dim" 222 | 223 | x = input.reshape(-1, original_shape[-reserve_dims:].numel()) 224 | for p in self.experts.parameters(): 225 | x = x.to(p.dtype) 226 | break 227 | gctx = self.gates[gate_index] 228 | a2a_ffn_overlap_degree = a2a_ffn_overlap_degree if a2a_ffn_overlap_degree is not None else self.a2a_ffn_overlap_degree 229 | 230 | def routing(): 231 | logits = gctx(x) 232 | 233 | if self.training and gctx.gate_noise > 0: 234 | logits_w_noise = logits + gctx.gate_noise * \ 235 | torch.randn_like(logits) / self.num_global_experts 236 | else: 237 | logits_w_noise = logits 238 | 239 | scores = F.softmax(logits_w_noise, dim=1) 240 | if self.is_gshard_loss: 241 | def _loss_fn(gates, topk_ids): return losses.gshard_loss( 242 | gates, topk_ids) 243 | else: 244 | def _loss_fn(gates, topk_ids): return losses.load_importance_loss( 245 | F.softmax(logits, dim=1), logits_w_noise.gather( 246 | index=topk_ids, dim=1), 247 | self.num_global_experts, gctx.gate_noise) 248 | return logits.dtype, extract_critical(scores, 249 | top_k=gctx.top_k if top_k is None else top_k, 250 | loss_fn=_loss_fn, 251 | capacity_factor=gctx.capacity_factor if capacity_factor is None else capacity_factor, 252 | batch_prioritized_routing=self.batch_prioritized_routing, 253 | normalize_gate=self.normalize_gate, 254 | group=self.group, 255 | alignment=4 * self.sharded_count * a2a_ffn_overlap_degree 256 | ) 257 | 258 | if x.is_cuda: 259 | with torch.cuda.amp.autocast(enabled=False): 260 | logits_dtype, (crit, l_aux) = routing() 261 | else: 262 | logits_dtype, (crit, l_aux) = routing() 263 | 264 | # x = x.to(torch.float16) 265 | y = fast_encode(x.to(logits_dtype), crit, 266 | self.is_postscore).to(x.dtype) 267 | #y = ((y - _min) / (_max - _min) * 255).to(torch.uint8) 268 | 269 | if self.auto_parallel: 270 | self.use_model_parallel = (y.numel( 271 | ) * (self.sharded_count - 1) * 2 < sum([x.numel() for x in self.experts.parameters()])) 272 | 273 | if self.num_global_experts < self.world_size: 274 | if self.use_model_parallel: 275 | y = y.repeat(1, self.sharded_count, 1).view( 276 | self.world_size, -1, y.size(2)) 277 | else: 278 | y = y.view(self.world_size, -1, y.size(2)) 279 | 280 | # if a2a_ffn_overlap_degree > 1 and y.is_cuda: 281 | # def expert_fn(expert_input): 282 | # return self.expert_local(expert_input, original_shape[-reserve_dims:]) 283 | # y = a2a_ffn_overlap_forward(y, expert_fn=expert_fn, a2a_ffn_overlap_degree=a2a_ffn_overlap_degree, use_2dh=self.use_2dh, group=self.group) 284 | # else: 285 | # y = C.all_to_all(y, 1, 0, use_2dh=self.use_2dh, group=self.group) 286 | # y = self.expert_local(y, original_shape[-reserve_dims:]) 287 | # y = C.all_to_all(y, 0, 1, use_2dh=self.use_2dh, group=self.group) 288 | if self.training: 289 | self.save_count = self.save_count + 1 290 | # is_compress = self.training and self.save_count > 100000 291 | is_compress = True 292 | 293 | def expert_fn(expert_input): 294 | return self.expert_local(expert_input, original_shape[-reserve_dims:]) 295 | y = a2a_ffn_overlap_forward(y, expert_fn=expert_fn, a2a_ffn_overlap_degree=a2a_ffn_overlap_degree, 296 | use_2dh=self.use_2dh, group=self.group, compress_name=self.compress_name, 297 | comm_name=self.comm_name) 298 | 299 | if self.num_global_experts < self.world_size: 300 | if self.use_model_parallel: 301 | y = torch.sum(y.view(self.num_global_experts, 302 | self.sharded_count, -1, y.size(2)), dim=1) 303 | else: 304 | y = y.view(self.num_global_experts, -1, y.size(2)) 305 | 306 | y = fast_decode(y.to(logits_dtype), crit, self.is_postscore) 307 | 308 | y = y.view(list(original_shape)).to(original_dtype) 309 | #y = y.view(list(original_shape[:-reserve_dims]) + list(self.protected_shape[-reserve_dims:])).to(original_dtype) 310 | self.l_aux = y.l_aux = l_aux 311 | return self.result_func(y) if self.result_func is not None else y 312 | 313 | 314 | moe_layer = MOELayer 315 | -------------------------------------------------------------------------------- /PlanMoE/impls/overlap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import schemoe_custom_kernel 6 | from time import time 7 | from torch.distributed import get_rank 8 | 9 | from ..impls import communicate as C 10 | 11 | 12 | class Compress(torch.autograd.Function): 13 | @staticmethod 14 | def forward(ctx, input, compress_name, comm_name): 15 | input = schemoe_custom_kernel.compress_operation(input, compress_name, comm_name) 16 | return input 17 | 18 | @staticmethod 19 | def backward(ctx, grad): 20 | grad = schemoe_custom_kernel.decompress_operation(grad) 21 | return grad, None, None 22 | 23 | 24 | class Decompress(torch.autograd.Function): 25 | @staticmethod 26 | def forward(ctx, input, compress_name, comm_name): 27 | ctx.compress_name = compress_name 28 | ctx.comm_name = comm_name 29 | input = schemoe_custom_kernel.decompress_operation(input) 30 | return input 31 | 32 | @staticmethod 33 | def backward(ctx, grad): 34 | return schemoe_custom_kernel.compress_operation(grad, ctx.compress_name, ctx.comm_name), None, None 35 | 36 | 37 | class Comm(torch.autograd.Function): 38 | @staticmethod 39 | def forward(ctx, input): 40 | return schemoe_custom_kernel.comm_operation(input) 41 | 42 | @staticmethod 43 | def backward(ctx, grad): 44 | return schemoe_custom_kernel.comm_operation(grad) 45 | 46 | 47 | def a2a_ffn_overlap_forward(input, expert_fn, a2a_ffn_overlap_degree, use_2dh, group, compress_name, comm_name): 48 | split_dim = 1 49 | assert a2a_ffn_overlap_degree <= C.AllToAllStatus.max_num_split, "Excepting a2a_ffn_overlap_degree (%d) <= AllToAllStatus.max_num_split (%d)." % ( 50 | a2a_ffn_overlap_degree, C.AllToAllStatus.max_num_split) 51 | assert input.shape[split_dim] % a2a_ffn_overlap_degree == 0, "Excepting input.shape[%d] (%d) be multiple of a2a_ffn_overlap_degree (%d)." % ( 52 | split_dim, input.shape[split_dim], a2a_ffn_overlap_degree) 53 | C.AllToAllStatus.init(group, a2a_ffn_overlap_degree, split_dim) 54 | 55 | split_size = input.shape[split_dim] // a2a_ffn_overlap_degree 56 | input_split = list(input.split(split_size, dim=split_dim)) 57 | schemoe_custom_kernel.clear_ptr_lst() 58 | 59 | for i in range(a2a_ffn_overlap_degree): 60 | input_split[i] = input_split[i].contiguous() 61 | input_split[i] = Compress.apply(input_split[i], compress_name, comm_name) 62 | # for i in range(a2a_ffn_overlap_degree): 63 | input_split[i] = Comm.apply(input_split[i]) 64 | 65 | for i in range(a2a_ffn_overlap_degree): 66 | input_split[i] = Decompress.apply(input_split[i], compress_name, comm_name) 67 | input_split[i] = C.post_expert_permute( 68 | expert_fn(C.pre_expert_permute(input_split[i], group=group)), group=group 69 | ) 70 | input_split[i] = Compress.apply(input_split[i], compress_name, comm_name) 71 | # for i in range(a2a_ffn_overlap_degree): 72 | input_split[i] = Comm.apply(input_split[i]) 73 | for i in range(a2a_ffn_overlap_degree): 74 | input_split[i] = Decompress.apply(input_split[i], compress_name, comm_name) 75 | output = torch.cat(input_split, dim=split_dim).contiguous() 76 | return output 77 | -------------------------------------------------------------------------------- /PlanMoE/jit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .impls.jit_compiler import create_cuda_kernel 5 | -------------------------------------------------------------------------------- /PlanMoE/jit_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /PlanMoE/jit_kernels/gating.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | import torch 6 | import logging 7 | from ..impls.jit_compiler import schemoe_custom_kernel 8 | 9 | torch.ops.load_library(schemoe_custom_kernel.__file__) 10 | 11 | use_fast_cumsum = (int(os.environ.get('FAST_CUMSUM', '1')) == 1) 12 | 13 | def torch_cumsum_sub_one(mask1): 14 | locations1 = torch.cumsum(mask1, dim=0) - 1 15 | return locations1 16 | 17 | def fast_cumsum_sub_one(data, dim=0): 18 | if data.dim() != 2 or dim != 0: 19 | raise Exception("Unimplemented fast_cumsum_sub_one() of data = %s and dim = %s" % (data.size(), dim)) 20 | if not data.is_cuda or not use_fast_cumsum: 21 | return torch_cumsum_sub_one(data) 22 | return torch.ops.tutel_ops.cumsum(data) 23 | -------------------------------------------------------------------------------- /PlanMoE/jit_kernels/sparse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | from ..impls.jit_compiler import JitCompiler 6 | 7 | 8 | def get_kernel_dtype(param_dtype): 9 | if param_dtype == torch.float16: 10 | return '__half2' 11 | elif param_dtype == torch.float32: 12 | return 'float' 13 | else: 14 | raise Exception("Unrecognized data type: %s" % param_dtype) 15 | 16 | 17 | def create_forward(param_dtype, is_cuda=True): 18 | if not is_cuda: 19 | return JitCompiler.generate_cpu_kernel(kernel_type=0) 20 | 21 | return JitCompiler.generate_kernel({'dtype': get_kernel_dtype(param_dtype), 'IS_FLOAT': 1 if param_dtype == torch.float32 else 0}, ''' 22 | #define __dtype @dtype@ 23 | 24 | extern "C" __global__ __launch_bounds__(1024) void execute(__dtype* __restrict__ gates1_s, int* __restrict__ indices1_s, int* __restrict__ locations1_s, __dtype* __restrict__ reshaped_input, __dtype* __restrict__ dispatched_input, int samples, int hidden, int capacity) { 25 | // [thread_extent] blockIdx.x = 512 26 | // [thread_extent] threadIdx.x = 1024 27 | 28 | for (int i = blockIdx.x; i < samples; i += gridDim.x) 29 | if (locations1_s[i] < capacity && indices1_s[i] >= 0) { 30 | #pragma unroll 31 | for (int j = threadIdx.x; j < hidden; j += 1024) 32 | atomicAdd(&dispatched_input[(indices1_s[i] * capacity + locations1_s[i]) * (hidden) + j], gates1_s[i] * reshaped_input[i * (hidden) + j]); 33 | } 34 | } 35 | ''') 36 | 37 | 38 | def create_backward_data(param_dtype, is_cuda=True): 39 | if not is_cuda: 40 | return JitCompiler.generate_cpu_kernel(kernel_type=1) 41 | 42 | return JitCompiler.generate_kernel({'dtype': get_kernel_dtype(param_dtype), 'IS_FLOAT': 1 if param_dtype == torch.float32 else 0}, ''' 43 | #define __dtype @dtype@ 44 | 45 | extern "C" __global__ __launch_bounds__(1024) void execute(__dtype* __restrict__ gates1_s, int* __restrict__ indices1_s, int* __restrict__ locations1_s, __dtype* __restrict__ grad_reshaped_input, __dtype* __restrict__ dispatched_input, int samples, int hidden, int capacity) { 46 | // [thread_extent] blockIdx.x = 512 47 | // [thread_extent] threadIdx.x = 1024 48 | 49 | for (int i = blockIdx.x; i < samples; i += gridDim.x) 50 | if (locations1_s[i] < capacity && indices1_s[i] >= 0) { 51 | #pragma unroll 52 | for (int j = threadIdx.x; j < hidden; j += 1024) 53 | grad_reshaped_input[i * hidden + j] = gates1_s[i] * dispatched_input[(indices1_s[i] * capacity + locations1_s[i]) * (hidden) + j]; 54 | } else { 55 | #pragma unroll 56 | for (int j = threadIdx.x; j < hidden; j += 1024) 57 | #if @IS_FLOAT@ 58 | grad_reshaped_input[i * hidden + j] = __dtype(0); 59 | #else 60 | grad_reshaped_input[i * hidden + j] = __dtype(0, 0); 61 | #endif 62 | } 63 | } 64 | ''') 65 | 66 | 67 | def create_backward_gate(param_dtype, is_cuda=True): 68 | if not is_cuda: 69 | return JitCompiler.generate_cpu_kernel(kernel_type=2) 70 | 71 | return JitCompiler.generate_kernel({'dtype': get_kernel_dtype(param_dtype), 'IS_FLOAT': 1 if param_dtype == torch.float32 else 0}, ''' 72 | #define __dtype @dtype@ 73 | 74 | extern "C" __global__ __launch_bounds__(32) void execute(void* __restrict__ grad_gates1_s, int* __restrict__ indices1_s, int* __restrict__ locations1_s, __dtype* __restrict__ reshaped_input, __dtype* __restrict__ dispatched_input, int samples, int hidden, int capacity) { 75 | // [thread_extent] blockIdx.x = 512 76 | // [thread_extent] threadIdx.x = 32 77 | for (int index = blockIdx.x; index < samples; index += gridDim.x) { 78 | if (locations1_s[index] >= capacity || indices1_s[index] < 0) { 79 | if (((int)threadIdx.x) == 0) 80 | #if @IS_FLOAT@ 81 | ((float*)grad_gates1_s)[index] = 0; 82 | #else 83 | ((half*)grad_gates1_s)[index] = __float2half_rn(0.000000e+00f); 84 | #endif 85 | continue; 86 | } 87 | int indice = indices1_s[index] * capacity + locations1_s[index]; 88 | #if @IS_FLOAT@ 89 | __dtype grad_gates1_s_rf = 0.000000e+00f; 90 | #else 91 | __dtype grad_gates1_s_rf = __dtype(0, 0); 92 | #endif 93 | for (int i = threadIdx.x; i < hidden; i += 32) 94 | grad_gates1_s_rf += dispatched_input[indice * (hidden) + i] * reshaped_input[index * (hidden) + i]; 95 | 96 | #if !defined(__HIPCC__) 97 | __dtype red_buf0[1]; 98 | unsigned int mask[1]; 99 | __dtype t0[1]; 100 | red_buf0[(0)] = grad_gates1_s_rf; 101 | mask[(0)] = __activemask(); 102 | t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 16, 32); 103 | red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); 104 | t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 8, 32); 105 | red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); 106 | t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 4, 32); 107 | red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); 108 | t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 2, 32); 109 | red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); 110 | t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 1, 32); 111 | red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); 112 | red_buf0[(0)] = __shfl_sync(mask[(0)], red_buf0[(0)], 0, 32); 113 | #else 114 | __shared__ __dtype red_buf0[32]; 115 | __syncthreads(); 116 | ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = grad_gates1_s_rf; 117 | if (((int)threadIdx.x) < 16) { 118 | ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 16))])); 119 | ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 8))])); 120 | ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 4))])); 121 | ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 2))])); 122 | ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 1))])); 123 | } 124 | __syncthreads(); 125 | #endif 126 | if (((int)threadIdx.x) == 0) 127 | #if @IS_FLOAT@ 128 | ((float*)grad_gates1_s)[index] = red_buf0[(0)]; 129 | #else 130 | ((half*)grad_gates1_s)[index] = red_buf0[(0)].x + red_buf0[(0)].y; 131 | #endif 132 | } 133 | } 134 | ''') 135 | -------------------------------------------------------------------------------- /PlanMoE/launcher/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /PlanMoE/launcher/execl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os, re, sys 5 | import logging 6 | import argparse 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('-m', default=False, action='store_true') 11 | parser.add_argument('rest', nargs=argparse.REMAINDER) 12 | args = parser.parse_args() 13 | 14 | local_rank = int(os.environ['LOCAL_RANK']) 15 | local_size = int(os.environ['LOCAL_SIZE']) 16 | 17 | if int(os.environ.get('TUTEL_CUDA_SANDBOX', 0)) == 2: 18 | os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank) 19 | else: 20 | os.environ['TUTEL_CUDA_SANDBOX'] = '1' 21 | 22 | skip_numa = int(os.environ.get('OMP_NUM_THREADS', '1')) > 1 23 | cmd_args = [] 24 | try: 25 | if skip_numa or not os.path.exists('/usr/bin/numactl'): 26 | raise 27 | local_size = int(os.environ['LOCAL_SIZE']) 28 | cpu_nodes = sorted([str(x[4:]) for x in os.listdir('/sys/devices/system/node') if re.match('node[0-9]+', x)]) 29 | if len(cpu_nodes) <= local_size: 30 | sel_nodes = cpu_nodes[(local_rank // (local_size // len(cpu_nodes))) % len(cpu_nodes)] 31 | else: 32 | sel_nodes = cpu_nodes[local_rank::local_size] 33 | sel_nodes = ','.join(sel_nodes) 34 | 35 | cmd_args = ['/usr/bin/numactl', '--cpunodebind=%s' % sel_nodes] 36 | except Exception as ex: 37 | if local_rank == 0: 38 | logging.warning('`numactl` is not enabled by tutel.launcher.execl') 39 | 40 | cmd_args += [sys.executable, '-m'] if args.m else [] 41 | cmd_args += args.rest 42 | os.execl(cmd_args[0], *cmd_args) 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /PlanMoE/launcher/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os, sys 5 | 6 | def main(): 7 | host_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 8 | host_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 9 | local_size = int(os.environ.get('LOCAL_SIZE', 1)) 10 | 11 | if 'TUTEL_ALLTOALL_ALGO' not in os.environ: 12 | if host_size >= 64 and local_size >= 8: 13 | os.environ['TUTEL_ALLTOALL_ALGO'] = '2D' 14 | 15 | master_addr = os.environ['MASTER_ADDR'] if host_size > 1 else 'localhost' 16 | master_port = int(os.environ.get('MASTER_PORT', 23232)) 17 | 18 | if 'OMP_NUM_THREADS' not in os.environ: 19 | os.environ['OMP_NUM_THREADS'] = '1024' 20 | 21 | try: 22 | from torch.distributed import run 23 | launch_mode = ['torch.distributed.run'] 24 | except: 25 | launch_mode = ['torch.distributed.launch', '--use_env'] 26 | 27 | cmd_args = [sys.executable, '-m'] + launch_mode + [ 28 | '--nproc_per_node=%d' % local_size, 29 | '--nnodes=%d' % host_size, 30 | '--node_rank=%d' % host_rank, 31 | '--master_addr=%s' % master_addr, 32 | '--master_port=%s' % master_port, 33 | '-m', 'tutel.launcher.execl', 34 | ] + sys.argv[1:] 35 | os.execl(cmd_args[0], *cmd_args) 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /PlanMoE/moe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | # Level-level Ops 6 | from .jit_kernels.gating import fast_cumsum_sub_one 7 | from .impls.fast_dispatch import fast_dispatcher, extract_critical, fast_encode, fast_decode 8 | 9 | top_k_routing = extract_critical 10 | 11 | # High-level Ops 12 | from .impls.moe_layer import moe_layer 13 | -------------------------------------------------------------------------------- /PlanMoE/net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import logging 5 | 6 | from .impls.communicate import get_world_size, get_world_rank, create_groups_from_world, barrier 7 | # Communication without Backward Compute 8 | from .impls.communicate import simple_all_reduce, simple_all_to_all,simple_split, simple_reduce_scatter, simple_all_gather 9 | # Communication with Backward Compute 10 | from .impls.communicate import all_to_all, all_to_all_single, all_gather, zero_gather, zero_scatter, spatial_split, reduce_scatter, allreduce_forward, allreduce_backward 11 | 12 | 13 | class TutelDistributedOptimizer: 14 | def __init__(self, params, group=None, average_shared=False): 15 | params = [x for x in params] 16 | self.params = [x for x in params if not hasattr(x, '_tutel_expert')] 17 | self.expert_params = [x for x in params if hasattr(x, '_tutel_expert')] 18 | self.shapes = [x.shape for x in self.params] 19 | self.group = group 20 | self.average_shared = average_shared 21 | 22 | def chunk_param(self): 23 | mocks = [] 24 | for p in self.params: 25 | mocks += [zero_scatter(p.data, simple_split, group=self.group)[0]] 26 | self.virt_params = mocks 27 | 28 | def chunk_grad(self): 29 | for i, p in enumerate(self.params): 30 | if hasattr(p, 'grad') and p.grad is not None: 31 | if self.average_shared: 32 | grad = p.grad.view(-1) / get_world_size(self.group) 33 | else: 34 | grad = p.grad.view(-1) 35 | self.virt_params[i].grad, _ = zero_scatter(grad, simple_reduce_scatter, group=self.group) 36 | 37 | def restore(self): 38 | for i, p in enumerate(self.virt_params): 39 | data = simple_all_gather(p.data, group=self.group).view(-1) 40 | self.params[i].data = data[:self.shapes[i].numel()].view(self.shapes[i]) 41 | 42 | def warp_local(self, local_optim, *args, **kwargs): 43 | self.chunk_param() 44 | self.local_optim = local_optim(self.virt_params + self.expert_params, *args, **kwargs) 45 | return self 46 | 47 | def zero_grad(self): 48 | for p in self.params + self.expert_params: 49 | if hasattr(p, 'grad') and p.grad is not None: 50 | p.grad.detach_() 51 | p.grad.zero_() 52 | 53 | def step(self): 54 | self.chunk_grad() 55 | self.local_optim.step() 56 | self.restore() 57 | -------------------------------------------------------------------------------- /PlanMoE/parted/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /PlanMoE/parted/backend/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /PlanMoE/parted/backend/torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /PlanMoE/parted/backend/torch/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os, sys 5 | import re 6 | import json 7 | import torch 8 | 9 | def get_input_definition(name, shape, stat_dim, dtype, is_param, device=None): 10 | return f'E.sharded_randn({shape}, {stat_dim}, dtype=torch.{dtype}, requires_grad=True, is_param={is_param}, device={device})' 11 | 12 | def get_execute_cmd(group_size, glob_size, device_type, program_path): 13 | if glob_size == 1: 14 | os_command = f'PYTHONWARNINGS=ignore OMP_NUM_THREADS=1 {sys.executable} {program_path}' 15 | else: 16 | host_names = os.environ.get('HOSTS', 'localhost').split(',') 17 | assert glob_size % len(host_names) == 0, f"Cannot evenly launch {glob_size} instances on {len(host_names)} hosts." 18 | local_size = glob_size // len(host_names) 19 | os_command = f'mpiexec --allow-run-as-root -host {",".join(host_names)} -x MASTER_ADDR={host_names[0]} -x LOCAL_SIZE={local_size} {sys.executable} -m tutel.launcher.run {sys.executable} {program_path}' 20 | return os_command 21 | 22 | def link(name, input_dim, output_dim, is_param=False, output_shape=None): 23 | if input_dim is None: 24 | return f'C.allreduce_forward({name}, group=E.parallel_env.model_group)' if output_dim == -1 else f'C.reduce_scatter({name}, {output_dim}, E.parallel_env.model_group)' 25 | if output_dim is None: 26 | return f'E.warp_bwd_allreduce({name}, {is_param})' 27 | if input_dim == -2: 28 | return f'C.zero_gather({name}, {output_shape}, E.parallel_env.model_group)' 29 | if input_dim == -1: 30 | return f'C.spatial_split({name}, {output_dim}, E.parallel_env.model_group)' 31 | if output_dim == -1: 32 | return f'C.all_gather({name}, {input_dim}, E.parallel_env.model_group)' 33 | return f'C.all_to_all({name}, {input_dim}, {output_dim}, E.parallel_env.model_group)' 34 | 35 | def generate_framework_code(device_type, group_size, group_count, run_mode, compute_name, headers, input_list, param_list, graph_prog): 36 | headers = '\n'.join(headers).strip() + '\n' if headers else '' 37 | graph_prog = '\n '.join(graph_prog) 38 | 39 | input_args = ', '.join([name for name, code in input_list]) 40 | input_list = '\n '.join([f'inputs["{name}"] = {code}' for name, code in input_list]) 41 | 42 | for name, _ in param_list: 43 | graph_prog = re.sub(fr'\b{name}\b', f'self.{name}', graph_prog) 44 | 45 | param_list = '\n '.join([f'self.register_parameter(name="{name}", param={code})' for name, code in param_list]) 46 | 47 | source = f'''import torch 48 | 49 | from tutel import net as C 50 | from tutel.parted.backend.torch import executor as E 51 | 52 | {headers} 53 | class DistModel(torch.nn.Module): 54 | compute_name = '{compute_name}' 55 | 56 | def __init__(self): 57 | super().__init__() 58 | {param_list} 59 | 60 | def forward(self, {input_args}): 61 | {graph_prog} 62 | return {compute_name} 63 | 64 | @staticmethod 65 | def synthetic_inputs(): 66 | inputs = dict() 67 | {input_list} 68 | return inputs 69 | 70 | 71 | if __name__ == '__main__': 72 | E.init_session(group_size={group_size}, group_count={group_count}, device_type='{device_type}') 73 | E.model_executor(DistModel, is_training={run_mode == 'train'}) 74 | ''' 75 | return source 76 | -------------------------------------------------------------------------------- /PlanMoE/parted/backend/torch/executor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os, sys 5 | import time 6 | import json 7 | import torch 8 | import torch.distributed as dist 9 | 10 | from tutel import system 11 | from tutel import net as C 12 | 13 | def warp_bwd_allreduce(data, is_param): 14 | if is_param: 15 | fusable_params.add(id(data)) 16 | return C.allreduce_backward(data, group=parallel_env.global_group) 17 | return C.allreduce_backward(data, group=parallel_env.model_group) 18 | 19 | def sharded_randn(shape, dim, dtype, requires_grad=False, is_param=False, device=None): 20 | if device is None: 21 | device = parallel_env.local_device 22 | torch.manual_seed(1) 23 | complete_tensor = torch.tensor(torch.randn(shape, dtype=dtype, device='cpu').numpy(), device=device, requires_grad=requires_grad) 24 | if dim >= 0: 25 | result = torch.chunk(complete_tensor, chunks=parallel_env.model_size, dim=dim)[parallel_env.model_rank].contiguous() 26 | elif dim == -2: 27 | numel = complete_tensor.numel() 28 | assert numel % parallel_env.model_size == 0 29 | result = complete_tensor.view(parallel_env.model_size, -1)[parallel_env.model_rank].contiguous() 30 | else: 31 | result = complete_tensor.contiguous() 32 | if is_param: 33 | result = torch.nn.Parameter(result * 1e-3) 34 | result.is_param = True 35 | if dim == -2: 36 | result._full_shape = shape 37 | result.is_param = True 38 | result.dim_state = dim 39 | return result 40 | 41 | def init_session(group_size, group_count=1, device_type='cuda'): 42 | global parallel_env, fusable_params 43 | parallel_env = system.init_data_model_parallel(group_count=group_count, backend='nccl' if device_type == 'cuda' else 'gloo') 44 | fusable_params = set() 45 | assert parallel_env.model_size == group_size, f"This codegen is designed for distributed parallelism = {group_size}, while current session only activates {parallel_env.model_size} device.\n\nPlease retry with command: mpiexec --allow-run-as-root -host localhost -x MASTER_ADDR=localhost -x LOCAL_SIZE={group_size} {sys.executable} -m tutel.launcher.run {sys.executable} {' '.join(sys.argv)}" 46 | 47 | def model_executor(module, is_training=True): 48 | name = module.compute_name 49 | model = module().to(parallel_env.local_device) 50 | inputs = module.synthetic_inputs() 51 | output = model(**inputs) 52 | params = model.parameters() 53 | 54 | verbose = int(os.environ.get('VERBOSE', '0')) 55 | is_cuda = (parallel_env.local_device.type == 'cuda') 56 | is_training = is_training and isinstance(output, torch.Tensor) 57 | start_result = output.contiguous().view(-1)[0] if isinstance(output, torch.Tensor) else -1 58 | 59 | if verbose: 60 | sys.stderr.write('[%d] %g %g .. %g (%s)\n' % (parallel_env.model_rank, output.flatten()[0], output.flatten()[1], output.flatten()[-1], output.shape)) 61 | 62 | if is_training: 63 | torch.manual_seed(1) 64 | label = torch.LongTensor(output.size(0)).random_(1).to(output.device) 65 | if params: 66 | optimizer = torch.optim.SGD(params, lr=1e-5) 67 | else: 68 | optimizer = model_executor 69 | optimizer.zero_grad = optimizer.step = lambda *x: None 70 | 71 | def next_step(): 72 | if parallel_env.group_count > 1: 73 | dist.barrier() 74 | if is_cuda: 75 | torch.cuda.synchronize(parallel_env.local_device) 76 | t_start = time.time() 77 | 78 | if is_training: 79 | optimizer.zero_grad() 80 | result = model(**inputs).contiguous() 81 | result = torch.nn.functional.log_softmax(result.view(result.size(0), -1), dim=1) 82 | result = torch.nn.functional.nll_loss(result, label) 83 | if parallel_env.model_rank == 0 and verbose: 84 | sys.stderr.write(f' Loss = {result} ({output.shape}, {label.shape})\n') 85 | result.backward(retain_graph=True) 86 | if parallel_env.group_count > 1: 87 | for p in params: 88 | if id(p) not in fusable_params: 89 | p.grad = simple_all_reduce(p.grad, group=parallel_env.data_group) 90 | optimizer.step() 91 | else: 92 | result = model(**inputs) 93 | result = result.contiguous().view(-1)[0] if isinstance(result, torch.Tensor) else -1 94 | 95 | if parallel_env.group_count > 1: 96 | dist.barrier() 97 | if is_cuda: 98 | torch.cuda.synchronize(parallel_env.local_device) 99 | t_stop = time.time() 100 | 101 | step_time = t_stop - t_start 102 | if parallel_env.model_rank == 0 and verbose: 103 | sys.stderr.write('Result(is_training=%s) = %g, cost = %s\n' % (is_training, result, step_time)) 104 | return step_time 105 | 106 | for i in range(5): 107 | next_step() 108 | average_step_time = sum([next_step() for _ in range(5)]) / 5 109 | if parallel_env.model_rank == 0: 110 | sys.stderr.write(' [%s] digest = %g .., time = %g\n' % (name, start_result, average_step_time)) 111 | result = json.dumps({'name': name, 'step_time': average_step_time}) 112 | if 'CONFIG_STORE_PATH' in os.environ: 113 | with open(os.environ['CONFIG_STORE_PATH'], 'w') as fp: 114 | fp.write(result) 115 | print(result) 116 | -------------------------------------------------------------------------------- /PlanMoE/parted/patterns.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .solver import register_primitive 5 | 6 | def is_replicas(dim): 7 | return dim == -1 8 | 9 | def is_partition(dim): 10 | return dim >= 0 11 | 12 | @register_primitive("BAR") 13 | def primitive_pass_through(sess, node, output_dim, group_size, rank): 14 | if not is_replicas(output_dim) and not is_partition(output_dim): 15 | return 16 | source_dims, num_partitions = node.parser.emit_dims_by_id(output_dim) 17 | 18 | if is_replicas(output_dim) and num_partitions == 0: 19 | yield (0, source_dims, {}) 20 | return 21 | 22 | connectors = dict([(inp, sess.backend.link('$', -1, None, is_param=(node.inputs[inp].op_type == "param"))) for inp in source_dims if is_replicas(source_dims[inp])]) 23 | yield (0, source_dims, connectors) 24 | 25 | @register_primitive("FAR") 26 | def primitive_fwd_allreduce_sum(sess, node, output_dim, group_size, rank): 27 | if not is_replicas(output_dim): 28 | return 29 | if node.parser.reduce_type != '+': 30 | return 31 | 32 | for i, ax in enumerate(node.parser.get_reduce_axes()): 33 | if rank is not None and i != rank: 34 | continue 35 | try: 36 | source_dims, num_partitions = node.parser.emit_dims_by_name(ax) 37 | except NotImplementedError: 38 | continue 39 | assert num_partitions > 0, "It is unexpected that no certain input is parted." 40 | connectors = dict([(inp, sess.backend.link('$', -1, None, is_param=(node.inputs[inp].op_type == "param"))) for inp in source_dims if is_replicas(source_dims[inp])]) 41 | connectors[''] = sess.backend.link('$', None, -1) 42 | yield (i, source_dims, connectors) 43 | 44 | @register_primitive("RS") 45 | def primitive_fwd_reduce_scatter_sum(sess, node, output_dim, group_size, rank): 46 | if not is_partition(output_dim): 47 | return 48 | if node.parser.reduce_type != '+': 49 | return 50 | 51 | for i, ax in enumerate(node.parser.get_reduce_axes()): 52 | if rank is not None and i != rank: 53 | continue 54 | try: 55 | source_dims, num_partitions = node.parser.emit_dims_by_name(ax) 56 | except NotImplementedError: 57 | continue 58 | assert num_partitions > 0, "It is unexpected that no certain input is parted." 59 | connectors = dict([(inp, sess.backend.link('$', -1, None, is_param=(node.inputs[inp].op_type == "param"))) for inp in source_dims if is_replicas(source_dims[inp])]) 60 | connectors[''] = sess.backend.link('$', None, output_dim) 61 | yield (i, source_dims, connectors) 62 | 63 | @register_primitive("SPLIT") 64 | def primitive_fwd_spatial_split(sess, node, output_dim, group_size, rank): 65 | if not is_partition(output_dim): 66 | return 67 | source_dims, num_partitions = node.parser.emit_dims_by_id(-1) 68 | assert num_partitions == 0, "It is unexpected that certain input is parted." 69 | connectors = dict([('', sess.backend.link('$', -1, output_dim))]) 70 | yield (0, source_dims, connectors) 71 | 72 | @register_primitive("AG") 73 | def primitive_fwd_all_gather(sess, node, output_dim, group_size, rank): 74 | if not is_replicas(output_dim): 75 | return 76 | for i in range(len(node.shape)): 77 | if rank is not None and i != rank: 78 | continue 79 | try: 80 | if node.shape[i] % group_size != 0: 81 | continue 82 | source_dims, num_partitions = node.parser.emit_dims_by_id(i) 83 | except NotImplementedError: 84 | continue 85 | if num_partitions == 0: # Handled by fwd_pass_through as well 86 | continue 87 | connectors = dict([(inp, sess.backend.link('$', -1, None, is_param=(node.inputs[inp].op_type == "param"))) for inp in source_dims if is_replicas(source_dims[inp])]) 88 | connectors[''] = sess.backend.link('$', rank, -1) 89 | yield (i, source_dims, connectors) 90 | 91 | @register_primitive("A2A") 92 | def primitive_alltoall(sess, node, output_dim, group_size, rank): 93 | if not is_partition(output_dim): 94 | return 95 | shape = node.shape 96 | if len(shape) < 2 or shape[output_dim] % group_size != 0: 97 | return 98 | for i in range(len(node.shape)): 99 | if rank is not None and i != rank: 100 | continue 101 | if shape[i] % group_size != 0 or output_dim == i: 102 | continue 103 | try: 104 | source_dims, num_partitions = node.parser.emit_dims_by_id(i) 105 | connectors = dict([(inp, sess.backend.link('$', -1, None, is_param=(node.inputs[inp].op_type == "param"))) for inp in source_dims if is_replicas(source_dims[inp])]) 106 | connectors[''] = sess.backend.link('$', i, output_dim) 107 | yield (i, source_dims, connectors) 108 | except NotImplementedError: 109 | continue 110 | 111 | @register_primitive("ZERO") 112 | def primitive_zero(sess, node, output_dim, group_size, rank): 113 | if not is_partition(output_dim): 114 | return 115 | source_dims, num_partitions = node.parser.emit_dims_by_id(output_dim) 116 | if num_partitions == 0: 117 | return 118 | has_params, connectors = False, {} 119 | for inp in source_dims: 120 | if not is_replicas(source_dims[inp]): 121 | continue 122 | if node.inputs[inp].op_type == 'param': 123 | source_dims[inp] = -2 124 | has_params, connectors[inp] = True, sess.backend.link('$', -2, -1, output_shape=node.inputs[inp].shape) 125 | else: 126 | connectors[inp] = sess.backend.link('$', -1, None, is_param=(node.inputs[inp].op_type == "param")) 127 | if not has_params: 128 | return 129 | yield (0, source_dims, connectors) 130 | -------------------------------------------------------------------------------- /PlanMoE/parted/solver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import copy, hashlib 5 | import os, sys 6 | import re 7 | import json 8 | 9 | spmd_primitives_dict = dict() 10 | 11 | def register_primitive(name=None): 12 | if not name: 13 | name = 'custom_%d' % len(spmd_primitives_dict) 14 | def register_primitive_instance(func): 15 | assert name not in spmd_primitives_dict, f"Parallel Pattern with name `{name}` already exists." 16 | spmd_primitives_dict[name] = func 17 | return register_primitive_instance 18 | 19 | def solve_partition(sess, compute_groups, input_nodes, split_pref, kwargs): 20 | marked_nodes, marked_comm = dict(), dict() 21 | last_node = compute_groups[-1][0][-1] 22 | 23 | run_mode = kwargs['run_mode'] 24 | group_size = kwargs['spmd_nodes'] 25 | glob_size = kwargs['total_nodes'] 26 | print(f'\nDistributed for total_nodes = {glob_size}, spmd_nodes = {group_size}, run_mode = `{run_mode}`\n') 27 | 28 | FL = dict() 29 | for input in input_nodes: 30 | FL[input.name] = dict() 31 | if input.name in split_pref: 32 | dim = split_pref[input.name] 33 | FL[input.name][dim] = (0.0, {input.name: (dim, '')}) 34 | continue 35 | FL[input.name][-1] = (0.0, {input.name: (-1, '')}) 36 | FL[input.name][-2] = (0.0, {input.name: (-2, '')}) 37 | for dim in range(len(input.shape)): 38 | if input.shape[dim] % group_size == 0: 39 | FL[input.name][dim] = (0.0, {input.name: (dim, '')}) 40 | 41 | def do_merge(base_config, new_config): 42 | if new_config is None: 43 | return None 44 | new_config = new_config[1] 45 | for k in new_config: 46 | if k not in base_config: 47 | base_config[k] = new_config[k] 48 | elif base_config[k] != new_config[k]: 49 | return None 50 | return base_config 51 | 52 | for compute_nodes, multi_used in compute_groups: 53 | enum_nums = 1 54 | for node in multi_used: 55 | enum_nums *= len(node.shape) + 1 56 | 57 | final_FL = dict() 58 | 59 | for enum_inst in range(enum_nums): 60 | looping_restricted_config = dict() 61 | remainder = enum_inst 62 | for node in multi_used: 63 | jump_val = len(node.shape) + 1 64 | looping_restricted_config[node.name] = remainder % jump_val - 1 65 | remainder //= jump_val 66 | 67 | ##### Looping once 68 | for node in compute_nodes: 69 | output_name = node.name 70 | output_shape = node.shape 71 | FL[output_name] = dict() 72 | 73 | if group_size == 1: 74 | left, right = -1, 0 75 | elif output_name in split_pref: 76 | assert isinstance(split_pref[output_name], int) 77 | left, right = split_pref[output_name], split_pref[output_name] + 1 78 | else: 79 | left, right = -1, len(output_shape) 80 | 81 | for dim in range(left, right): 82 | if looping_restricted_config.get(output_name, dim) != dim: 83 | continue 84 | if dim >= 0 and output_shape[dim] % group_size != 0: 85 | continue 86 | programs = [] 87 | for key in spmd_primitives_dict: 88 | rule_func = spmd_primitives_dict[key] 89 | try: 90 | merged_config = None 91 | for rank, source_dims, connectors in rule_func(sess, node, dim, group_size, None): 92 | merged_config = {node.name: (dim, f'{key}:{rank}')} 93 | for input_id in source_dims: 94 | state = source_dims[input_id] 95 | from_record = FL[node.inputs[input_id].name].get(state, None) 96 | if from_record is not None and looping_restricted_config.get(node.inputs[input_id].name, state) == state: 97 | merged_config = do_merge(merged_config, from_record) 98 | else: 99 | merged_config = None 100 | if not merged_config: 101 | break 102 | if merged_config: 103 | break 104 | if not merged_config: 105 | continue 106 | prog = node.compile(merged_config, **kwargs) 107 | if prog: 108 | programs += [(prog, merged_config)] 109 | except NotImplementedError: 110 | pass 111 | 112 | best_result = (float('inf'), None) 113 | for index, (prog, cfg) in enumerate(programs): 114 | print(f'>> Try `{output_name}:{dim} [ENUM:{enum_inst+1}/{enum_nums}]` ({index+1}/{len(programs)}), config = {json.dumps(cfg)}') 115 | 116 | # Evaluate Program 117 | if enum_nums == 1 and (len(programs) == 1) and (output_name != last_node.name): 118 | model_cost = -1 119 | else: 120 | print('>> Program Snapshot:') 121 | print(prog.code) 122 | model_cost = prog.execute() 123 | model_cost = model_cost.get('step_time', float('inf')) 124 | 125 | if model_cost < best_result[0]: 126 | best_result = (model_cost, cfg) 127 | 128 | if best_result[1] is not None: 129 | FL[output_name][dim] = best_result 130 | print(f'>> FL_{output_name}_{dim} [ENUM:{enum_inst+1}/{enum_nums}] = {best_result}\n') 131 | 132 | # Update enum history best 133 | for dim in FL[compute_nodes[-1].name]: 134 | if dim not in final_FL or final_FL[dim][0] > FL[compute_nodes[-1].name][dim][0]: 135 | final_FL[dim] = FL[compute_nodes[-1].name][dim] 136 | 137 | # Persistent enum best 138 | for node in compute_nodes: 139 | FL[node.name] = None 140 | print(f'>> Updating Stage `{output_name}:{dim}`; Stage Enum = {enum_inst}/{enum_nums}; Valid FL_State[*] Count: {len(FL)}') 141 | sys.stdout.flush() 142 | FL[compute_nodes[-1].name] = final_FL 143 | 144 | return [(dim, FL[last_node.name].get(dim, None)) for dim in range(-1, len(last_node.shape))] 145 | 146 | -------------------------------------------------------------------------------- /PlanMoE/parted/spmdx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os, sys, json, re 5 | import tempfile 6 | import copy 7 | import inspect 8 | import logging 9 | import importlib 10 | 11 | from . import solver 12 | from . import patterns 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | session = None 16 | 17 | def init(backend_name): 18 | global session 19 | if session is not None: 20 | raise Exception('Function `init()` can be only invoked once.') 21 | if not re.match('^[a-zA-Z0-9]+$', backend_name): 22 | raise Exception('Only letters and digits are allowed for backend_name, get: %s' % backend_name) 23 | session = init 24 | session.backend = importlib.import_module('..backend.%s.config' % backend_name, __name__) 25 | session.is_strict_fmt = int(os.environ.get('STRICT_FMT', 0)) > 0 26 | session.ptype = os.environ.get('PTYPE', '') 27 | session.custom_dict = dict() 28 | 29 | manual_config = os.environ.get('CONFIG', '') 30 | manual_config = json.loads(manual_config) if manual_config else {} 31 | manual_config = dict([(x, manual_config[x] if isinstance(manual_config[x], int) else manual_config[x][0]) for x in manual_config]) 32 | session.manual_config = manual_config 33 | try: 34 | extra = importlib.import_module('..backend.%s' % backend_name, __name__) 35 | except: 36 | extra = None 37 | return extra 38 | 39 | def new_dependency(header_content, depends=[]): 40 | header_content = header_content.strip() + '\n' 41 | depends = depends if isinstance(depends, list) else [depends] 42 | return {"data": header_content, "depends": depends} 43 | 44 | def product(arrlist): 45 | result = 1 46 | for x in arrlist: result *= int(x) 47 | return result 48 | 49 | class Mapper2D: 50 | def __init__(self, item): 51 | def split_dim(item): 52 | parts = item.replace(')', '(').split('(') 53 | for i in range(len(parts)): 54 | if i % 2 == 0: 55 | for x in parts[i]: 56 | if x.strip(): 57 | yield x 58 | else: 59 | x = [x for x in parts[i] if x.strip()] 60 | yield x if len(x) > 1 else x[0] 61 | 62 | iter = split_dim(item) 63 | self.id2ax = [x for x in iter] 64 | self.ax2id = dict([(x, i) for i, x in enumerate(self.id2ax) if isinstance(x, str) and x != '*']) 65 | for i, x in enumerate(self.id2ax): 66 | if not isinstance(x, str): 67 | for j, ax in enumerate(x): 68 | self.ax2id[ax] = (i, j) 69 | 70 | class Parser: 71 | def __init__(self, irs): 72 | left, rights = irs.split('=') 73 | reduce_type = '' 74 | if left[-1] in ('+', '<', '>', '[', ']'): 75 | left, reduce_type = left[:-1], left[-1] 76 | 77 | self.reduce_type = reduce_type 78 | self.left = Mapper2D(left) 79 | self.rights = [Mapper2D(x) for x in rights.split(',')] 80 | self.num_inputs = len(self.rights) 81 | 82 | def get_leading_target(self, target): 83 | return target if isinstance(target, str) else target[0] 84 | 85 | def get_reduce_axes(self): 86 | reduce_axes = set() 87 | for right in self.rights: 88 | for k in right.ax2id: 89 | if k not in self.left.ax2id: 90 | reduce_axes.add(k) 91 | return reduce_axes 92 | 93 | def emit_dims_by_name(self, ax_name): 94 | if ax_name == '*': 95 | raise NotImplementedError() 96 | target_ax = self.get_leading_target(ax_name) 97 | source_dims, parted = dict(), 0 98 | for i, right in enumerate(self.rights): 99 | if target_ax not in right.ax2id: 100 | source_dims[i] = -1 101 | continue 102 | ids = right.ax2id[target_ax] 103 | if isinstance(ids, int): 104 | source_dims[i] = ids 105 | elif ids[1] == 0: 106 | source_dims[i] = ids[0] 107 | else: 108 | raise NotImplementedError() 109 | parted += 1 110 | return source_dims, parted 111 | 112 | def emit_dims_by_id(self, output_dim): 113 | if output_dim == -1: 114 | return dict([(i, -1) for i in range(self.num_inputs)]), 0 115 | if output_dim == -2 or self.left.id2ax[output_dim] == '*': 116 | raise NotImplementedError() 117 | if output_dim >= 0: 118 | return self.emit_dims_by_name(self.left.id2ax[output_dim]) 119 | raise NotImplementedError() 120 | 121 | 122 | class Program: 123 | def __init__(self, code, kwargs): 124 | self.code = code 125 | self.kwargs = kwargs 126 | 127 | def save(self, path): 128 | with open(path, 'w') as fp: 129 | fp.write('# Copyright (c) Microsoft Corporation.\n') 130 | fp.write('# Licensed under the MIT license.\n\n') 131 | fp.write(self.code) 132 | 133 | def execute(self, save_file_path=None): 134 | is_tempfile = save_file_path is None 135 | if is_tempfile: 136 | save_file_path = tempfile.NamedTemporaryFile(delete=False, dir=tempfile.gettempdir(), suffix='.py').name 137 | 138 | def remove_file(filenames): 139 | if isinstance(filenames, str): 140 | filenames = [filenames] 141 | for filename in filenames: 142 | try: 143 | os.unlink(filename) 144 | except FileNotFoundError: 145 | pass 146 | 147 | remove_file(save_file_path) 148 | 149 | model_program = self.code 150 | glob_size = self.kwargs['total_nodes'] 151 | device_type = self.kwargs['device_type'] 152 | group_size = self.kwargs['spmd_nodes'] 153 | 154 | with open(save_file_path, 'w') as fp: 155 | fp.write(model_program) 156 | 157 | log_file = tempfile.NamedTemporaryFile(delete=False, dir=tempfile.gettempdir(), suffix='.log').name 158 | os.environ['CONFIG_STORE_PATH'] = log_file 159 | remove_file(log_file) 160 | os_command = session.backend.get_execute_cmd(group_size, glob_size, device_type, save_file_path) 161 | 162 | try: 163 | result = '' 164 | logging.info('Executing: %s' % os_command) 165 | assert 0 == os.system(os_command), f"Failed to execute command: {os_command}" 166 | with open(log_file, 'r') as fp: 167 | result = fp.read().strip() 168 | result = json.loads(result) 169 | except: 170 | import traceback 171 | print(traceback.format_exc()) 172 | print(result) 173 | result = {} 174 | if is_tempfile: 175 | remove_file(save_file_path) 176 | return result 177 | 178 | class Custom: 179 | __t_builtins__ = dict() 180 | __t_ids__ = dict() 181 | __t_ops__ = dict() 182 | 183 | def __init__(self, data, fw_ops=None, inputs=None, op_name=None, shape_fn=None, flops=None, depends=[]): 184 | self.op_type = op_name or inspect.currentframe().f_back.f_code.co_name 185 | if not re.match('^[a-zA-Z0-9]+$', self.op_type): 186 | self.op_type = 'Custom' 187 | assert self.op_type[0].isupper(), f'The leading charactor of the operator name must be uppercase letter (received: "{self.op_type}").' 188 | rank_dict = (Custom.__t_ops__ if self.op_type != 'Builtin' else Custom.__t_builtins__) if self.op_type != 'Id' else Custom.__t_ids__ 189 | 190 | rank_dict[self] = len(rank_dict) 191 | self.name = f'{self.op_type[0].lower()}{self.op_type[1:]}{rank_dict[self]}' 192 | self.depends = depends if isinstance(depends, list) else [depends] 193 | 194 | if fw_ops is not None: 195 | self.fw_ops = fw_ops.replace('@@', '') 196 | 197 | if inputs is None: 198 | assert self.fw_ops is not None, 'At least one property in "fw_ops" and inputs should be specified.' 199 | fw_ops = fw_ops.split('@@') 200 | input_names = [] 201 | for x in range(1, len(fw_ops), 2): 202 | if fw_ops[x] not in input_names: 203 | input_names.append(fw_ops[x]) 204 | self.inputs = [session.custom_dict[x] for x in input_names] 205 | else: 206 | self.inputs = inputs 207 | 208 | self.outputs = [] 209 | self.data = data 210 | 211 | if isinstance(data, dict): 212 | self.op_type = 'data' 213 | if data["is_param"]: 214 | self.name += '_' 215 | self.op_type = 'param' 216 | self.inputs = [] 217 | 218 | self.shape = data["shape"] 219 | self.dtype = data["dtype"] 220 | self.flops = flops or 0 221 | else: 222 | self.op_type = 'compute' 223 | self.parser = Parser(data) 224 | 225 | if shape_fn is not None: 226 | self.shape, self.dtype = shape_fn(self.inputs) 227 | else: 228 | try: 229 | infershape = dict() 230 | for i, x in enumerate(self.parser.rights): 231 | for ax in x.ax2id: 232 | infershape[ax] = self.inputs[i].shape[x.ax2id[ax]] 233 | self.shape = [infershape[x] if not isinstance(x, list) else product([infershape[y] for y in x]) for x in self.parser.left.id2ax] 234 | self.dtype = self.inputs[0].dtype 235 | except: 236 | raise Exception(f'Cannot auto-infershape for op {self.name} due to unknown dimension size by tensor format: {self.data}') 237 | # logging.info(f'Shape dict of {self.name} = {self.shape}:{self.dtype}') 238 | 239 | if flops is None: 240 | self.flops = product(self.shape) 241 | if self.parser.reduce_type: 242 | infershape = dict() 243 | for i, x in enumerate(self.parser.rights): 244 | for ax in x.ax2id: 245 | if isinstance(ax, str): 246 | infershape[ax] = self.inputs[i].shape[x.ax2id[ax]] 247 | self.flops *= product([infershape[x] for x in self.parser.get_reduce_axes()]) 248 | self.flops <<= 1 249 | else: 250 | self.flops = flops 251 | 252 | assert self.name not in session.custom_dict, f"Node with name `{self.name}` has already existed in current session." 253 | session.custom_dict[self.name] = self 254 | 255 | def __del__(self): 256 | try: 257 | session.custom_dict.pop(self.name) 258 | except: 259 | pass 260 | 261 | def update_config(self, parent, **kwargs): 262 | if parent is not None and parent not in self.outputs: 263 | self.outputs.append(parent) 264 | node_name = self.name 265 | if kwargs['spmd_nodes'] == 1: 266 | self.config = -1 267 | elif session.ptype == 'dp': 268 | self.config = -1 if self.op_type == 'param' else 0 269 | elif session.ptype == 'zero': 270 | self.config = -2 if self.op_type == 'param' else 0 271 | elif node_name in session.manual_config: 272 | self.config = session.manual_config[node_name] 273 | for input in self.inputs: 274 | input.update_config(self, **kwargs) 275 | 276 | def __str__(self): 277 | return f'@@{self.name}@@' 278 | 279 | def numel(self): 280 | return int(product(self.shape)) 281 | 282 | def parse_inputs(self): 283 | if isinstance(self.data, dict): 284 | return [] 285 | results, patt = [], self.data 286 | while True: 287 | pos = re.search(r'\b[a-z][a-zA-Z0-9_]*\b', patt) 288 | if not pos: 289 | break 290 | results += [patt[pos.start():pos.end()]] 291 | patt = patt[pos.end() + 1:] 292 | return results 293 | 294 | def get_leading_dim(self): 295 | return [i for i, x in enumerate(self.shape) if x > 1][0] 296 | 297 | def get_input_by_name(self, name): 298 | for inp in self.inputs: 299 | if inp.name == name: 300 | return inp 301 | raise Exception(f'Node input with name `{name}` not found!') 302 | 303 | def autotune(self, config_file=None, **kwargs): 304 | config = Config.load_from_file(config_file) 305 | if config: 306 | return config 307 | kwargs, results = optimize(self, **kwargs) 308 | valid_configs = [sol for dim, sol in results if sol is not None] 309 | if not valid_configs: 310 | raise Exception('No valid configuration found!') 311 | best_time, best_config = min(valid_configs) 312 | config = Config.create(best_config, kwargs, best_time) 313 | if config_file is not None: 314 | config.save(config_file) 315 | return config 316 | 317 | def articulare_analyse(self): 318 | low, dfn, cut = dict(), dict(), dict() 319 | pcnt, root, st = [0], self, [] 320 | 321 | ##### Mask Articulation Points 322 | def mask_dfs(u): 323 | tot = 0 324 | st.append(u) 325 | pcnt[0] += 1 326 | dfn[u] = low[u] = pcnt[0] 327 | 328 | for v in u.inputs + u.outputs: 329 | # Assume every param tensor is unshared 330 | if v.op_type == 'param': 331 | continue 332 | if v not in dfn: 333 | tot += 1 334 | mask_dfs(v) 335 | low[u] = min(low[u], low[v]) 336 | if ((u == root and tot > 1) or (u != root and low[v] >= dfn[u])): 337 | cut[u] = cut.get(u, 0) + 1 338 | if low[v] >= dfn[u]: 339 | while st.pop() != v: 340 | continue 341 | else: 342 | low[u] = min(low[u], dfn[v]) 343 | cut[u] = cut.get(u, 0) + 1 344 | 345 | mask_dfs(self) 346 | 347 | ##### Partition Computations into Groups 348 | pcnt, visited, group_export = [0], set(), dict() 349 | 350 | def compute_dfs(u, vid, is_leader): 351 | if u in visited: 352 | return 353 | if u.op_type != 'compute': 354 | return 355 | if is_leader: 356 | group_export[vid] = [u] 357 | else: 358 | group_export[vid].append(u) 359 | 360 | visited.add(u) 361 | for v in u.inputs: 362 | if cut.get(v, 0) > 1: 363 | pcnt[0] += 1 364 | compute_dfs(v, pcnt[0], True) 365 | else: 366 | compute_dfs(v, vid, False) 367 | 368 | compute_dfs(self, pcnt[0], True) 369 | 370 | compute_groups = [] 371 | for _, members in sorted(group_export.items(), reverse=True): 372 | for x in members: 373 | multi_used = set() 374 | for y in x.inputs: 375 | if len(y.outputs) > 1: 376 | multi_used.add(y) 377 | compute_groups.append(([x for x in reversed(members)], multi_used)) 378 | return compute_groups 379 | 380 | def get_data_parallel_config(self, **kwargs): 381 | visited = set() 382 | config = dict() 383 | 384 | def property_dfs(node): 385 | visited.add(id(node)) 386 | for inp in node.inputs: 387 | if id(inp) not in visited: 388 | property_dfs(inp) 389 | config[node.name] = [-1, ""] if node.op_type == 'param' else [0, "BAR:0"] 390 | 391 | property_dfs(self) 392 | return Config.create(config, environ_config(kwargs)) 393 | 394 | def serialize(self, **kwargs): 395 | node = self 396 | node.update_config(None, **kwargs) 397 | 398 | compute_groups = node.articulare_analyse() 399 | 400 | input_nodes, compute_nodes, config = [], [], {} 401 | visited = set() 402 | 403 | def property_dfs(node): 404 | visited.add(id(node)) 405 | node_name = node.name 406 | for inp in node.inputs: 407 | if id(inp) not in visited: 408 | property_dfs(inp) 409 | if hasattr(node, 'config'): 410 | config[node_name] = getattr(node, 'config') 411 | if isinstance(node.data, dict): 412 | input_nodes.append(node) 413 | else: 414 | compute_nodes.append(node) 415 | property_dfs(node) 416 | 417 | return compute_groups, compute_nodes, input_nodes, config 418 | 419 | def compile(self, config, **kwargs): 420 | if not isinstance(config, dict): 421 | assert config.config['v'] == Config.VERSION, f"Unmatched configuration file version: expect {Config.VERSION}, got {config.config['v']}" 422 | for k in kwargs: 423 | config.config['kwargs'][k] = kwargs[k] 424 | kwargs = config.config['kwargs'] 425 | config = config.config['b'] 426 | 427 | run_mode = kwargs['run_mode'] 428 | device_type = kwargs['device_type'] 429 | total_nodes = kwargs['total_nodes'] 430 | spmd_nodes = kwargs['spmd_nodes'] 431 | assert total_nodes % spmd_nodes == 0, f"`total_nodes` must by evenly divided by `spmd_nodes`, got: {total_nodes} % {spmd_nodes} != 0" 432 | 433 | if True: 434 | _, compute_nodes, input_nodes, restricted_state = self.serialize(**kwargs) 435 | 436 | # Verify restricted_state & extra padding 437 | for node in compute_nodes + input_nodes: 438 | node_state = config[node.name][0] 439 | if restricted_state.get(node.name, node_state) != node_state: 440 | raise Exception(f"Unstatisfied sharding state requirements on node `{node.name}`") 441 | if node_state >= 0 and node.shape[node_state] % spmd_nodes != 0: 442 | raise Exception(f"Unstatisfied slicing chunks `{node.shape[node_state]} // {spmd_nodes}` on node `{node.name}`") 443 | 444 | # Construct Inputs 445 | input_list, param_list = [], [] 446 | for node in input_nodes: 447 | shard_dim, _ = config[node.name] 448 | if node.op_type != 'param': 449 | input_list.append((node.name, session.backend.get_input_definition(node.name, node.shape, shard_dim, node.dtype, is_param=False))) 450 | else: 451 | param_list.append((node.name, session.backend.get_input_definition(node.name, node.shape, shard_dim, node.dtype, is_param=True))) 452 | 453 | def apply_communicate(item_name, comm_op): 454 | return re.sub(fr'\$', item_name, comm_op).strip() 455 | 456 | # Construct Computes 457 | graph_prog, temp_ids = [], 0 458 | for node in compute_nodes: 459 | output_dim, key = config[node.name] 460 | if ':' in key: 461 | key, rank = key.split(':') 462 | rank = int(rank) 463 | else: 464 | rank = None 465 | rule_func = solver.spmd_primitives_dict[key] 466 | conn_sol, conn_src = None, None 467 | try: 468 | valid_count = 0 469 | for rank, source_dims, connectors in rule_func(session, node, output_dim, spmd_nodes, rank): 470 | valid_count += 1 471 | assert valid_count <= 1, f"Ambiguous solution `{key}` for node with `{node.name}` at dimension {output_dim}" 472 | conn_sol, conn_src = connectors, source_dims 473 | except NotImplementedError: 474 | pass 475 | assert conn_sol is not None, f"No statisfied parallel pattern `{key}` applying on node `{node.name}`" 476 | 477 | graph_prog += [f'{node.name} = {node.fw_ops}',] 478 | for index in range(len(node.inputs)): 479 | input_item = node.inputs[index] 480 | item_name = input_item.name 481 | from_state = config[item_name][0] 482 | prim_state = conn_src[index] 483 | if from_state != prim_state: 484 | extra = {'output_shape': node.inputs[index].shape, 'is_param': node.inputs[index].op_type == 'param'} 485 | if from_state == -2 and prim_state >= 0: 486 | item_name = session.backend.link(item_name, -2, -1, **extra) 487 | item_name = session.backend.link(item_name, -1, prim_state, **extra) 488 | else: 489 | item_name = session.backend.link(item_name, from_state, prim_state, **extra) 490 | 491 | if index in conn_sol: 492 | item_name = apply_communicate(item_name, conn_sol[index]) or item_name 493 | 494 | if item_name != input_item.name: 495 | temp_ids = temp_ids + 1 496 | graph_prog[-1] = f'_temp{temp_ids} = {item_name}; ' + re.sub(fr'\b{input_item.name}\b', f'_temp{temp_ids}', graph_prog[-1]) 497 | 498 | aggr_output = apply_communicate(node.name, conn_sol.get('', '')) 499 | if aggr_output: 500 | graph_prog += [f'{node.name} = {aggr_output}'] 501 | 502 | depends, headers = set(), [] 503 | def compute_dependencies(nodes): 504 | for node in nodes: 505 | if id(node) in depends: 506 | continue 507 | depends.add(id(node)) 508 | for dep in node["depends"]: 509 | compute_dependencies(dep) 510 | headers.append(node["data"]) 511 | 512 | for node in compute_nodes: 513 | compute_dependencies(node.depends) 514 | 515 | program_strings = session.backend.generate_framework_code(device_type, spmd_nodes, total_nodes // spmd_nodes, run_mode, self.name, headers, input_list, param_list, graph_prog) 516 | return Program(program_strings, kwargs) 517 | 518 | def environ_config(kwargs): 519 | if 'spmd_nodes' not in kwargs: 520 | kwargs['spmd_nodes'] = kwargs['total_nodes'] 521 | if 'device_type' not in kwargs: 522 | kwargs['device_type'] = os.environ.get('DEVICE', 'cuda') 523 | if 'run_mode' not in kwargs: 524 | kwargs['run_mode'] = os.environ.get('MODE', 'train') 525 | assert kwargs['total_nodes'] % kwargs['spmd_nodes'] == 0, "`total_nodes` must be exactly divided by `spmd_nodes`." 526 | return kwargs 527 | 528 | def optimize(node, **kwargs): 529 | kwargs = environ_config(kwargs) 530 | 531 | if session.is_strict_fmt: 532 | node = Id(node, op_name='Builtin') 533 | node.config = 0 534 | 535 | compute_groups, compute_nodes, input_nodes, config = node.serialize(**kwargs) 536 | 537 | print('<< TUNE Graph >>\n') 538 | print('\n'.join([f'| {x.name} <- new_{x.op_type}() | {x.dtype}{x.shape} | {getattr(x, "config", None)} |' for x in input_nodes])) 539 | print('---------------------------------------------------') 540 | print('\n'.join([f'| {x.name} <- {", ".join([x.name for x in x.inputs])} | {x.dtype}{x.shape} | "{x.data}" | {getattr(x, "config", None)} |' for x in compute_nodes])) 541 | print('\n>> config = %s\n' % (json.dumps(config),)) 542 | sys.stdout.flush() 543 | return kwargs, solver.solve_partition(session, compute_groups, input_nodes=input_nodes, split_pref=config, kwargs=kwargs) 544 | 545 | class Config: 546 | VERSION = '0.1' 547 | 548 | @staticmethod 549 | def load_from_file(filename): 550 | if filename is not None and os.path.exists(filename): 551 | return Config(filename) 552 | return None 553 | 554 | @staticmethod 555 | def create(config, environ, timecost=0): 556 | return Config({'v': Config.VERSION, 't': timecost, 'b': config, 'kwargs': environ}) 557 | 558 | def __init__(self, config): 559 | if isinstance(config, dict): 560 | self.set_config(config) 561 | elif isinstance(config, str): 562 | with open(config, 'r') as fp: 563 | config = json.load(fp) 564 | self.set_config(config) 565 | else: 566 | raise Exception('Unsupported config value: %s' % config) 567 | 568 | def set_config(self, config): 569 | if config['v'] != Config.VERSION: 570 | raise Exception('Incompatible config version: expect %s, got %s' % (Config.VERSION, config['v'])) 571 | self.config = config 572 | 573 | def __str__(self): 574 | return json.dumps(self.config) 575 | 576 | def save(self, filepath): 577 | with open(filepath, 'w') as fp: 578 | json.dump(self.config, fp) 579 | 580 | def Id(x, op_name=None): 581 | layout = ''.join([chr(ord('a') + i) for i in range(len(x.shape))]) 582 | return Custom(f'{layout} = {layout}', f'{x}', op_name=op_name) 583 | 584 | def Tensor(shape, dtype, is_param=False): 585 | inp = Custom({"shape": shape, "dtype": dtype, "is_param": is_param}, inputs=[]) 586 | if not is_param and session.is_strict_fmt: 587 | config = getattr(inp, 'config', session.manual_config.get(inp.name, None)) 588 | if config is not None: 589 | if inp.name in session.manual_config: 590 | session.manual_config.pop(inp.name) 591 | inp.config = 0 592 | inp = Id(inp, op_name="Builtin") 593 | inp.config = config 594 | else: 595 | inp.config = 0 596 | inp = Id(inp, op_name="Builtin") 597 | return inp 598 | -------------------------------------------------------------------------------- /PlanMoE/system.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os, sys 5 | import re 6 | import logging 7 | 8 | TUTEL_CUDA_SANDBOX = int(os.environ.get('TUTEL_CUDA_SANDBOX', 0)) 9 | 10 | def init_affinity_at_program_beginning(): 11 | if TUTEL_CUDA_SANDBOX: 12 | return 13 | try: 14 | numa_type = int(os.environ.get('NUMA_TYPE', '1')) 15 | if numa_type <= 0: 16 | return 17 | group_rank = int(os.environ.get('LOCAL_RANK', '0')) 18 | nodes = sorted([int(x[4:]) for x in os.listdir('/sys/devices/system/node') if re.match('node[0-9]+', x)]) 19 | cpus = [sorted([int(x[3:]) for x in os.listdir('/sys/devices/system/node/node%d' % node_id) if re.match('cpu[0-9]+', x)]) for node_id in nodes] 20 | sel_node = (group_rank // numa_type) % len(nodes) 21 | os.sched_setaffinity(0, cpus[sel_node]) 22 | logging.info('LOCAL_RANK %d is to set NUMA node: %d (total NUMA nodes = %d)' % (group_rank, sel_node, len(nodes))) 23 | except Exception as ex: 24 | if group_rank == 0: 25 | logging.warning('Failed to set NUMA status: %s' % ex) 26 | 27 | def init_data_model_parallel(group_count=1, backend='nccl'): 28 | from putel import net as C 29 | result = C.create_groups_from_world(group_count=group_count, include_init=backend) 30 | result.is_cuda = (result.local_device.type == 'cuda') 31 | 32 | logging.critical(f'Registering device global rank {result.global_rank}: data_rank = {result.data_rank}, model_rank = {result.model_rank}') 33 | init_data_model_parallel.default_env = result 34 | 35 | def on_quit(): 36 | sys.stdout.flush() 37 | sys.stderr.flush() 38 | # Builtin dist.all_to_all_single in torch is unstable in some versions. 39 | # Temp work around: https://github.com/pytorch/pytorch/issues/56390 40 | if getattr(C.simple_all_to_all, '_use_builtins', False): 41 | os._exit(0) 42 | 43 | import atexit 44 | atexit.register(lambda *args: on_quit()) 45 | return result 46 | 47 | class LocalCache: 48 | _CACHE = dict() 49 | 50 | @staticmethod 51 | def reset(): 52 | LocalCache._CACHE = dict() 53 | 54 | @staticmethod 55 | def set(key, val): 56 | LocalCache._CACHE[key] = val 57 | 58 | @staticmethod 59 | def get(key=None): 60 | if key not in LocalCache._CACHE: 61 | return [LocalCache._CACHE[x] for x in LocalCache._CACHE] 62 | return LocalCache._CACHE[key] 63 | 64 | def cache(): 65 | return LocalCache 66 | 67 | def get_local_session(): 68 | if not hasattr(init_data_model_parallel, 'default_env'): 69 | raise Exception("Current session is not initialized with: system.init_data_model_parallel() from tutel. Please try with: system.record_time(is_cuda=False)") 70 | return init_data_model_parallel.default_env 71 | 72 | def record_time(is_cuda=None): 73 | import time 74 | is_cuda = is_cuda if is_cuda is not None else get_local_session().is_cuda 75 | if is_cuda: 76 | import torch 77 | torch.cuda.synchronize() 78 | return time.time() 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![image](https://github.com/user-attachments/assets/fb0425a4-db58-4b98-998c-f2f8df8366f5) 2 | 3 | # PlanMoE 4 | Large language models (LLMs) have achieved significant breakthroughs in many neural language processing (NLP) and computer vision tasks with increasing model sizes. However, scaling LLMs requires a linear increase of compute with respect to the model size. Recently, the sparsely activated mixture-of-experts (MOE) technology, which was first proposed in 1990s, has been integrated into LLMs to scale the model size to trillions of parameters with requiring only a sub-linear increase of computations.   5 | 6 | However, when training MoE models on a large-scale GPU/TPU cluster, it would introduce critical performance issues that make the distributed training system scale badly. Specifically, in training MoE models, the input data (e.g., tokens) of MoE layers should be dynamically (every mini-batch) routed to different experts for computation, but the experts may be located on different workers when one worker (e.g., GPU) cannot store all experts.   7 | 8 | # Our Contributions 9 | In this work, we propose PlanMoE, an extensible and efficient MoE training system, which is equipped with several features: 10 | 11 | PlanMoE provides a generic scheduling framework that allows the communication and computation tasks in training MoE models to be scheduled in an optimal way.   12 | PlanMoE integrates our proposed novel all-to-all collective which better utilizes intra- and inter-connect bandwidths.   13 | PlanMoE supports easy extensions of customized all-to-all collectives and data compression approaches while enjoying our scheduling algorithm.   14 | Code Design 15 | The PlanMoE system is designed to be extensible and efficient. To this end, we have made the following design decisions: 16 | 17 | We modularize the time-consuming operations including data compression (a computing task), collective communication (a communication task), and expert computation (a computing task) so that these operations are easily customized with newly designed implementations.   18 | Based on the modularized operations, we propose an adaptive optimal scheduling algorithm to pipeline the communication and computing tasks to improve the training efficiency.   19 | We design a novel all-to-all algorithm, Pipe-A2A, that pipelines the intra-node communications and inter-node communications such that the intra-node bandwidth and inter-node bandwidth can be simultaneously utilized to improve communication efficiency.   20 | 21 | 22 | 23 | 24 | The development of this code refers to [tutel](https://github.com/microsoft/tutel). 25 | 26 | ## Prerequisite 27 | 28 | torch>=1.9.1 29 | 30 | ## How to install 31 | 32 | ```Shell 33 | # Install zfp 34 | cd zfp 35 | mkdir build 36 | cd build 37 | cmake .. 38 | cmake --build . --config Release 39 | ctest 40 | cd ../.. 41 | 42 | 43 | cd PlanMoE 44 | # May change include_dirs and library_dirs in setup.py 45 | pip install -e . 46 | ``` 47 | 48 | ## How to Use 49 | 50 | ```python3 51 | # Single Machine: 52 | python3 -m torch.distributed.run --nproc_per_node=4 -m planmoe.examples.pre_test --batch_size=16 53 | # Distribute: 54 | # pls refers to planmoe/examples/run_mpi.sh 55 | ``` 56 | 57 | ## How to Add a New Compressor 58 | 59 | 1. Navigate to the planmoe/custom/compressor/ directory. 60 | 61 | 2. Create a new compressor class that inherits from the AbstractCompressor class. 62 | 63 | 3. Implement the virtual functions defined in abstract.h within your new compressor class. 64 | 65 | ## How to Add a New AllToAll Communication Algorithm 66 | 67 | 1. Navigate to the planmoe/custom/comm/ directory. 68 | 69 | 2. Create a new comm class that inherits from the AbstractComm class. 70 | 71 | 3. Implement the virtual functions defined in abstract.h within your new comm class. 72 | 73 | ## Test Environment 74 | 75 | - g++==7.5.0 76 | - cuda==10.2 77 | - gpu==2080Ti 78 | 79 | 80 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT License. 4 | 5 | """The setuptools based setup module. 6 | 7 | Reference: 8 | https://packaging.python.org/guides/distributing-packages-using-setuptools/ 9 | """ 10 | 11 | import os, sys 12 | import subprocess 13 | import platform as pf 14 | 15 | from typing import List, Tuple 16 | 17 | from setuptools import setup, find_packages, Command 18 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension 19 | 20 | try: 21 | from torch.utils.cpp_extension import IS_HIP_EXTENSION 22 | except: 23 | IS_HIP_EXTENSION = False 24 | 25 | if len(sys.argv) <= 1: 26 | sys.argv += ["install", "--user"] 27 | 28 | root_path = os.path.dirname(sys.argv[0]) 29 | root_path = root_path if root_path else "." 30 | 31 | os.chdir(root_path) 32 | 33 | 34 | class Tester(Command): 35 | """Cmdclass for `python setup.py test`. 36 | Args: 37 | Command (distutils.cmd.Command): 38 | Abstract base class for defining command classes. 39 | """ 40 | 41 | description = "test the code using pytest" 42 | user_options: List[Tuple[str, str, str]] = [] 43 | 44 | def initialize_options(self): 45 | """Set default values for options that this command supports.""" 46 | pass 47 | 48 | def finalize_options(self): 49 | """Set final values for options that this command supports.""" 50 | pass 51 | 52 | def run(self): 53 | """Run pytest.""" 54 | subprocess.check_call("python3 -m pytest -v -s tests/", shell=True) 55 | 56 | 57 | def install(use_cuda, use_nccl): 58 | ext_libs, ext_args = [ 59 | "zfp", 60 | ], { 61 | "cxx": ( 62 | [ 63 | "-Wno-sign-compare", 64 | "-Wno-unused-but-set-variable", 65 | "-Wno-terminate", 66 | "-Wno-unused-function", 67 | ] 68 | if pf.system() == "Linux" 69 | else [] 70 | ) 71 | } 72 | if not use_cuda: 73 | use_nccl = False 74 | extension = CppExtension 75 | else: 76 | ext_libs += ["cuda", "nvrtc"] if not IS_HIP_EXTENSION else [] 77 | ext_args["cxx"] += ["-DUSE_GPU"] 78 | extension = CUDAExtension 79 | 80 | if use_nccl: 81 | if IS_HIP_EXTENSION: 82 | ext_libs += ["rccl"] 83 | else: 84 | ext_libs += ["nccl"] 85 | ext_args["cxx"] += ["-DUSE_NCCL"] 86 | 87 | setup( 88 | name="schemoe", 89 | version="0.1", 90 | description="An Optimized Mixture-of-Experts Implementation.", 91 | license="MIT", 92 | classifiers=[ 93 | "Development Status :: 2 - Pre-Alpha", 94 | "Environment :: GPU", 95 | "Intended Audience :: Developers", 96 | "Intended Audience :: Education", 97 | "Intended Audience :: Science/Research", 98 | "License :: OSI Approved :: MIT License", 99 | "Programming Language :: Python :: 3", 100 | "Programming Language :: Python :: 3 :: Only", 101 | "Programming Language :: Python :: 3.5", 102 | "Programming Language :: Python :: 3.6", 103 | "Programming Language :: Python :: 3.7", 104 | "Programming Language :: Python :: 3.8", 105 | "Programming Language :: Python :: 3.9", 106 | "Topic :: Scientific/Engineering", 107 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 108 | "Topic :: Software Development", 109 | "Topic :: Software Development :: Libraries", 110 | "Topic :: Software Development :: Libraries :: Python Modules", 111 | ], 112 | keywords=["Mixture of Experts", "MoE", "Optimization"], 113 | packages=find_packages(exclude=["tests"]), 114 | python_requires=">=3.6, <4", 115 | install_requires=[], 116 | extras_require={ 117 | "test": [ 118 | "GPUtil>=1.4.0", 119 | "pytest-subtests>=0.4.0", 120 | "pytest>=6.2.2", 121 | ], 122 | }, 123 | ext_modules=[ 124 | extension( 125 | "schemoe_custom_kernel", 126 | include_dirs=[ 127 | "/home/xinglinpan/zfp/include/", 128 | ], 129 | sources=[ 130 | "./schemoe/custom/comm/abstract.cpp", 131 | "./schemoe/custom/comm/hetu.cpp", 132 | "./schemoe/custom/comm/layout_transform.cu", 133 | "./schemoe/custom/comm/naive.cpp", 134 | "./schemoe/custom/comm/pipe.cpp", 135 | "./schemoe/custom/compressor/abstract.cpp", 136 | "./schemoe/custom/compressor/gpulz.cu", 137 | "./schemoe/custom/compressor/int8.cpp", 138 | "./schemoe/custom/compressor/lz.cpp", 139 | "./schemoe/custom/compressor/no.cpp", 140 | "./schemoe/custom/compressor/zfpc.cpp", 141 | "./schemoe/custom/custom_kernel.cpp", 142 | "./schemoe/custom/dd_comm.cpp", 143 | "./schemoe/custom/jit.cpp", 144 | ], 145 | library_dirs=[ 146 | "/home/xinglinpan/zfp/build/lib", 147 | "/usr/local/cuda-10.2/lib64/stubs", 148 | ], 149 | libraries=ext_libs, 150 | extra_compile_args=ext_args, 151 | ) 152 | ], 153 | cmdclass={ 154 | "build_ext": BuildExtension, 155 | "test": Tester, 156 | }, 157 | project_urls={ 158 | "Source": "https://github.com/Fragile-azalea/ScheMoE", 159 | }, 160 | ) 161 | 162 | 163 | install(use_cuda=True, use_nccl=True) 164 | --------------------------------------------------------------------------------