├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── .pylintrc ├── LICENSE ├── README.md ├── cuda ├── .gitignore ├── balancing.cu ├── balancing.cuh ├── fastermoe │ ├── smart_schedule.cpp │ ├── smart_schedule.h │ └── status.h ├── fmoe_cuda.cpp ├── global_exchange.cpp ├── global_exchange.h ├── local_exchange.cu ├── local_exchange.cuh ├── parallel_linear.cu ├── parallel_linear.cuh ├── stream_manager.cpp ├── stream_manager.h ├── tests │ ├── .gitignore │ ├── Makefile │ ├── assign.cu │ ├── counting.cu │ ├── limit.cu │ └── prune_gate.cu └── utils │ ├── cublas_wrapper.h │ ├── fmoe_utils.h │ ├── helper_cuda.h │ └── timer.hh ├── doc ├── fastermoe │ ├── README.md │ └── smartsch.png ├── installation-guide.md ├── logo │ ├── rect.png │ └── square.png ├── parallelism │ ├── README.md │ ├── fastmoe_data_parallel.png │ ├── fastmoe_expert_parallel.png │ └── parallelism.png ├── readme-cn.md └── release-note.md ├── examples ├── .gitignore ├── README.md ├── megatron │ ├── README.md │ ├── clip-grad-v2.2.patch │ ├── fmoefy-v2.2.patch │ ├── v2.5.patch │ └── v3.0.2.patch └── transformer-xl │ ├── LICENSE │ ├── README.md │ ├── data_utils.py │ ├── eval.py │ ├── mem_transformer.py │ ├── scripts │ ├── getdata.sh │ ├── run_enwik8_base.sh │ ├── run_enwik8_base_moe.sh │ ├── run_enwik8_large.sh │ ├── run_lm1b_base.sh │ ├── run_lm1b_large.sh │ ├── run_text8_base.sh │ ├── run_text8_large.sh │ ├── run_wt103_base.sh │ └── run_wt103_large.sh │ ├── train.py │ └── utils │ ├── adaptive_softmax.py │ ├── data_parallel.py │ ├── exp_utils.py │ ├── log_uniform_sampler.py │ ├── proj_adaptive_softmax.py │ └── vocabulary.py ├── fmoe ├── __init__.py ├── balance.py ├── distributed.py ├── fastermoe │ ├── __init__.py │ ├── config.py │ ├── expert_utils.py │ ├── schedule.py │ └── shadow_policy.py ├── functions.py ├── gates │ ├── __init__.py │ ├── base_gate.py │ ├── dc_gate.py │ ├── faster_gate.py │ ├── gshard_gate.py │ ├── naive_gate.py │ ├── noisy_gate.py │ ├── swipe_gate.py │ ├── switch_gate.py │ ├── utils.py │ └── zero_gate.py ├── layers.py ├── linear.py ├── megatron │ ├── Megatron.LICENSE │ ├── __init__.py │ ├── balance.py │ ├── checkpoint.py │ ├── distributed.py │ ├── layers.py │ ├── patch.py │ └── utils.py ├── transformer.py └── utils.py ├── requirements.txt ├── setup.py └── tests ├── README.md ├── benchmark_mlp.py ├── moe.py ├── test.sh ├── test_comm.py ├── test_ddp.py ├── test_dp.py ├── test_faster_gate.py ├── test_faster_schedule.py ├── test_faster_shadow.py ├── test_gates.py ├── test_local_exchange.py ├── test_mimo.py ├── test_numerical.py ├── test_swipe.py └── test_zero.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Compile with "..." 16 | 2. Run "..." with "..." processes on "..." nodes 17 | 18 | **Expected behavior** 19 | A clear and concise description of what you expected to happen. 20 | 21 | **Logs** 22 | If applicable, add logs to help explain your problem. 23 | 24 | **Platform** 25 | - Device: [e.g. NVIDIA V100] 26 | - OS: [e.g. Debian 10.2 buster] 27 | - CUDA version: [e.g. 11.1] 28 | - NCCL version: [e.g. 2.7.8-1] 29 | - PyTorch version: [e.g. 1.8.0] 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | data/ 3 | libtorch-shared-with-deps-* 4 | pytorch/cuda/build 5 | exp/ 6 | .vscode/ 7 | a.out 8 | *.egg-info 9 | *.egg 10 | build 11 | *swp 12 | logs 13 | dist 14 | **/.DS_Store 15 | .idea 16 | -------------------------------------------------------------------------------- /cuda/.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | build 3 | -------------------------------------------------------------------------------- /cuda/balancing.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "balancing.cuh" 3 | #include "global_exchange.h" 4 | #include 5 | 6 | /* 7 | * note that due to limit of cuda atomic operator, capacity should be int32 8 | */ 9 | torch::Tensor _limit_by_capacity( 10 | torch::Tensor expert_count, torch::Tensor capacity, 11 | long n_expert, long n_worker) { 12 | CHECK_INPUT(expert_count); 13 | CHECK_INPUT(capacity); 14 | auto expert_count_ack = torch::empty_like(expert_count); 15 | auto smgr = getCudaStreamManager(expert_count.device().index()); 16 | fmoe_cuda_limit_by_capacity_impl( 17 | expert_count.data_ptr(), 18 | capacity.data_ptr(), 19 | expert_count_ack.data_ptr(), 20 | n_expert, n_worker, smgr); 21 | return expert_count_ack; 22 | } 23 | 24 | torch::Tensor _prune_gate_by_capacity( 25 | torch::Tensor gate_idx, torch::Tensor expert_count, 26 | long n_expert, long n_worker) { 27 | auto smgr = getCudaStreamManager(expert_count.device().index()); 28 | auto batch_size = gate_idx.numel(); 29 | auto opt = torch::TensorOptions() 30 | .dtype(gate_idx.dtype()) 31 | .device(gate_idx.device()); 32 | auto new_gate_idx = torch::empty(gate_idx.sizes(), opt); 33 | fmoe_cuda_prune_gate_by_capacity_impl( 34 | gate_idx.data_ptr(), 35 | new_gate_idx.data_ptr(), 36 | expert_count.data_ptr(), 37 | batch_size, n_expert, n_worker, smgr); 38 | return new_gate_idx; 39 | } 40 | 41 | template 42 | T* _cudamalloc(size_t sz) { 43 | T* dptr; 44 | cudaMalloc(&dptr, sz * sizeof(T)); 45 | return dptr; 46 | } 47 | 48 | template 49 | T* _h2d(const T* hptr, T* dptr, size_t sz) { 50 | cudaMemcpy(dptr, hptr, sz * sizeof(T), cudaMemcpyHostToDevice); 51 | return dptr; 52 | } 53 | template 54 | T* _h2d(T* hptr, size_t sz) { 55 | T* dptr = _cudamalloc(sz); 56 | return _h2d(hptr, dptr, sz); 57 | } 58 | template 59 | T* _d2h(const T* dptr, T* hptr, size_t sz) { 60 | cudaMemcpy(hptr, dptr, sz * sizeof(T), cudaMemcpyDeviceToHost); 61 | return hptr; 62 | } 63 | template 64 | T* _d2h(const T* dptr, size_t sz) { 65 | T* hptr = new T[sz]; 66 | return _d2h(dptr, hptr, sz); 67 | } 68 | 69 | #ifdef FMOE_USE_NCCL 70 | 71 | #include 72 | 73 | #define UPDATE_COUNTERS(__count__) { \ 74 | if (i == rank) { \ 75 | lec[j] += (__count__); \ 76 | } \ 77 | if (j == rank) { \ 78 | gec[i] += (__count__); \ 79 | cap -= (__count__); \ 80 | } \ 81 | } 82 | 83 | std::vector _swipe_once( 84 | torch::Tensor gate_idx, torch::Tensor capacity, 85 | long n_expert, long n_worker, long bias) { 86 | auto device_idx = gate_idx.device().index(); 87 | auto smgr = getCudaStreamManager(device_idx); 88 | int rank; 89 | ncclCommUserRank(smgr->ncclcomm, &rank); 90 | cudaSetDevice(device_idx); 91 | 92 | auto capacity_new = capacity.clone(); 93 | auto cap = capacity_new.item(); 94 | 95 | long batch_size = gate_idx.size(0); 96 | auto gate_idx_cpu = gate_idx.cpu(); 97 | long* gidx = gate_idx_cpu.data_ptr(); 98 | 99 | /* Local count and exchange */ 100 | long *lec = new long[n_worker]; 101 | memset(lec, 0, n_worker * sizeof(long)); 102 | for (long i = 0; i < batch_size; ++i) { 103 | ++lec[gidx[i] / n_expert]; 104 | } 105 | long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc(n_worker); 106 | fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr); 107 | smgr->syncTorch(); 108 | long *gec = _d2h(d_gec, n_worker); 109 | 110 | /* Limit number of incoming samples */ 111 | long *drop_count = new long[n_worker]; 112 | memset(drop_count, 0, n_worker * sizeof(long)); 113 | for (long i = 0; i < n_worker; ++i) { 114 | if (cap >= gec[i]) { 115 | drop_count[i] = 0; 116 | cap -= gec[i]; 117 | } else { 118 | drop_count[i] = gec[i] - cap; 119 | gec[i] = cap; 120 | cap = 0; 121 | } 122 | } 123 | 124 | /* Send limit information back */ 125 | _h2d(gec, d_gec, n_worker); 126 | fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr); 127 | smgr->syncTorch(); 128 | _d2h(d_lec, lec, n_worker); 129 | 130 | auto d_dropcount = _h2d(drop_count, n_worker); 131 | ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum, 132 | smgr->ncclcomm, smgr->torchStream()); 133 | smgr->syncTorch(); 134 | _d2h(d_dropcount, drop_count, n_worker); 135 | 136 | auto d_gcap = _cudamalloc(n_worker); 137 | _h2d(&cap, d_gcap + rank, 1); 138 | ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64, 139 | smgr->ncclcomm, smgr->torchStream()); 140 | smgr->syncTorch(); 141 | auto gcap = _d2h(d_gcap, n_worker); 142 | 143 | /* Re-assign and update counters */ 144 | for (long i = 0, j = 0; i < n_worker; ++i) { 145 | while (drop_count[i] > 0) { 146 | if (drop_count[i] > gcap[j]) { 147 | drop_count[i] -= gcap[j]; 148 | UPDATE_COUNTERS(gcap[j]); 149 | ++j; 150 | } else { 151 | gcap[j] -= drop_count[i]; 152 | UPDATE_COUNTERS(drop_count[i]); 153 | break; 154 | } 155 | } 156 | } 157 | for (long i = 0; i < batch_size; ++i) { 158 | auto widx = gidx[i] / n_expert; 159 | if (lec[widx] > 0) { 160 | --lec[widx]; 161 | } else { 162 | gidx[i] = -1; 163 | } 164 | } 165 | for (long i = 0, k = 0; i < batch_size; ++i) { 166 | if (gidx[i] != -1) { 167 | continue; 168 | } 169 | for (; lec[k] == 0; ++k); 170 | --lec[k]; 171 | gidx[i] = k * n_expert + bias; 172 | } 173 | *capacity_new.data_ptr() = cap; 174 | 175 | delete [] drop_count; 176 | delete [] lec; 177 | delete [] gec; 178 | delete [] gcap; 179 | 180 | cudaFree(d_dropcount); 181 | cudaFree(d_lec); 182 | cudaFree(d_gec); 183 | cudaFree(d_gcap); 184 | 185 | return {gate_idx_cpu, capacity_new}; 186 | } 187 | 188 | #undef UPDATE_COUNTERS 189 | 190 | #endif 191 | -------------------------------------------------------------------------------- /cuda/balancing.cuh: -------------------------------------------------------------------------------- 1 | #include "stream_manager.h" 2 | #include "utils/fmoe_utils.h" 3 | #include 4 | 5 | __global__ 6 | void limit_by_capacity_kernel(const long* ec, int* cap, long* eca, 7 | const long n_expert, const long n_worker) { 8 | int eid = blockIdx.y; 9 | int wid = blockIdx.x * blockDim.x + threadIdx.x; 10 | if (wid < n_worker) { 11 | int proposal = ec[wid * n_expert + eid]; 12 | int cap_left = atomicSub(cap + eid, proposal); 13 | if (cap_left >= proposal) { 14 | eca[wid * n_expert + eid] = proposal; 15 | } else if (cap_left >= 0) { 16 | eca[wid * n_expert + eid] = cap_left; 17 | } else { 18 | eca[wid * n_expert + eid] = 0; 19 | } 20 | } 21 | } 22 | 23 | void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap, 24 | long* eca, const long n_expert, const long n_worker, 25 | CudaStreamManager* smgr) { 26 | dim3 grid_dim(CEIL(n_worker, 1024), n_expert); 27 | dim3 block_dim(1024); 28 | limit_by_capacity_kernel<<torchStream()>>>( 29 | ec, cap, eca, n_expert, n_worker); 30 | } 31 | 32 | __global__ 33 | void prune_gate_by_capacity_kernel(const long* gate_idx, long* new_gate_idx, 34 | int* ec, 35 | const long batch_size, const long n_expert, const long n_worker) { 36 | int i = blockIdx.x * blockDim.x + threadIdx.x; 37 | if (i < batch_size) { 38 | int orig_cap = atomicSub(ec + gate_idx[i], 1); 39 | if (orig_cap <= 0) { 40 | new_gate_idx[i] = -1; 41 | } else { 42 | new_gate_idx[i] = gate_idx[i]; 43 | } 44 | } 45 | } 46 | 47 | void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx, 48 | int* ec, 49 | const long batch_size, const long n_expert, const long n_worker, 50 | CudaStreamManager* smgr) { 51 | dim3 grid_dim(CEIL(batch_size, 1024)); 52 | dim3 block_dim(1024); 53 | prune_gate_by_capacity_kernel<<torchStream()>>>( 54 | gate_idx, new_gate_idx, ec, batch_size, n_expert, n_worker 55 | ); 56 | } 57 | -------------------------------------------------------------------------------- /cuda/fastermoe/status.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef FASTER_STATUS_H 3 | #define FASTER_STATUS_H 4 | 5 | int isSmartSchEnabled(); 6 | void setSmartSchEnabled(int); 7 | 8 | #endif // FASTER_STATUS_H 9 | -------------------------------------------------------------------------------- /cuda/fmoe_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | // global_exchange 7 | #ifdef FMOE_USE_NCCL 8 | 9 | #if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \ 10 | (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)) 11 | #include 12 | #include 13 | #else 14 | #include 15 | #endif 16 | 17 | torch::Tensor _expert_exchange( 18 | torch::Tensor local_expert_count, 19 | long n_expert, long n_workers); 20 | torch::Tensor _global_scatter( 21 | torch::Tensor input_buf, 22 | torch::Tensor local_expert_count, 23 | torch::Tensor global_expert_count, 24 | long batch_size, long n_workers); 25 | torch::Tensor _global_gather( 26 | torch::Tensor output_buf, 27 | torch::Tensor local_expert_count, 28 | torch::Tensor global_expert_count, 29 | long batch_size, long n_workers); 30 | #if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2) 31 | void _ensure_nccl(c10d::ProcessGroup& p, torch::Tensor t); 32 | #else 33 | void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t); 34 | #endif // TORCH_VERSION 35 | 36 | #endif // FMOE_USE_NCCL 37 | 38 | // local_exchange 39 | void _assign_pos( 40 | torch::Tensor cum_count, 41 | torch::Tensor gate, 42 | torch::Tensor pos); 43 | void _expert_count( 44 | torch::Tensor gate_idx, 45 | torch::Tensor expert_count); 46 | 47 | // parallel_linear 48 | torch::Tensor _linear_forward( 49 | torch::Tensor input_buf, 50 | torch::Tensor expert_count, 51 | torch::Tensor weight, 52 | at::optional bias 53 | ); 54 | std::vector _linear_backward( 55 | torch::Tensor grad_output_buf, 56 | torch::Tensor input_buf, 57 | torch::Tensor expert_count, 58 | torch::Tensor weight, 59 | at::optional bias 60 | ); 61 | 62 | // balancing 63 | torch::Tensor _limit_by_capacity( 64 | torch::Tensor expert_count, torch::Tensor capacity, 65 | long n_expert, long n_experts); 66 | torch::Tensor _prune_gate_by_capacity( 67 | torch::Tensor gate_idx, torch::Tensor expert_count, 68 | long n_expert, long n_worker); 69 | std::vector _swipe_once( 70 | torch::Tensor gate_idx, torch::Tensor capacity_tensor, 71 | long n_expert, long n_worker, long bias); 72 | 73 | // smart scheduling 74 | std::vector _smart_sch_forward( 75 | torch::Tensor input_buf, 76 | torch::Tensor local_expert_count, 77 | torch::Tensor global_expert_count, 78 | torch::Tensor stored_models, 79 | long global_batch_size, 80 | long expert_size, 81 | long n_workers, 82 | py::function forward_fn, 83 | py::function get_param_fn, 84 | py::function stash_fn, 85 | py::function pop_fn); 86 | torch::Tensor _smart_sch_backward( 87 | torch::Tensor grad_out, 88 | torch::Tensor local_expert_count, 89 | torch::Tensor global_expert_count, 90 | torch::Tensor stored_models, 91 | long buf_batch_size, 92 | long global_batch_size, 93 | long n_workers, 94 | py::function backward_fn, 95 | py::function stash_fn, 96 | py::function pop_fn, 97 | py::function collect_fn, 98 | py::function set_grad_fn); 99 | void _reduce_grad( 100 | torch::Tensor t, 101 | long root, 102 | long expert_size); 103 | 104 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 105 | #ifdef FMOE_USE_NCCL 106 | m.def("expert_exchange", &_expert_exchange, "FastMoE expert exchange (CUDA)"); 107 | m.def("global_scatter", &_global_scatter, "FastMoE global scatter (CUDA)"); 108 | m.def("global_gather", &_global_gather, "FastMoE global gather (CUDA)"); 109 | m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm"); 110 | m.def("swipe_once", &_swipe_once, "SWIPE balance strategy(CUDA)"); 111 | 112 | m.def("smart_sch_forward", &_smart_sch_forward, "E2E MoE layer forward with smart scheduling"); 113 | m.def("smart_sch_backward", &_smart_sch_backward, "E2E MoE layer backward with smart scheduling"); 114 | m.def("reduce_grad", &_reduce_grad, "Reduce gradients over FastMoE's communication stream"); 115 | #endif 116 | 117 | m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)"); 118 | m.def("assign_pos", &_assign_pos, "FastMoE assign pos by gate (CUDA)"); 119 | 120 | m.def("linear_forward", &_linear_forward, "FastMoE forward (CUDA)"); 121 | m.def("linear_backward", &_linear_backward, "FastMoE backward (CUDA)"); 122 | 123 | m.def("limit_by_capacity", &_limit_by_capacity, "FastMoE limit experts by capacity(CUDA)"); 124 | m.def("prune_gate_by_capacity", &_prune_gate_by_capacity, "FastMoE prune gate by capacity(CUDA)"); 125 | } 126 | -------------------------------------------------------------------------------- /cuda/global_exchange.cpp: -------------------------------------------------------------------------------- 1 | #include "global_exchange.h" 2 | #include "utils/fmoe_utils.h" 3 | #include 4 | 5 | #ifdef FMOE_USE_NCCL 6 | #include 7 | 8 | 9 | void fmoe_cuda_expert_exchange_impl( 10 | const long* local_expert_count, 11 | long* global_expert_count, 12 | int n_expert, int world_size, 13 | CudaStreamManager* smgr) { 14 | NCCL_SAFE_CALL(ncclGroupStart()); 15 | for (int i = 0; i < world_size; ++i) { 16 | NCCL_SAFE_CALL(ncclSend( 17 | local_expert_count + n_expert * i, 18 | n_expert, 19 | ncclInt64, 20 | i, 21 | smgr->ncclcomm, 22 | smgr->torchStream())); 23 | NCCL_SAFE_CALL(ncclRecv( 24 | global_expert_count + n_expert * i, 25 | n_expert, 26 | ncclInt64, 27 | i, 28 | smgr->ncclcomm, 29 | smgr->torchStream())); 30 | } 31 | NCCL_SAFE_CALL(ncclGroupEnd()); 32 | } 33 | 34 | torch::Tensor _expert_exchange( 35 | torch::Tensor local_expert_count, 36 | long n_expert, long n_workers) { 37 | auto global_expert_count = torch::empty_like(local_expert_count); 38 | auto smgr = getCudaStreamManager(local_expert_count.device().index()); 39 | 40 | fmoe_cuda_expert_exchange_impl( 41 | local_expert_count.data_ptr(), 42 | global_expert_count.data_ptr(), 43 | n_expert, n_workers, 44 | smgr); 45 | return global_expert_count; 46 | } 47 | 48 | torch::Tensor _global_scatter( 49 | torch::Tensor input_buf, 50 | torch::Tensor local_expert_count, 51 | torch::Tensor global_expert_count, 52 | long batch_size, long n_workers) { 53 | CHECK_INPUT(input_buf); 54 | 55 | auto n_expert = local_expert_count.size(0) / n_workers; 56 | auto in_feat = input_buf.size(1); 57 | auto global_input_buf = input_buf.new_empty({batch_size, in_feat}); 58 | auto smgr = getCudaStreamManager(input_buf.device().index()); 59 | 60 | AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, 61 | input_buf.scalar_type(), "fmoe_cuda_global_scatter", ([&] { 62 | fmoe_cuda_global_scatter_impl( 63 | input_buf.data_ptr(), 64 | local_expert_count.data_ptr(), 65 | global_expert_count.data_ptr(), 66 | global_input_buf.data_ptr(), 67 | in_feat, n_expert, n_workers, 68 | smgr 69 | ); 70 | })); 71 | return global_input_buf; 72 | } 73 | 74 | torch::Tensor _global_gather( 75 | torch::Tensor output_buf, 76 | torch::Tensor local_expert_count, 77 | torch::Tensor global_expert_count, 78 | long batch_size, long n_workers) { 79 | CHECK_INPUT(output_buf); 80 | 81 | auto n_expert = local_expert_count.size(0) / n_workers; 82 | auto out_feat = output_buf.size(1); 83 | auto local_output_buf = output_buf.new_empty({batch_size, out_feat}); 84 | auto smgr = getCudaStreamManager(output_buf.device().index()); 85 | 86 | AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, 87 | output_buf.scalar_type(), "fmoe_cuda_global_gather", ([&] { 88 | fmoe_cuda_global_gather_impl( 89 | output_buf.data_ptr(), 90 | local_expert_count.data_ptr(), 91 | global_expert_count.data_ptr(), 92 | local_output_buf.data_ptr(), 93 | out_feat, n_expert, n_workers, 94 | smgr 95 | ); 96 | })); 97 | return local_output_buf; 98 | } 99 | 100 | #if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \ 101 | (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)) 102 | #include 103 | #include 104 | #else 105 | #include 106 | #endif 107 | 108 | class HackNCCLGroup: public c10d::ProcessGroupNCCL { 109 | public: 110 | ncclComm_t getcomm(at::Device dev) { 111 | ncclUniqueId ncclID; 112 | int rank = getRank(); 113 | if (rank == 0) { 114 | ncclGetUniqueId(&ncclID); 115 | } 116 | #if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \ 117 | (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 12)) 118 | broadcastUniqueNCCLID(&ncclID, 119 | false, 120 | "fastmoe_nccl_comm", 121 | rank); 122 | #elif defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \ 123 | (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 8)) 124 | broadcastUniqueNCCLID(&ncclID, 125 | c10d::OpType::SEND, 126 | "fastmoe_nccl_comm", 127 | rank); 128 | #else 129 | broadcastUniqueNCCLID(&ncclID); 130 | #endif 131 | ncclComm_t comm; 132 | NCCL_SAFE_CALL(ncclCommInitRank(&comm, getSize(), ncclID, rank)); 133 | return comm; 134 | } 135 | }; 136 | 137 | #if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2) 138 | void _ensure_nccl(c10d::ProcessGroup& p, torch::Tensor t) { 139 | #else 140 | void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) { 141 | #endif // TORCH_VERSION 142 | auto smgr = getCudaStreamManager(t.device().index()); 143 | if (smgr->ncclgood) { 144 | return; 145 | } 146 | #if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2) 147 | HackNCCLGroup* h = (HackNCCLGroup*)(void*) 148 | (p.getBackend(c10d::ProcessGroup::NCCL).get()); 149 | #else 150 | HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p; 151 | #endif // TORCH_VERSION 152 | smgr->ncclcomm = h->getcomm(t.device()); 153 | if (smgr->ncclcomm != 0) { 154 | smgr->ncclgood = 1; 155 | } else { 156 | std::cerr << "Nccl initialization failed\n"; 157 | } 158 | } 159 | 160 | #endif // FMOE_USE_NCCL 161 | -------------------------------------------------------------------------------- /cuda/global_exchange.h: -------------------------------------------------------------------------------- 1 | #include "stream_manager.h" 2 | #ifdef FMOE_USE_NCCL 3 | 4 | void fmoe_cuda_expert_exchange_impl( 5 | const long* local_expert_count, 6 | long* global_expert_count, 7 | int n_expert, int world_size, 8 | CudaStreamManager* smgr); 9 | 10 | 11 | template 12 | void fmoe_cuda_global_scatter_impl( 13 | const scalar_t* local_input_buf, 14 | const long* local_expert_count, 15 | const long* global_expert_count, 16 | scalar_t* input_buf, 17 | size_t in_feat, size_t n_expert, size_t world_size, 18 | CudaStreamManager* smgr) { 19 | // assert world_size > 1 20 | int recv_ptr = 0; 21 | /* TODO: may save for backward */ 22 | long*expert_ptr = new long[n_expert * world_size]; 23 | expert_ptr[0] = 0; 24 | for (size_t i = 1; i < n_expert * world_size; ++i) { 25 | expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1]; 26 | } 27 | 28 | for (size_t i = 0; i < n_expert; ++i) { 29 | NCCL_SAFE_CALL(ncclGroupStart()); 30 | for (size_t j = 0; j < world_size; ++j) { 31 | int idx = i + j * n_expert; 32 | if (local_expert_count[idx]) { 33 | NCCL_SAFE_CALL(ncclSend( 34 | local_input_buf + expert_ptr[idx] * in_feat, 35 | local_expert_count[idx] * in_feat * sizeof(scalar_t), 36 | ncclChar, 37 | j, 38 | smgr->ncclcomm, 39 | smgr->torchStream())); 40 | } 41 | if (global_expert_count[idx]) { 42 | NCCL_SAFE_CALL(ncclRecv( 43 | input_buf + recv_ptr * in_feat, 44 | global_expert_count[idx] * in_feat * sizeof(scalar_t), 45 | ncclChar, 46 | j, 47 | smgr->ncclcomm, 48 | smgr->torchStream())); 49 | recv_ptr += global_expert_count[idx]; 50 | } 51 | } 52 | NCCL_SAFE_CALL(ncclGroupEnd()); 53 | } 54 | delete [] expert_ptr; 55 | } 56 | 57 | template 58 | void fmoe_cuda_global_gather_impl( 59 | const scalar_t* output_buf, 60 | const long* local_expert_count, 61 | const long* global_expert_count, 62 | scalar_t* local_output_buf, 63 | size_t out_feat, size_t n_expert, size_t world_size, 64 | CudaStreamManager* smgr) { 65 | long send_ptr = 0; 66 | /* TODO: may save for backward */ 67 | long *expert_ptr = new long[n_expert * world_size]; 68 | expert_ptr[0] = 0; 69 | for (size_t i = 1; i < n_expert * world_size; ++i) { 70 | expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1]; 71 | } 72 | 73 | for (size_t i = 0; i < n_expert; ++i) { 74 | NCCL_SAFE_CALL(ncclGroupStart()); 75 | for (size_t j = 0; j < world_size; ++j) { 76 | int idx = i + j * n_expert; 77 | if (global_expert_count[idx]) { 78 | NCCL_SAFE_CALL(ncclSend( 79 | output_buf + send_ptr * out_feat, 80 | global_expert_count[idx] * out_feat * sizeof(scalar_t), 81 | ncclChar, 82 | j, 83 | smgr->ncclcomm, 84 | smgr->torchStream())); 85 | send_ptr += global_expert_count[idx]; 86 | } 87 | if (local_expert_count[idx]) { 88 | NCCL_SAFE_CALL(ncclRecv( 89 | local_output_buf + expert_ptr[idx] * out_feat, 90 | local_expert_count[idx] * out_feat * sizeof(scalar_t), 91 | ncclChar, 92 | j, 93 | smgr->ncclcomm, 94 | smgr->torchStream())); 95 | } 96 | } 97 | NCCL_SAFE_CALL(ncclGroupEnd()); 98 | } 99 | delete [] expert_ptr; 100 | } 101 | 102 | 103 | #endif // FMOE_USE_NCCL 104 | -------------------------------------------------------------------------------- /cuda/local_exchange.cu: -------------------------------------------------------------------------------- 1 | #include "local_exchange.cuh" 2 | #include "utils/fmoe_utils.h" 3 | #include 4 | 5 | void _assign_pos( 6 | torch::Tensor cum_count, 7 | torch::Tensor gate, 8 | torch::Tensor pos) { 9 | auto smgr = getCudaStreamManager(cum_count.device().index()); 10 | auto gate_shp = gate.sizes(); 11 | size_t batch_size = gate_shp[0], topk = 1; 12 | if (gate_shp.size() == 2) { 13 | topk = gate_shp[1]; 14 | } 15 | fmoe_cuda_assign_pos_impl( 16 | cum_count.data_ptr(), 17 | gate.data_ptr(), 18 | pos.data_ptr(), 19 | batch_size, topk, smgr); 20 | } 21 | 22 | void _expert_count( 23 | torch::Tensor gate_idx, 24 | torch::Tensor expert_count) { 25 | auto smgr = getCudaStreamManager(gate_idx.device().index()); 26 | auto batch_size = gate_idx.numel(); 27 | auto n_expert = expert_count.numel(); 28 | fmoe_cuda_expert_count_impl( 29 | gate_idx.data_ptr(), 30 | expert_count.data_ptr(), 31 | batch_size, n_expert, smgr); 32 | } 33 | -------------------------------------------------------------------------------- /cuda/local_exchange.cuh: -------------------------------------------------------------------------------- 1 | #include "stream_manager.h" 2 | #include "utils/helper_cuda.h" 3 | #include "utils/fmoe_utils.h" 4 | 5 | __global__ 6 | void assign_pos_kernel(int* cum_count, const long* gate, long* pos, 7 | size_t numel, size_t topk) { 8 | size_t idx = threadIdx.x + blockIdx.x * blockDim.x; 9 | if (idx < numel) { 10 | long gate_idx = gate[idx]; 11 | if (gate_idx > -1) { 12 | int p = atomicSub(cum_count + gate_idx, 1); 13 | pos[p - 1] = (long)idx; 14 | } 15 | } 16 | } 17 | 18 | void fmoe_cuda_assign_pos_impl( 19 | int* cum_count, const long* gate, long* pos, 20 | const size_t batch_size, const size_t topk, 21 | CudaStreamManager* smgr) { 22 | size_t numel = batch_size * topk; 23 | assign_pos_kernel 24 | <<torchStream()>>> 25 | (cum_count, gate, pos, numel, topk); 26 | } 27 | 28 | #define PERTHREAD_EXPERTS 256 29 | 30 | #ifdef FMOE_USE_HIP 31 | #define WARP_SIZE 64 32 | #else 33 | #define WARP_SIZE 32 34 | #endif 35 | 36 | __global__ 37 | void expert_count_kernel(const long* gate_idx, int* expert_count, 38 | const size_t batch_size, const size_t n_expert) { 39 | int res_tmp[PERTHREAD_EXPERTS] = {0}; 40 | long expert_min = blockIdx.x * PERTHREAD_EXPERTS; 41 | long expert_max = expert_min + PERTHREAD_EXPERTS; 42 | if (expert_max > n_expert) { 43 | expert_max = n_expert; 44 | } 45 | for (int i = threadIdx.x; i < batch_size; i += blockDim.x) { 46 | long idx = gate_idx[i]; 47 | if (idx == -1) { 48 | continue; 49 | } 50 | if (idx < expert_min || idx >= expert_max) { 51 | continue; 52 | } 53 | res_tmp[idx - expert_min] += 1; 54 | } 55 | for (int i = expert_min; i < expert_max; ++i) { 56 | int x = res_tmp[i - expert_min]; 57 | #pragma unroll 58 | for (int j = 1; j < WARP_SIZE; j <<= 1) { 59 | #ifdef FMOE_USE_HIP 60 | x = x + __shfl_down(x, j); 61 | #else 62 | x = x + __shfl_down_sync(-1u, x, j); 63 | #endif 64 | } 65 | if (threadIdx.x % WARP_SIZE == 0) { 66 | atomicAdd(expert_count + i, x); 67 | } 68 | } 69 | } 70 | 71 | void fmoe_cuda_expert_count_impl( 72 | const long* gate_idx, int* expert_count, 73 | const size_t batch_size, const size_t n_expert, 74 | CudaStreamManager* smgr) { 75 | expert_count_kernel 76 | <<torchStream()>>> 77 | (gate_idx, expert_count, batch_size, n_expert); 78 | } 79 | -------------------------------------------------------------------------------- /cuda/parallel_linear.cu: -------------------------------------------------------------------------------- 1 | #include "parallel_linear.cuh" 2 | #include "utils/fmoe_utils.h" 3 | #include 4 | 5 | torch::Tensor _linear_forward( 6 | torch::Tensor input_buf, 7 | torch::Tensor expert_count, 8 | torch::Tensor weight, 9 | at::optional bias 10 | ) { 11 | auto smgr = getCudaStreamManager(input_buf.device().index()); 12 | const auto batch_size = input_buf.size(0); 13 | const auto num_expert = weight.size(0); 14 | const auto out_feat = weight.size(1); 15 | const auto in_feat = weight.size(2); 16 | 17 | #ifdef MOE_DEBUG 18 | printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 19 | num_expert, in_feat, out_feat); 20 | #endif 21 | 22 | torch::Tensor output; 23 | 24 | if (bias.has_value()) { 25 | output = bias.value().repeat_interleave(expert_count.to(bias.value().device()), 0); 26 | } else{ 27 | auto out_options = torch::TensorOptions() 28 | .device(input_buf.device()) 29 | .dtype(input_buf.dtype()); 30 | output = torch::empty({batch_size, out_feat}, out_options); 31 | } 32 | 33 | AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, 34 | input_buf.scalar_type(), "moe_forward_cuda", 35 | ([&] { 36 | fmoe_cuda_linear_forward_impl( 37 | input_buf.data_ptr(), 38 | weight.data_ptr(), 39 | expert_count.data_ptr(), 40 | output.data_ptr(), 41 | bias.has_value(), 42 | in_feat, 43 | out_feat, 44 | num_expert, 45 | smgr 46 | ); 47 | })); 48 | 49 | return output; 50 | } 51 | 52 | 53 | std::vector _linear_backward( 54 | torch::Tensor grad_output_buf, 55 | torch::Tensor input_buf, 56 | torch::Tensor expert_count, 57 | torch::Tensor weight, 58 | at::optional bias 59 | ) { 60 | auto smgr = getCudaStreamManager(input_buf.device().index()); 61 | const auto batch_size = input_buf.size(0); 62 | const auto num_expert = weight.size(0); 63 | const auto out_feat = weight.size(1); 64 | const auto in_feat = weight.size(2); 65 | 66 | #ifdef MOE_DEBUG 67 | printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, " 68 | "out_feat (d_ffn)=%ld\n", 69 | batch_size, num_expert, in_feat, out_feat); 70 | #endif 71 | 72 | auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 73 | auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat}); 74 | auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat}); 75 | 76 | AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, 77 | input_buf.scalar_type(), "moe_cuda_backward", ([&] { 78 | fmoe_cuda_linear_backward_impl( 79 | grad_output_buf.data_ptr(), 80 | input_buf.data_ptr(), 81 | weight.data_ptr(), 82 | expert_count.data_ptr(), 83 | grad_input_buf.data_ptr(), 84 | grad_weight.data_ptr(), 85 | grad_bias.data_ptr(), 86 | bias.has_value(), 87 | batch_size, 88 | in_feat, 89 | out_feat, 90 | num_expert, 91 | smgr 92 | ); 93 | })); 94 | 95 | return {grad_input_buf, grad_weight, grad_bias}; 96 | } 97 | 98 | -------------------------------------------------------------------------------- /cuda/parallel_linear.cuh: -------------------------------------------------------------------------------- 1 | #include "stream_manager.h" 2 | #include "utils/cublas_wrapper.h" 3 | 4 | 5 | /* 6 | This function is to be called with one block per each column 7 | */ 8 | template 9 | __global__ 10 | void column_reduce(const scalar_t * matrix, scalar_t * result, 11 | int m /* lines */, int n /* columns*/) { 12 | 13 | // https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory 14 | extern __shared__ unsigned char my_smem[]; 15 | scalar_t *sdata = reinterpret_cast(my_smem); 16 | 17 | // normal tid 18 | int tid = threadIdx.x + threadIdx.y * blockDim.x; 19 | 20 | // transposed tid for shared memory 21 | int new_tid = threadIdx.y + threadIdx.x * blockDim.y; 22 | 23 | // true x value in the matrix 24 | int real_x = threadIdx.x + blockDim.x * blockIdx.x; 25 | 26 | int i = real_x + n * threadIdx.y; 27 | const int it = n*blockDim.y; 28 | int offset = it; 29 | float accumulator = 0; 30 | 31 | if (threadIdx.y < m && real_x < n) { 32 | // store all the values from this column in a warped way 33 | accumulator = matrix[i]; 34 | while (i + offset < n*m) { 35 | accumulator += matrix[i + offset]; 36 | offset += it; 37 | } 38 | } 39 | 40 | // save column reduction data in a transposed way 41 | sdata[new_tid] = accumulator; 42 | __syncthreads(); 43 | 44 | for (size_t t= 16; t > 0; t>>=1) { 45 | if (tid < 32 * 32 - 16) 46 | sdata[tid] += sdata[tid + t]; 47 | __syncthreads(); 48 | } 49 | 50 | if (threadIdx.y == 0 && real_x < n) 51 | result[real_x] = sdata[new_tid]; 52 | 53 | } 54 | 55 | template 56 | void fmoe_cuda_linear_forward_impl( 57 | const scalar_t* input_buf, 58 | const scalar_t* weight, 59 | const long* expert_count, 60 | scalar_t* output_buf, 61 | const bool has_bias, 62 | const size_t in_feat, 63 | const size_t out_feat, 64 | const size_t num_expert, 65 | CudaStreamManager* smgr) { 66 | scalar_t alpha = 1, beta = has_bias ? 1 : 0; 67 | 68 | smgr->syncTorch(); 69 | for (int i = 0, ptr = 0; i < num_expert; ++i) { 70 | if (expert_count[i] == 0) { 71 | continue; 72 | } 73 | // Use T(B) x T(A) = T(C) to produce row-major C 74 | checkCudaErrors(cublasXgemm( 75 | smgr->handle(i), 76 | CUBLAS_OP_T, 77 | CUBLAS_OP_N, 78 | out_feat, expert_count[i], in_feat, 79 | &alpha, 80 | weight + i * in_feat * out_feat, in_feat, 81 | input_buf + ptr * in_feat, in_feat, 82 | &beta, 83 | output_buf + out_feat * ptr, out_feat 84 | )); 85 | 86 | ptr += expert_count[i]; 87 | } 88 | smgr->sync(num_expert); 89 | } 90 | 91 | template 92 | void fmoe_cuda_linear_backward_impl( 93 | const scalar_t* grad_output_buf, 94 | const scalar_t* input_buf, 95 | const scalar_t* weight, 96 | const long* expert_count, 97 | scalar_t* grad_input_buf, 98 | scalar_t* grad_weight, 99 | scalar_t* grad_bias, 100 | const bool has_bias, 101 | const size_t batch_size, 102 | const size_t in_feat, 103 | const size_t out_feat, 104 | const size_t num_expert, 105 | CudaStreamManager* smgr) { 106 | smgr->syncTorch(); 107 | scalar_t alpha = 1, beta = 0; 108 | 109 | // bias 110 | dim3 block_threads(32, 32); 111 | dim3 grid_threads(out_feat / 32 + (out_feat % 32 ? 1 : 0), 1); 112 | 113 | 114 | for (int i = 0, ptr = 0; i < num_expert; ++i) { 115 | if (expert_count[i] == 0) { 116 | cudaMemset(grad_weight + i * in_feat * out_feat, 0, 117 | sizeof(scalar_t) * in_feat * out_feat); 118 | cudaMemset(grad_bias + i * out_feat, 0, sizeof(scalar_t) * out_feat); 119 | continue; 120 | } 121 | // Use T(B) x T(A) = T(C) to produce row-major C 122 | 123 | // Backward input: g_i = w @ g_o 124 | checkCudaErrors(cublasXgemm( 125 | smgr->handle(i), 126 | CUBLAS_OP_N, 127 | CUBLAS_OP_N, 128 | in_feat, expert_count[i], out_feat, 129 | &alpha, 130 | weight + i * in_feat * out_feat, in_feat, 131 | grad_output_buf + ptr * out_feat, out_feat, 132 | &beta, 133 | grad_input_buf + in_feat * ptr, in_feat 134 | )); 135 | 136 | // Backward weight: g_w = i @ g_o 137 | checkCudaErrors(cublasXgemm( 138 | smgr->handle(i), 139 | CUBLAS_OP_N, 140 | CUBLAS_OP_T, 141 | in_feat, out_feat, expert_count[i], 142 | &alpha, 143 | input_buf + in_feat * ptr, in_feat, 144 | grad_output_buf + ptr * out_feat, out_feat, 145 | &beta, 146 | grad_weight + i * in_feat * out_feat, in_feat 147 | )); 148 | 149 | if (has_bias) { 150 | column_reduce 151 | <<stream(i)>>> 152 | ( 153 | grad_output_buf + ptr * out_feat, 154 | grad_bias + i * out_feat, 155 | expert_count[i], 156 | out_feat 157 | ); 158 | } 159 | 160 | ptr += expert_count[i]; 161 | } 162 | smgr->sync(num_expert); 163 | } 164 | 165 | -------------------------------------------------------------------------------- /cuda/stream_manager.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "fastermoe/status.h" 10 | #include "stream_manager.h" 11 | 12 | #define SMGR_N_STREAMS 16 13 | 14 | 15 | cudaStream_t CudaStreamManager::stream(size_t idx) { 16 | if (this->use_default) { 17 | return c10::cuda::getCurrentCUDAStream().stream(); 18 | } 19 | return this->streams[idx % SMGR_N_STREAMS]; 20 | } 21 | 22 | cudaStream_t CudaStreamManager::torchStream() { 23 | return c10::cuda::getCurrentCUDAStream().stream(); 24 | } 25 | 26 | cublasHandle_t CudaStreamManager::handle(size_t idx) { 27 | if (this->use_default) { 28 | return at::cuda::getCurrentCUDABlasHandle(); 29 | } 30 | return this->handles[idx % SMGR_N_STREAMS]; 31 | } 32 | 33 | 34 | void CudaStreamManager::syncTorch() { 35 | cudaStreamSynchronize(this->torchStream()); 36 | } 37 | 38 | void CudaStreamManager::sync(int idx) { 39 | if (this->use_default) { 40 | return; 41 | } 42 | for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) { 43 | cudaStreamSynchronize(streams[i]); 44 | } 45 | } 46 | 47 | void CudaStreamManager::setup(const int device) { 48 | #ifdef FMOE_USE_NCCL 49 | this->ncclgood = 0; 50 | #endif 51 | this->device = device; 52 | checkCudaErrors(cudaSetDevice(device)); 53 | streams = new cudaStream_t[SMGR_N_STREAMS]; 54 | handles = new cublasHandle_t[SMGR_N_STREAMS]; 55 | for (size_t i = 0; i < SMGR_N_STREAMS; ++i) { 56 | // SHOULD NOT USE: cudaStreamCreate(...) 57 | // more details in 58 | // https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html 59 | checkCudaErrors(cudaStreamCreateWithFlags(streams + i, 60 | cudaStreamNonBlocking)); 61 | checkCudaErrors(cublasCreate(handles + i)); 62 | cublasSetStream(handles[i], streams[i]); 63 | } 64 | } 65 | 66 | void CudaStreamManager::destroy() { 67 | for (size_t i = 0; i < SMGR_N_STREAMS; ++i) { 68 | checkCudaErrors(cudaStreamDestroy(streams[i])); 69 | checkCudaErrors(cublasDestroy(handles[i])); 70 | } 71 | delete[] streams; 72 | delete[] handles; 73 | } 74 | 75 | std::unordered_map smgrs; 76 | std::mutex smgr_mtx; 77 | 78 | CudaStreamManager* getCudaStreamManager(const int device) { 79 | auto it = smgrs.find(device); 80 | if (it == smgrs.end()) { 81 | smgr_mtx.lock(); 82 | it = smgrs.find(device); 83 | if (it == smgrs.end()) { 84 | auto smgr = new CudaStreamManager(device); 85 | smgrs.insert(std::pair(device, smgr)); 86 | smgr_mtx.unlock(); 87 | return smgr; 88 | } else { 89 | smgr_mtx.unlock(); 90 | } 91 | } 92 | return it->second; 93 | } 94 | 95 | -------------------------------------------------------------------------------- /cuda/stream_manager.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_STREAM_MANAGER_H 2 | #define CUDA_STREAM_MANAGER_H 3 | 4 | #include "utils/helper_cuda.h" 5 | 6 | #ifdef FMOE_USE_NCCL 7 | #include 8 | 9 | #define NCCL_SAFE_CALL(__fn__) { \ 10 | auto __res__ = __fn__; \ 11 | if (__res__ != ncclSuccess) { \ 12 | fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \ 13 | exit(-1); \ 14 | } \ 15 | } 16 | 17 | #endif 18 | 19 | class CudaStreamManager { 20 | public: 21 | int device; 22 | cublasHandle_t* handles; 23 | cudaStream_t* streams; 24 | bool use_default; 25 | #ifdef FMOE_USE_NCCL 26 | char ncclgood; 27 | ncclComm_t ncclcomm; 28 | #endif 29 | 30 | public: 31 | CudaStreamManager(int device_): device(device_), use_default(false) { 32 | this->setup(device); 33 | } 34 | 35 | void setup(int); 36 | void sync(int=0); 37 | void syncTorch(); 38 | void destroy(); 39 | 40 | cudaStream_t torchStream(); 41 | cudaStream_t stream(size_t=0); 42 | cublasHandle_t handle(size_t=0); 43 | 44 | ~CudaStreamManager() { 45 | this->destroy(); 46 | } 47 | }; 48 | 49 | CudaStreamManager* getCudaStreamManager(const int device); 50 | 51 | #endif // CUDA_STREAM_MANAGER 52 | -------------------------------------------------------------------------------- /cuda/tests/.gitignore: -------------------------------------------------------------------------------- 1 | test_* 2 | -------------------------------------------------------------------------------- /cuda/tests/Makefile: -------------------------------------------------------------------------------- 1 | default : test_prune_gate test_limit test_assign test_counting 2 | 3 | test_% : %.cu 4 | nvcc $< ../stream_manager.cpp -lcublas -o $@ 5 | -------------------------------------------------------------------------------- /cuda/tests/assign.cu: -------------------------------------------------------------------------------- 1 | #include "../local_exchange.cuh" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | int main(int argc, char* args[]) { 9 | int n_worker = atoi(args[1]); 10 | int n_expert = atoi(args[2]); 11 | int batch_size = atoi(args[3]); 12 | int topk = atoi(args[4]); 13 | int tot_expert = n_worker * n_expert; 14 | 15 | long* gate_idx = new long[batch_size * topk]; 16 | long* n_gate_idx = new long[batch_size * topk]; 17 | 18 | int* lec = new int[tot_expert]; 19 | memset(lec, 0, sizeof(int) * tot_expert); 20 | for (int i = 0; i < batch_size * topk; ++i) { 21 | if (rand() % 10) { 22 | gate_idx[i] = rand() % tot_expert; 23 | ++lec[gate_idx[i]]; 24 | } else { 25 | gate_idx[i] = -1; 26 | } 27 | } 28 | for (int i = 1; i < tot_expert; ++i) { 29 | lec[i] += lec[i - 1]; 30 | } 31 | 32 | puts("gate idx"); 33 | for (int i = 0; i < batch_size * topk; ++i) { 34 | printf("%d ", gate_idx[i]); 35 | } 36 | putchar(10); 37 | int nlec = lec[tot_expert - 1]; 38 | 39 | int* g_lec; 40 | cudaMalloc(&g_lec, sizeof(int) * tot_expert); 41 | cudaMemcpy(g_lec, lec, sizeof(int) * tot_expert, cudaMemcpyHostToDevice); 42 | long* g_gate_idx; 43 | cudaMalloc(&g_gate_idx, sizeof(long) * batch_size * topk); 44 | cudaMemcpy(g_gate_idx, gate_idx, sizeof(long) * batch_size * topk, 45 | cudaMemcpyHostToDevice); 46 | long* g_pos; 47 | cudaMalloc(&g_pos, sizeof(long) * nlec); 48 | // cudaMemcpy(g_gate_idx, gate_idx, sizeof(long) * nlec, cudaMemcpyHostToDevice); 49 | 50 | auto smgr = getCudaStreamManager(0); 51 | fmoe_cuda_assign_pos_impl(g_lec, g_gate_idx, g_pos, batch_size * topk, 52 | topk, smgr); 53 | 54 | long* pos = new long[nlec]; 55 | cudaMemcpy(pos, g_pos, sizeof(long) * nlec, cudaMemcpyDeviceToHost); 56 | 57 | puts("pos"); 58 | for (int i = 0; i < nlec; ++i) { 59 | printf("%d ", pos[i]); 60 | } 61 | putchar(10); 62 | } 63 | 64 | -------------------------------------------------------------------------------- /cuda/tests/counting.cu: -------------------------------------------------------------------------------- 1 | #include "../local_exchange.cuh" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | int main(int argc, char* args[]) { 9 | int batch_size = atoi(args[1]); 10 | int n_expert = atoi(args[2]); 11 | 12 | long* gate_idx = new long[batch_size]; 13 | long* n_gate_idx = new long[batch_size]; 14 | int* ref_lec = new int[n_expert]; 15 | memset(ref_lec, 0, sizeof(int) * n_expert); 16 | 17 | for (int i = 0; i < batch_size; ++i) { 18 | gate_idx[i] = rand() % (n_expert + 1) - 1; 19 | if (gate_idx[i] != -1) { 20 | ref_lec[gate_idx[i]] += 1; 21 | } 22 | } 23 | 24 | puts("ref lec"); 25 | for (int i = 0; i < n_expert; ++i) { 26 | printf("%d ", ref_lec[i]); 27 | } 28 | putchar(10); 29 | 30 | int* g_lec; 31 | cudaMalloc(&g_lec, sizeof(int) * n_expert); 32 | cudaMemset(g_lec, 0, sizeof(int) * n_expert); 33 | long* g_gate_idx; 34 | cudaMalloc(&g_gate_idx, sizeof(long) * batch_size); 35 | cudaMemcpy(g_gate_idx, gate_idx, sizeof(long) * batch_size, 36 | cudaMemcpyHostToDevice); 37 | 38 | auto smgr = getCudaStreamManager(0); 39 | fmoe_cuda_expert_count_impl(g_gate_idx, g_lec, batch_size, n_expert, smgr); 40 | 41 | int* lec = new int[n_expert]; 42 | cudaMemcpy(lec, g_lec, sizeof(int) * n_expert, cudaMemcpyDeviceToHost); 43 | 44 | puts("lec"); 45 | for (int i = 0; i < n_expert; ++i) { 46 | printf("%d ", lec[i]); 47 | } 48 | putchar(10); 49 | } 50 | 51 | -------------------------------------------------------------------------------- /cuda/tests/limit.cu: -------------------------------------------------------------------------------- 1 | #include "../balancing.cuh" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | int main(int argc, char* args[]) { 9 | int n_worker = atoi(args[1]); 10 | int n_expert = atoi(args[2]); 11 | int cap_v = atoi(args[3]); 12 | int tot_expert = n_worker * n_expert; 13 | 14 | long* lec = new long[tot_expert]; 15 | for (int i = 0; i < tot_expert; ++i) { 16 | lec[i] = i; 17 | } 18 | long* g_lec; 19 | cudaMalloc(&g_lec, sizeof(long) * tot_expert); 20 | cudaMemcpy(g_lec, lec, sizeof(long) * tot_expert, cudaMemcpyHostToDevice); 21 | 22 | int* cap = new int[n_expert]; 23 | for (int i = 0; i < n_expert; ++i) { 24 | cap[i] = cap_v; 25 | } 26 | int* g_cap; 27 | cudaMalloc(&g_cap, sizeof(int) * n_expert); 28 | cudaMemcpy(g_cap, cap, sizeof(int) * n_expert, cudaMemcpyHostToDevice); 29 | 30 | long* eca = new long[tot_expert]; 31 | long* g_eca; 32 | cudaMalloc(&g_eca, sizeof(long) * tot_expert); 33 | 34 | auto smgr = getCudaStreamManager(0); 35 | fmoe_cuda_limit_by_capacity_impl(g_lec, g_cap, g_eca, n_expert, n_worker, smgr); 36 | 37 | cudaMemcpy(cap, g_cap, sizeof(int) * n_expert, cudaMemcpyDeviceToHost); 38 | cudaMemcpy(eca, g_eca, sizeof(long) * tot_expert, cudaMemcpyDeviceToHost); 39 | 40 | printf("%d\n", cap[0]); 41 | for (int i = 0; i < tot_expert; ++i) { 42 | printf("%ld %ld\n", lec[i], eca[i]); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /cuda/tests/prune_gate.cu: -------------------------------------------------------------------------------- 1 | #include "../balancing.cuh" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | int main(int argc, char* args[]) { 9 | int n_worker = atoi(args[1]); 10 | int n_expert = atoi(args[2]); 11 | int batch_size = atoi(args[3]); 12 | int tot_expert = n_worker * n_expert; 13 | 14 | long* gate_idx = new long[batch_size]; 15 | long* n_gate_idx = new long[batch_size]; 16 | 17 | long* lec = new long[tot_expert]; 18 | memset(lec, 0, sizeof(long) * tot_expert); 19 | 20 | for (int i = 0; i < batch_size; ++i) { 21 | gate_idx[i] = rand() % tot_expert; 22 | ++lec[gate_idx[i]]; 23 | } 24 | for (int i = 0; i < tot_expert; ++i) { 25 | lec[i] >>= 1; 26 | } 27 | long* g_lec; 28 | cudaMalloc(&g_lec, sizeof(long) * tot_expert); 29 | cudaMemcpy(g_lec, lec, sizeof(long) * tot_expert, cudaMemcpyHostToDevice); 30 | 31 | int* g_new_lec; 32 | cudaMalloc(&g_new_lec, sizeof(int) * tot_expert); 33 | 34 | long* g_gate_idx; 35 | cudaMalloc(&g_gate_idx, sizeof(long) * batch_size); 36 | cudaMemcpy(g_gate_idx, gate_idx, sizeof(long) * batch_size, cudaMemcpyHostToDevice); 37 | 38 | auto smgr = getCudaStreamManager(0); 39 | fmoe_cuda_prune_gate_by_capacity_impl(g_gate_idx, g_lec, g_new_lec, 40 | batch_size, n_expert, n_worker, smgr); 41 | cudaMemcpy(n_gate_idx, g_gate_idx, sizeof(long) * batch_size, cudaMemcpyDeviceToHost); 42 | 43 | for (int i = 0; i < batch_size; ++i) { 44 | printf("%ld %ld (%d)\n", gate_idx[i], n_gate_idx[i], lec[gate_idx[i]]); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /cuda/utils/fmoe_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef FMOE_UTILS_H 2 | #define FMOE_UTILS_H 3 | 4 | #define CHECK_CUDA(x) AT_ASSERTM(x.device().is_cuda(), #x " must be a CUDA tensor") 5 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 6 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 7 | 8 | #define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1) 9 | 10 | #endif // FMOE_UTILS_H 11 | -------------------------------------------------------------------------------- /cuda/utils/timer.hh: -------------------------------------------------------------------------------- 1 | #ifndef TIMER_HH 2 | #define TIMER_HH 3 | 4 | /* 5 | * This part of code is not used. 6 | #include 7 | 8 | inline double getDuration(std::chrono::time_point a, 9 | std::chrono::time_point b) { 10 | return std::chrono::duration(b - a).count(); 11 | } 12 | 13 | #define timestamp(__var__) auto __var__ = std::chrono::system_clock::now(); 14 | */ 15 | 16 | #endif // TIMER_HH 17 | 18 | -------------------------------------------------------------------------------- /doc/fastermoe/README.md: -------------------------------------------------------------------------------- 1 | Boost the Performance by FasterMoE 2 | === 3 | 4 | 一个中文版见[这篇博客](https://laekov.com.cn/view/181401#howto) 5 | 6 | There are three main optimizations in the PPoPP'22 paper _FasterMoE: Modeling 7 | and Optimizing Training of Large-scale Dynamic Pre-trained Models_. Thanks to 8 | the contributions of authors of the article, their optimizations are now 9 | integrated into FastMoE, and can be enabled via switches of environment 10 | variables. These optimizations can greatly increase the training efficiency of 11 | FastMoE. 12 | 13 | ## Smart Scheduling 14 | 15 | Recall that in an MoE layer, two `all-to-all`s are performed with the experts' 16 | computation in-between. In FasterMoE, the `all-to-all`s are broken down using 17 | a _group-wise exchange_ algorithm. And then, the expert can instantly start 18 | its jobs as long as a part of input, e.g. tokens from one other worker, is 19 | ready. 20 | 21 | Its effectiveness is revealed in the following timeline. `S` and `R` stand for 22 | the components of the `all-to-all`s, and `C` stands for computation of the 23 | expert. 24 | 25 | ![](smartsch.png) 26 | 27 | In FastMoE, to enable smart scheduling, set the environment variable ` 28 | FMOE_FASTER_SCHEDULE_ENABLE` to `1` or `ON`, and it is now by default off. 29 | 30 | Please note that there are a few constraints for smart scheduling in the 31 | current version of FastMoE. The input and output features have to be of 32 | the same length for the experts. This is because the developers of FasterMoE 33 | only implement this on their prototype, and they are looking for the 34 | community's efforts to have other cases supported. 35 | 36 | To fine-tune the performance of smart scheduling, the environment variable 37 | `FMOE_FASTER_GROUP_SIZE` stands for the size of worker groups in the 38 | _Group-wise Exchange_ algorithm. In other words, it is the granularity of the 39 | schedule. It should be set to a proper value that balance between pipeline 40 | bubbles and inefficient undersized computation granularity. 41 | 42 | ## Expert Shadowing 43 | 44 | According to observations when training real models, when no limitation is 45 | placed over expert selection, it follows a skew distribution, which means a few 46 | experts are much more popular than others. This introduces significant 47 | performance issue of load imbalance when using FastMoE's model parallel mode. 48 | 49 | The authors of FasterMoE proposes the solution that for the hot experts, their 50 | parameters are broadcast to all workers, namely shadows. With the shadows, 51 | computation of the hot experts can be performed locally on all workers, 52 | avoiding the bottleneck of sending so much workload to the workers containing 53 | the hot experts. Besides, a performance predictor, together with a shadow 54 | selection algorithm, is used to determine which experts to be shadowed before 55 | each iteration. 56 | 57 | In FastMoE, this feature is enabled by the environment variable 58 | `FMOE_FASTER_SHADOW_ENABLE`. For simplicity, this feature is only available 59 | when smart scheduling is enabled. Besides the constraints of smart scheduling, 60 | this feature requires the experts to be identical in structure, so that 61 | parameters can be copied between experts. 62 | 63 | A default shadow selection policy is located at 64 | `fmoe/fastermoe/shadow_policy.py`. If you want to alter the policy, please code 65 | there and re-install FastMoE. For the default policy, we assume that the 66 | experts are two-layer MLPs. A few parameters of the policy can be specified by 67 | the following environment variables for better effectiveness of the shadowing 68 | mechanism. 69 | 70 | * `FMOE_FASTER_GLBPLC_NETBW` is the bandwidth of the interconnection between 71 | workers, measured by `GBps`. 72 | * `FMOE_FASTER_GLBPLC_GPUTP` is the GeMM throughput of the GPUs, measured by 73 | `FLOPs`, e.g. `13e12` for NVIDIA V100 PCIe GPUs using fp32. 74 | * `FMOE_FASTER_GLBPLC_ALPHA` is the fraction of the activation length in the 75 | middle of the MLP to the input and output feature length, commonly seen to be 76 | `2` or `4` in transformers. 77 | * `FMOE_FASTER_GLBPLC_DMODEL` is the feature length of input and output of the 78 | experts. This parameter can be set automatically by FastMoE. 79 | 80 | ## Topology-aware Gate 81 | 82 | The two optimizations above do not change the behavior of the model, while this 83 | one does. To reduce network congestion when training in distributed system 84 | with hierarchical network topology, e.g. many GPUs in each of many nodes, the 85 | number of samples transmitted through the slower upper-level network is 86 | limited. The overfilling tokens select experts within the same lower-level 87 | network to reduce the communication overhead. 88 | 89 | The example topology-aware gate is implemented as `FasterGate` among FastMoE's 90 | gates. However, note that it may influence the accuracy of the model. And for 91 | different training hardware, different topology-aware gates shall be designed 92 | according to the specific case. 93 | 94 | The environment variable `FMOE_TOPO_GPUS_PER_NODE` represents number of GPUs in 95 | each local network, e.g. each node. And `FMOE_TOPO_OUTGOING_FRACTION` controls 96 | the fraction of tokens that are allowed to be sent across the upper-level 97 | network. 98 | -------------------------------------------------------------------------------- /doc/fastermoe/smartsch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laekov/fastmoe/55af4f98eee087cf5b3aac34318abf80c3bcbafd/doc/fastermoe/smartsch.png -------------------------------------------------------------------------------- /doc/installation-guide.md: -------------------------------------------------------------------------------- 1 | Step by step tutorial to install FastMoE on your local machine: 2 | 3 | 1. First of all you'll need to check your torch and nccl version, make sure to have a CUDA version compatible to the one torch was compiled (in general if you have the latest torch version it works also with the latest CUDA): 4 | ``` 5 | # go in terminal and use this command, the output should be something like this: 6 | 7 | python -c 'import torch; print(torch.__version__); print(torch.cuda.nccl.version())' 8 | >>> 2.0.1+cu117 9 | >>> (2, 14, 3) # -> this means version 2.14.3 10 | 11 | # to check cuda version you can use one of this two options with a similar output, 12 | # the binary path (second option) might be needed for troubleshooting: 13 | 14 | nvcc --version 15 | >>> Cuda compilation tools, release 11.7, V11.7.99 16 | >>> Build cuda_11.7.r11.7/compiler.31442593_0 17 | 18 | which nvcc 19 | >>> /usr/local/cuda-11.7/bin/nvcc 20 | ``` 21 | 2. An extra NCCL developer package is needed to enable the distributed features of FastMoE at the following link: https://developer.nvidia.com/nccl/nccl-legacy-downloads. Make sure to follow this steps: 22 | ``` 23 | # following the previous example I'll consider the version 2.14.3 with a CUDA version <= System CUDA version and Ubuntu 20.04 24 | # the first command is different depending on the system and the version, just paste it from the site 25 | 26 | wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb 27 | sudo dpkg -i cuda-keyring_1.0-1_all.deb 28 | sudo apt-get update 29 | 30 | # don't forget to install the package, this command is difficult to see as it is written at the end 31 | # outside the code block for each different installation 32 | 33 | sudo apt install libnccl2=2.14.3-1+cuda11.7 libnccl-dev=2.14.3-1+cuda11.7 34 | ``` 35 | 36 | 3. Now you can clone the repository and enter the folder to launch the installation script as follows: 37 | ``` 38 | # clone repo and move into the folder 39 | 40 | git clone https://github.com/laekov/fastmoe.git 41 | cd fastmoe 42 | 43 | # Option 1: disabling distributed features 44 | 45 | USE_NCCL=0 python setup.py install 46 | 47 | # Option 2: enabling distributed features 48 | 49 | python setup.py install 50 | ``` 51 | 52 | #### Troubleshooting 53 | 54 | If you have errors (warnings are OK) during the compilation make sure that the installer has the correct flags, this can be seen in the error as `-I/path/to/xxx/bin` and `-L/path/to/xxx/lib`. This flags should point to the correct CUDA for which all the other packages are compatible (torch and NCCL), if this paths are not correct you'll have to tell the system explicitly which CUDA version you want to use. Simple solutions could be this: 55 | ``` 56 | # (suggested) export the correct paths before compiling 57 | 58 | export PATH="/usr/local/cuda-11.7/bin:$PATH" 59 | export LD_LIBRARY_PATH="/usr/local/cuda-11.7/lib:$LD_LIBRARY_PATH" 60 | python setup.py install 61 | 62 | # eventually add these to your ~/.bashrc as an option to reduce future works 63 | 64 | nano ~/.bashrc 65 | export PATH="/usr/local/cuda-11.7/bin:$PATH" 66 | export LD_LIBRARY_PATH="/usr/local/cuda-11.7/lib:$LD_LIBRARY_PATH" 67 | source ~/.bashrc 68 | ``` 69 | -------------------------------------------------------------------------------- /doc/logo/rect.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laekov/fastmoe/55af4f98eee087cf5b3aac34318abf80c3bcbafd/doc/logo/rect.png -------------------------------------------------------------------------------- /doc/logo/square.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laekov/fastmoe/55af4f98eee087cf5b3aac34318abf80c3bcbafd/doc/logo/square.png -------------------------------------------------------------------------------- /doc/parallelism/README.md: -------------------------------------------------------------------------------- 1 | Multi-Dimensional Parallelism Supported by FastMoE 2 | === 3 | 4 | _这篇文档懒得写中文版了. 在获得来自社区的贡献前, 请自行谷歌翻译._ 5 | 6 | FastMoE now supports almost every popular way to train models in parallel, and any combination of them. 7 | Below shows all possible group of processes that a process may get involved. 8 | Users can enable them by simply assigning communication groups in either FastMoE or external codebase that uses FastMoE. 9 | 10 | ![](parallelism.png) 11 | 12 | #### Data Parallel 13 | 14 | In a group of data-parallel processes, models, including the experts, are replicated across the processes. 15 | To have experts replicated, first, assign `expert_dp_comm="dp"` at `mark_parallel_comm` function of an `FMoE` instance. 16 | (The string `"dp"` can be replaced by another name if you wish). 17 | Then, wrap the MoE module with `fmoe.distributed.DistributedGroupedDataParallel`, 18 | and set `dp_group` in the constructor to the process group in PyTorch that you wish to perform data parallelism. 19 | By default, the parameters are initially synchronized, unless disabled by `need_sync=False`. 20 | Run `model.allreduce_params` every iteration after backward propagation. 21 | 22 | ![](fastmoe_data_parallel.png) 23 | 24 | #### Model Parallel 25 | 26 | In typical model parallelism (maybe called tensor-model parallelism), every single expert is split up. 27 | FastMoE requires the external codebase to implement it by properly splitting the expert module that is provided to FastMoE. 28 | An official example using Megatron-LM can be seen in our adapter. 29 | The `hidden_hidden_size` of FastMoE's transformer module is divided by `k` which denotes the number of model-parallel processes. 30 | In this way, each expert is split into `k` pieces. 31 | Then, an `all-reduce` is performed over the feature matrix externally in the adapter, so that output of the experts is merged. 32 | 33 | #### Expert Parallel (MoE Group and Slice Group) 34 | 35 | In a group of expert parallel processes, each process maintains different experts. 36 | Processes in an MoE group contain all experts, and in `moe_group`, the input feature maps on the processes are from different samples. 37 | FastMoE performs `all-to-all` to exchange them, i.e. sending each feature vector to the processes that contain its selected experts. 38 | 39 | ![](fastmoe_expert_parallel.png) 40 | 41 | `slice_group` is a way to adapt typical model parallel to expert parallel. 42 | It assumes that the processes in the group have replicated input feature vectors. 43 | So, each process selects part of the feature vectors (a slice) as input to the `moe_group`, 44 | and perform `all-gather` after the expert-parallel NN operations to produce replicated output. 45 | 46 | #### Pipeline Parallel 47 | 48 | An MoE layer is a part of any stage. 49 | The external codebase shall handle the communication across stages. 50 | Notice that the `gate` module is replicated across all the process of the above three ways of intra-layer parallelism. 51 | So, for the inter-layer parallelism, users should specify `gate_group` in `DistributedGroupedDataParallel` as all processes in the same stage. 52 | 53 | #### Hybrid Parallel 54 | 55 | Obviously, any combination of the above four ways of parallel training can be enabled by specifying proper communication groups for `FMoE` and `DistributedGroupedDataParallel`. 56 | Refer to our [ATC'23 paper](https://www.usenix.org/conference/atc23/presentation/zhai) for studies on the optimal selection of hybrid parallelism. 57 | -------------------------------------------------------------------------------- /doc/parallelism/fastmoe_data_parallel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laekov/fastmoe/55af4f98eee087cf5b3aac34318abf80c3bcbafd/doc/parallelism/fastmoe_data_parallel.png -------------------------------------------------------------------------------- /doc/parallelism/fastmoe_expert_parallel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laekov/fastmoe/55af4f98eee087cf5b3aac34318abf80c3bcbafd/doc/parallelism/fastmoe_expert_parallel.png -------------------------------------------------------------------------------- /doc/parallelism/parallelism.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laekov/fastmoe/55af4f98eee087cf5b3aac34318abf80c3bcbafd/doc/parallelism/parallelism.png -------------------------------------------------------------------------------- /doc/readme-cn.md: -------------------------------------------------------------------------------- 1 | FastMoE 系统 2 | === 3 | 4 | [版本更新记录](release-note.md) 5 | | [Slack 讨论组邀请链接](https://join.slack.com/t/fastmoe/shared_invite/zt-mz0ai6ol-ggov75D62YsgHfzShw8KYw) 6 | 7 | ## 简介 8 | 9 | FastMoE 是一个易用且高效的基于 PyTorch 的 MoE 模型训练系统. 10 | 11 | ## 安装 12 | 13 | ### 依赖 14 | 15 | 启用了 CUDA 的 PyTorch 是必要的. 当前版本的 FastMoE 在 PyTorch v1.10.0 和 CUDA 16 | 11 的平台上经过了测试. 本系统从设计上也支持更旧或更新的 PyTorch 版本. 17 | 18 | 已知最老的支持的版本是 PyTorch `1.7.0` 和 CUDA `10`, 19 | 但已知某些老版本可能需要修改 FastMoE 的代码以实现支持. 20 | 21 | 如果需要使能 FastMoE 模型并行特性, 那么支持点对点通信的 NCCL 库 (即不旧于 22 | `2.7.5` 版本) 也是必需的. 23 | 24 | ### 安装 25 | 26 | FastMoE 包含一些定制的 PyTorch 算子, 包含一些 C 的组件. 用 `python setup.py install` 27 | 来简单地安装 FastMoE. 28 | 29 | FastMoE 分布式模型并行特性默认是被启用的. 如果它需要被禁用, 30 | 则需要在运行上述命令时加入环境变量 `USE_NCCL=0`. 31 | 32 | 注意, 由于 PyTorch 框架通常仅集成了 NCCL 的运行时组件, 额外的 NCCL 33 | 开发包需要被安装在编译环境中, 而且它的版本需要与 PyTorch 的版本相对应. 推荐使用 34 | [PyTorch 官方 Docker 镜像](https://hub.docker.com/r/pytorch/pytorch), 35 | 因为那里的环境较为干净. 如果您希望手工配置环境, 可以在 [NCCL 36 | 全部版本的下载链接](https://developer.nvidia.com/nccl/nccl-legacy-downloads) 37 | 下载合适版本的 NCCL 开发包. 38 | 39 | ## 使用 40 | 41 | ### 将一个 Transformer 模型 FMoE 化 42 | 43 | Transformer 是当前最流行的可被 MoE 化的模型. FastMoE 可以一键将一个普通的 44 | Transformer 模型变为一个 MoE 的模型. 其使用方法如下. 45 | 46 | 例如在 [Megatron-LM](https://github.com/nvidia/megatron-lm) 中, 47 | 添加如下的代码即可将 Transformer 中的每个 MLP 层变为多个 MLP 层构成的 MoE 网络. 48 | 49 | ```python 50 | model = ... 51 | 52 | from fmoe.megatron import fmoefy 53 | model = fmoefy(model, fmoe_num_experts=) 54 | 55 | train(model, ...) 56 | ``` 57 | 58 | 一个更详细的在 Megatron-LM 中使用 `fmoefy` 函数的样例参见[此处](../examples/megatron). 59 | 60 | ### 将 FastMoE 作为一个网络模块使用 61 | 62 | 一个使用 FastMoE 的 Transformer 模型见[这个示例](../examples/transformer-xl). 63 | 最简单的使用方式是使用 `FMoE` 层来代替 `MLP` 层. 64 | 65 | ### 分布式地使用 FastMoE 66 | 67 | FastMoE 支持并行方式. 详见[并行方式详细说明](doc/parallelism). 68 | 以下简单介绍两种最容易使用的并行方式. 69 | 70 | #### 数据并行. 71 | 72 | 在 FastMoE 的数据并行模式下, 73 | 门网络(gate)和专家网络都被复制地放置在各个运算单元上. 74 | 下图展示了一个有三个专家的两路数据并行MoE模型进行前向计算的方式. 75 | 76 |

77 | 78 |

79 | 80 | 对于数据并行, 额外的代码是不需要的. FastMoE 与 PyTorch 的 `DataParallel` 和 81 | `DistributedDataParallel` 模块都可以无缝对接. 该方式唯一的问题是, 82 | 专家的数量受到单个计算单元(如GPU)的内存大小限制. 83 | 84 | #### 专家并行 (也曾被叫作模型并行) 85 | 86 | 在 FastMoE 的专家并行模式中, 门网络依然是复制地被放置在每个计算单元上的, 87 | 但是专家网络被独立地分别放置在各个计算单元上. 因此, 通过引入额外的通信操作, 88 | FastMoE 可以允许更多的专家网络们同时被训练, 89 | 而其数量限制与计算单元的数量是正相关的. 90 | 91 | 下图展示了一个有六个专家网络的模型被两路专家并行地训练. 92 | 注意专家1-3被放置在第一个计算单元上, 而专家4-6被放置在第二个计算单元上. 93 | 94 |

95 | 96 |

97 | 98 | FastMoE 的专家并行模式需要专门的并行策略, 而 PyTorch 和 Megatron-LM 99 | 都不支持这样的策略 (在我们创建 FastMoE 时). 因此, 需要使用 100 | `fmoe.DistributedGroupedDataParallel` 101 | 模块来代替 PyTorch 的 DDP 模块. 102 | 103 | ### 如何训练得更快 104 | 105 | 在 PPoPP'22 会议上有一篇论文: _FasterMoE: modeling and optimizing training of 106 | large-scale dynamic pre-trained models_. 我们将文中的技术集成到了 FastMoE 系统中, 107 | 从而提升其模型并行的效率. 108 | 109 | 这些新特性被命名为 **Faster Performance Features**, 并通过一些环境变量来控制是否 110 | 启用它们. 详见[这篇单独的文档](doc/fastermoe). 111 | 112 | ## 答疑 / 讨论 113 | 114 | 如果您在使用 FastMoE 的过程中有任何疑问, 或您有兴趣参与 FastMoE 的相关工作, 115 | 欢迎加入我们的 [Slack 讨论组](https://join.slack.com/t/fastmoe/shared_invite/zt-mz0ai6ol-ggov75D62YsgHfzShw8KYw). 116 | -------------------------------------------------------------------------------- /doc/release-note.md: -------------------------------------------------------------------------------- 1 | ## v1.1.0 2 | 3 | ### Performance 4 | 5 | * Smart schedule of FasterMoE is updated with correct stream management, and becomes faster. 6 | 7 | ### Testing 8 | 9 | * All unit tests are checked and they run correctly now. 10 | 11 | ### Adaption 12 | 13 | * Megatron-LM 3.2 supported. 14 | 15 | ### Documentation 16 | 17 | * README is updated with some bugs fixed. 18 | * A detailed [document for process groups](/doc/parallelism). 19 | 20 | 21 | ## v1.0.1 22 | 23 | ### Compatibility 24 | 25 | * PyTorch 2.0 supported. 26 | * Megatron-LM 2.5 supported. 27 | 28 | ### Documentation 29 | 30 | * A detailed [installation-guide](/doc/installation-guide.md) thanks to @santurini 31 | 32 | ### Performance related 33 | 34 | * Generalize FasterMoE's schedule to `n_expert > 1` and more bug fixes. 35 | * Synchronization reduction thanks to @Fragile-azalea 36 | 37 | ## v1.0.0 38 | 39 | ### FasterMoE 40 | 41 | * The new performance boosting features in the PPoPP'22 paper FasterMoE, detailed in the document. 42 | * Expert Shadowing. 43 | * Smart Scheduling. 44 | * Topology-aware gate. 45 | 46 | ### Bug fixes 47 | 48 | * Transformer-XL examples. 49 | * Compatibility to PyTorch versions. 50 | * Megatron-LM documents. 51 | * GShardGate. 52 | 53 | ## v0.3.0 54 | 55 | ### FMoE core 56 | 57 | * Previous `mp_group` is renamed to `slice_group`, indicating that all workers in the group receive the same input batch, and process a slice of the input. `mp_group` will be deprecated in our next release. 58 | * ROCm supported. 59 | * `FMoELinear` is moved to a stand-alone file. 60 | 61 | ### Groupped data parallel 62 | 63 | * Support any group name by their relative tag name. 64 | 65 | ### Load balancing 66 | 67 | * A brand new balancing strategy - SWIPE. Contributed by authors of a (currently unpublished) paper. 68 | * A property `has_loss` is added to each gate, in order to identify whether balance loss should be collected. 69 | 70 | ### Megatron-LM support 71 | 72 | * Experts are partitioned by tensor model parallelism in `mp_group`, instead of expert parallelism. 73 | * Support arbitrary customized gate in `MegatronMLP`. 74 | * Move the patches to a stand-alone file. 75 | 76 | ### Tests 77 | 78 | * Move util functions into `test_ddp.py`. 79 | 80 | ## v0.2.1 81 | 82 | ## Load balancing 83 | 84 | * Fix gradient for balance loss. 85 | 86 | ### Misc 87 | 88 | * Typos. 89 | * Update benchmark interface. 90 | * Remove some redundant code for performance improvement. 91 | * Enable `USE_NCCL` by default. 92 | * Compatibility for PyTorch `<1.8.0` and `>=1.8.0`. 93 | 94 | ### Megatron adaption 95 | 96 | * Patch for numerical correctness of gradient clipping. 97 | * Support to pipeline parallelism. 98 | 99 | ## v0.2.0 100 | 101 | ## Load balancing 102 | 103 | * A brand new gate module with capacity-related utilities. 104 | * GShard's and Switch Transformer's balance strategies are implemented as integrated gates. 105 | * Balance loss is enabled. 106 | * Balance monitor is provided. 107 | 108 | ## Checkpointing 109 | 110 | * MoE models can be loaded and saved by fmoe's checkpointing module. 111 | 112 | ## Performance 113 | 114 | * FP16 training performance is improved. 115 | 116 | ## Misc 117 | 118 | * CUDA code directory is reconstructed. 119 | * More tests are added. 120 | 121 | ## v0.1.2 122 | 123 | ### Compilation 124 | 125 | - Remove dependency on the CUDA examples repository. 126 | 127 | ### Distributed 128 | 129 | - Fix a bug related to PyTorch v1.8.0. FastMoE can now operate on multiple GPUs 130 | on multiple nodes with PyTorch v1.8.0. 131 | 132 | ### Misc 133 | 134 | - Fix tons of typos. 135 | - Format the code. 136 | 137 | ## v0.1.1 138 | 139 | ### Distributed 140 | 141 | - Broadcast data-parallel parameters before training. 142 | 143 | ### Megatron adaption 144 | 145 | - Initialize `FMoELinear` parameters using different seed in model parallel even using the same random seed in megatron. 146 | - Use proper comm for mp and dp. 147 | 148 | ### Transformer-XL example 149 | 150 | - Improve scripts. 151 | 152 | ### Misc 153 | 154 | - Logo and slack workspace link. 155 | - Document in Chinese. 156 | - Figures to explain how FastMoE works. 157 | 158 | ## v0.1.0 159 | 160 | ### Functions 161 | 162 | - A model-injection-style easy-to-use user interface for Megatron-LM. 163 | - Support both data parallel and model parallel, and a hybrid of the two, 164 | - Provide a new customized DDP module to synchronize in different comm groups. 165 | - Support to customized `nn.Module` as an expert. 166 | 167 | ### Document and infrastructure 168 | 169 | - Use PyTest. 170 | - Setup PyLint. 171 | - Installation and usage guide. 172 | - Explanation of functions and code structure in code. 173 | 174 | ### Performance 175 | 176 | - A benchmark to compare FastMoE and old PyTorch impl. 177 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | transformer-xl/data 2 | transformer-xl/LM-TFM-enwik8 3 | data 4 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | Examples of FastMoE 2 | === 3 | 4 | As FastMoE supports both stand-alone training (or built-in data parallelism 5 | supported by PyTorch), and expert parallelism implemented by customized 6 | operators, we present two examples to show the usage of them separately. 7 | 8 | ### Transformer-XL 9 | 10 | This example contains a single-process version of transformer training code 11 | that uses PyTorch's DataParallel module to utilize multiple GPUs within one 12 | node. In this example, FastMoE works as a simple local module without involving 13 | any means of parallelism. 14 | 15 | ### Megatron-LM 16 | 17 | [Megatron-LM](https://github.com/nvidia/megatron-lm) is a transformer framework 18 | developed by NVIDIA. It supports diverse parallelisms, including data, model, 19 | and pipeline. It is scalable to up to thousands of GPUs, with one process 20 | binded to each GPU. 21 | 22 | FastMoE works with any combination of the parallelisms provided by Megatron-LM. 23 | In the example, the dimension of data parallelism is used as the communication 24 | group for expert parallelism, so that the GPU memory consumption is kept 25 | identical to the original non-MoE model, and the model size is enlarged. 26 | -------------------------------------------------------------------------------- /examples/megatron/README.md: -------------------------------------------------------------------------------- 1 | FastMoE works with different versions of 2 | [Megatron-LM](https://github.com/nvidia/megatron-lm). 3 | See `fmoe/megatron/utils.py` for arguments of FastMoE. 4 | 5 | An example patch is provided for `v2.2` release. 6 | The patch can be directly applied to add FastMoE support if you are using 7 | Megatron-LM v2.2. 8 | Otherwise, you may need to manually enable FastMoE in your codebase. 9 | The patch includes the following modifications. 10 | 11 | ### Add arguments to Megatron's argparser 12 | 13 | In `megatron/arguments.py`, add `_add_fmoe_args` to the parser. 14 | 15 | ### Patch checkpoint 16 | 17 | In `megatron/training.py`, replace `load_checkpoint` and `save_checkpoint` by 18 | functions with the same name in `fmoe.megatron.checkpointing`. 19 | 20 | ### Building the model in FastMoE style 21 | 22 | In `megatron/training.py`, the `fmoe.megatron.fmoefy` function is used as an 23 | entrance to one-key introduce FastMoE layer to replace the MLP layers in the 24 | transformer language models. 25 | 26 | ```python 27 | from fmoe.megatron import fmoefy 28 | model = fmoefy(model, fmoe_num_experts=4) 29 | ``` 30 | 31 | Note that the `fmoefy` function currently only takes a standard Megatron-LM's 32 | top-level raw model as input, i.e. the MLP layers should be available at 33 | `model.language_model.transformer.layers[i].mlp`. 34 | 35 | ### Using FastMoE's model parallellization 36 | 37 | In `megatron/training.py`, the `LocalDDP` module is replaced by the one in 38 | `fmoe.megatron` to enable the sophiscated data parallel strategies that can 39 | parallelize the experts across both the data parallel group and the (tensor) 40 | model parallel model group. 41 | 42 | ```python 43 | # from megatron.model import DistributedDataParallel as LocalDDP 44 | from fmoe.megatron import DistributedDataParallel as LocalDDP 45 | ``` 46 | 47 | ### Fix gradient clipping 48 | 49 | Megatron-LM uses gradient normalization, which is incompatible with FastMoE. 50 | Incorrect norm of the gradients lead to inconsistent parameter updates. 51 | Apply `clip-grad-v2.2.patch` to fix the issue. 52 | 53 | Note that only 2-norm is implemented in the patch. If other norm methods is 54 | used, remember to implement it accordingly. 55 | 56 | ### Train as usual 57 | 58 | Start traning with FastMoE by using the scripts provided by Megatron-LM. 59 | -------------------------------------------------------------------------------- /examples/megatron/clip-grad-v2.2.patch: -------------------------------------------------------------------------------- 1 | diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py 2 | index e8d0d02..91c663e 100644 3 | --- a/megatron/optimizer/clip_grads.py 4 | +++ b/megatron/optimizer/clip_grads.py 5 | @@ -52,6 +52,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): 6 | # - should not be a replica due to tensor model parallelism 7 | grads = [] 8 | grads_for_norm = [] 9 | + grads_in_moe = [] 10 | for param in parameters: 11 | grad_not_none = param.grad is not None 12 | is_not_shared = not hasattr(param, 'shared') or not param.shared 13 | @@ -63,7 +64,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): 14 | assert param.grad.type() == 'torch.cuda.FloatTensor' 15 | grads.append(grad) 16 | if grad_not_none and is_not_shared and is_not_tp_duplicate: 17 | - grads_for_norm.append(grad) 18 | + if hasattr(param, 'dp_comm') and param.dp_comm in ('none'): 19 | + grads_in_moe.append(grad) 20 | + else: 21 | + grads_for_norm.append(grad) 22 | 23 | # Norm parameters. 24 | max_norm = float(max_norm) 25 | @@ -72,6 +76,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): 26 | 27 | # Calculate norm. 28 | if norm_type == inf: 29 | + # TODO: moe 30 | total_norm = max(grad.abs().max() for grad in grads_for_norm) 31 | total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) 32 | # Take max across all model-parallel GPUs. 33 | @@ -96,7 +101,20 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): 34 | # we need the pow(norm-type). 35 | total_norm = grad_norm ** norm_type 36 | 37 | + if grads_in_moe: 38 | + grad_norm, _ = multi_tensor_applier( 39 | + amp_C.multi_tensor_l2norm, 40 | + dummy_overflow_buf, 41 | + [grads_in_moe], 42 | + False # no per-parameter norm 43 | + ) 44 | + grad_norm = grad_norm ** norm_type 45 | + torch.distributed.all_reduce(grad_norm, 46 | + group=mpu.get_model_parallel_group()) 47 | + total_norm += grad_norm 48 | + 49 | else: 50 | + # TODO: moe 51 | for grad in grads_for_norm: 52 | grad_norm = torch.norm(grad, norm_type) 53 | total_norm += grad_norm ** norm_type -------------------------------------------------------------------------------- /examples/megatron/fmoefy-v2.2.patch: -------------------------------------------------------------------------------- 1 | diff --git a/megatron/arguments.py b/megatron/arguments.py 2 | index 26a7cec..0acfb22 100644 3 | --- a/megatron/arguments.py 4 | +++ b/megatron/arguments.py 5 | @@ -21,6 +21,8 @@ import os 6 | import torch 7 | from megatron import fused_kernels 8 | 9 | +from fmoe.megatron import add_fmoe_args as _add_fmoe_args 10 | + 11 | def parse_args(extra_args_provider=None, defaults={}, 12 | ignore_unknown_args=False): 13 | """Parse all arguments.""" 14 | @@ -40,6 +42,7 @@ def parse_args(extra_args_provider=None, defaults={}, 15 | parser = _add_data_args(parser) 16 | parser = _add_autoresume_args(parser) 17 | parser = _add_realm_args(parser) 18 | + parser = _add_fmoe_args(parser) 19 | 20 | # Custom arguments. 21 | if extra_args_provider is not None: 22 | diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py 23 | index 9d42260..2583db2 100644 24 | --- a/megatron/optimizer/optimizer.py 25 | +++ b/megatron/optimizer/optimizer.py 26 | @@ -177,6 +177,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): 27 | param) 28 | if hasattr(param, 'shared'): 29 | main_param.shared = param.shared 30 | + if hasattr(param, 'dp_comm'): 31 | + main_param.dp_comm = param.dp_comm 32 | # Replace the optimizer params with the new fp32 copy. 33 | param_group['params'][i] = main_param 34 | fp32_from_fp16_params_this_group.append(main_param) 35 | diff --git a/megatron/training.py b/megatron/training.py 36 | index 56d1c7c..f825bf3 100644 37 | --- a/megatron/training.py 38 | +++ b/megatron/training.py 39 | @@ -35,20 +35,24 @@ from megatron import update_num_microbatches 40 | from megatron import mpu 41 | from megatron import print_rank_0 42 | from megatron import print_rank_last 43 | -from megatron.checkpointing import load_checkpoint 44 | -from megatron.checkpointing import save_checkpoint 45 | +# from megatron.checkpointing import load_checkpoint 46 | +from fmoe.megatron.checkpoint import load_checkpoint 47 | +# from megatron.checkpointing import save_checkpoint 48 | +from fmoe.megatron.checkpoint import save_checkpoint 49 | from megatron.model import FP16Module 50 | from megatron.optimizer import get_megatron_optimizer 51 | 52 | from megatron.initialize import initialize_megatron 53 | from megatron.initialize import write_args_to_tensorboard 54 | from megatron.learning_rates import AnnealingLR 55 | -from megatron.model import DistributedDataParallel as LocalDDP 56 | +# from megatron.model import DistributedDataParallel as LocalDDP 57 | from megatron.model.realm_model import ICTBertModel 58 | from megatron.utils import check_adlr_autoresume_termination 59 | from megatron.data.data_loaders import build_pretraining_data_loader 60 | from megatron.utils import report_memory 61 | 62 | +from fmoe.megatron import DistributedDataParallel as LocalDDP 63 | +from fmoe.megatron import add_balance_log 64 | 65 | def print_datetime(string): 66 | """Note that this call will sync across all ranks.""" 67 | @@ -102,6 +106,13 @@ def pretrain(train_valid_test_dataset_provider, model_provider, 68 | args = get_args() 69 | timers = get_timers() 70 | 71 | + # Initialize FastMoE 72 | + if args.fmoefy: 73 | + from fmoe.megatron import patch_forward_step, patch_model_provider 74 | + 75 | + forward_step_func = patch_forward_step(forward_step_func) 76 | + model_provider = patch_model_provider(model_provider) 77 | + 78 | # Model, optimizer, and learning rate. 79 | timers('model and optimizer').start() 80 | model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) 81 | @@ -643,7 +654,7 @@ def train_step(forward_step_func, data_iterator, 82 | 83 | 84 | def training_log(loss_dict, total_loss_dict, learning_rate, iteration, 85 | - loss_scale, report_memory_flag, skipped_iter): 86 | + loss_scale, report_memory_flag, skipped_iter, model): 87 | """Log training information such as losses, timing, ....""" 88 | args = get_args() 89 | timers = get_timers() 90 | @@ -725,6 +736,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, 91 | args.consumed_train_samples) 92 | timers.write(timers_to_log, writer, iteration, 93 | normalizer=total_iterations) 94 | + if args.fmoefy and args.balance_strategy and args.balance_strategy != 'naive': 95 | + add_balance_log(model, writer, iteration) 96 | 97 | if iteration % args.log_interval == 0: 98 | elapsed_time = timers('interval time').elapsed() 99 | @@ -816,7 +829,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, 100 | report_memory_flag = training_log(loss_dict, total_loss_dict, 101 | optimizer.param_groups[0]['lr'], 102 | iteration, loss_scale, 103 | - report_memory_flag, skipped_iter) 104 | + report_memory_flag, skipped_iter, model) 105 | 106 | # Autoresume 107 | if args.adlr_autoresume and \ 108 | -------------------------------------------------------------------------------- /examples/transformer-xl/README.md: -------------------------------------------------------------------------------- 1 | This directory contains an example based on Zihang Dai, et.al's open-source 2 | transformer [implementation](https://github.com/kimiyoung/transformer-xl) to 3 | demostrate the usage of the usage of Fast MoE's layers. 4 | 5 | The code is released with Apache-2.0 license. Here, only the pytorch part of the 6 | code is used, with modification in the `mem_transformer.py` file to enable MoE 7 | training. 8 | 9 | ## Introduction 10 | 11 | This directory contains our pytorch implementation of Transformer-XL. Note that our state-of-the-art results reported in the paper were obtained by training the model on a large-scale TPU cluster, and our pytorch codebase currently does not support distributed training. Here we provide two sets of hyperparameters and scripts: 12 | - `*large.sh` are for the SoTA setting with large models which might not be directly runnable on a local GPU machine. 13 | - `*base.sh` are for the base models which can be run on a few GPUs. 14 | 15 | The pytorch implementation produces similar results to the TF codebase under the same settings in our preliminary experiments. 16 | 17 | 18 | ## Prerequisite 19 | 20 | - Pytorch 0.4: `conda install pytorch torchvision -c pytorch` 21 | 22 | 23 | ## Data Prepration 24 | 25 | `bash getdata.sh` 26 | 27 | ## Training and Evaluation 28 | 29 | #### Replicate the "bpc = 1.06" result on `enwik8` with a 12-layer Transformer-XL 30 | 31 | - Make sure the machine have **4 GPUs**, each with **at least 11G memory** 32 | 33 | - Training 34 | 35 | `bash run_enwik8_base.sh train --work_dir PATH_TO_WORK_DIR` 36 | 37 | - Evaluation 38 | 39 | `bash run_enwik8_base.sh eval --work_dir PATH_TO_WORK_DIR` 40 | 41 | 42 | 43 | #### Replicate the "PPL = 24.03" result on `wikitext-103` with Transformer-XL 44 | 45 | - Make sure the machine have **4 GPUs**, each with **at least 11G memory** 46 | 47 | - Training 48 | 49 | `bash run_wt103_base.sh train --work_dir PATH_TO_WORK_DIR` 50 | 51 | - Evaluation 52 | 53 | `bash run_wt103_base.sh eval --work_dir PATH_TO_WORK_DIR` 54 | 55 | 56 | 57 | #### Other options: 58 | 59 | - `--batch_chunk`: this option allows one to trade speed for memory. For `batch_chunk > 1`, the program will split each training batch into `batch_chunk` sub-batches and perform forward and backward on each sub-batch sequentially, with the gradient accumulated and divided by `batch_chunk`. Hence, the memory usage will propertionally lower while the computation time will inversely higher. 60 | - `--div_val`: when using adaptive softmax and embedding, the embedding dimension is divided by `div_val` from bin $i$ to bin $i+1$. This saves both GPU memory and the parameter budget. 61 | - `--fp16` and `--dynamic-loss-scale`: Run in pseudo-fp16 mode (fp16 storage fp32 math) with dynamic loss scaling. 62 | - Note: to explore the `--fp16` option, please make sure the `apex` package is installed (https://github.com/NVIDIA/apex/). 63 | - To see performance without the recurrence mechanism, simply use `mem_len=0` in all your scripts. 64 | - To see performance of a standard Transformer without relative positional encodings or recurrence mechanisms, use `attn_type=2` and `mem_len=0`. 65 | 66 | 67 | #### Other datasets: 68 | 69 | - `Text8` character-level language modeling: check out `run_text8_base.sh` 70 | - `lm1b` word-level language modeling: check out `run_lm1b_base.sh` 71 | -------------------------------------------------------------------------------- /examples/transformer-xl/eval.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import argparse 3 | import time 4 | import math 5 | import os, sys 6 | 7 | import torch 8 | 9 | from data_utils import get_lm_corpus 10 | from mem_transformer import MemTransformerLM 11 | from utils.exp_utils import get_logger 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') 14 | parser.add_argument('--data', type=str, default='../data/wikitext-103', 15 | help='location of the data corpus') 16 | parser.add_argument('--dataset', type=str, default='wt103', 17 | choices=['wt103', 'lm1b', 'enwik8', 'text8'], 18 | help='dataset name') 19 | parser.add_argument('--split', type=str, default='all', 20 | choices=['all', 'valid', 'test'], 21 | help='which split to evaluate') 22 | parser.add_argument('--batch_size', type=int, default=10, 23 | help='batch size') 24 | parser.add_argument('--tgt_len', type=int, default=5, 25 | help='number of tokens to predict') 26 | parser.add_argument('--ext_len', type=int, default=0, 27 | help='length of the extended context') 28 | parser.add_argument('--mem_len', type=int, default=0, 29 | help='length of the retained previous heads') 30 | parser.add_argument('--clamp_len', type=int, default=-1, 31 | help='max positional embedding index') 32 | parser.add_argument('--cuda', action='store_true', 33 | help='use CUDA') 34 | parser.add_argument('--work_dir', type=str, required=True, 35 | help='path to the work_dir') 36 | parser.add_argument('--no_log', action='store_true', 37 | help='do not log the eval result') 38 | parser.add_argument('--same_length', action='store_true', 39 | help='set same length attention with masking') 40 | args = parser.parse_args() 41 | assert args.ext_len >= 0, 'extended context length must be non-negative' 42 | 43 | device = torch.device("cuda" if args.cuda else "cpu") 44 | 45 | # Get logger 46 | logging = get_logger(os.path.join(args.work_dir, 'log.txt'), 47 | log_=not args.no_log) 48 | 49 | # Load dataset 50 | corpus = get_lm_corpus(args.data, args.dataset) 51 | ntokens = len(corpus.vocab) 52 | 53 | va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, 54 | device=device, ext_len=args.ext_len) 55 | te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, 56 | device=device, ext_len=args.ext_len) 57 | 58 | # Load the best saved model. 59 | with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f: 60 | model = torch.load(f) 61 | model.backward_compatible() 62 | model = model.to(device) 63 | 64 | logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( 65 | args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) 66 | 67 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len) 68 | if args.clamp_len > 0: 69 | model.clamp_len = args.clamp_len 70 | if args.same_length: 71 | model.same_length = True 72 | 73 | ############################################################################### 74 | # Evaluation code 75 | ############################################################################### 76 | def evaluate(eval_iter): 77 | # Turn on evaluation mode which disables dropout. 78 | model.eval() 79 | total_len, total_loss = 0, 0. 80 | start_time = time.time() 81 | with torch.no_grad(): 82 | mems = tuple() 83 | for idx, (data, target, seq_len) in enumerate(eval_iter): 84 | ret = model(data, target, *mems) 85 | loss, mems = ret[0], ret[1:] 86 | loss = loss.mean() 87 | total_loss += seq_len * loss.item() 88 | total_len += seq_len 89 | total_time = time.time() - start_time 90 | logging('Time : {:.2f}s, {:.2f}ms/segment'.format( 91 | total_time, 1000 * total_time / (idx+1))) 92 | return total_loss / total_len 93 | 94 | # Run on test data. 95 | if args.split == 'all': 96 | test_loss = evaluate(te_iter) 97 | valid_loss = evaluate(va_iter) 98 | elif args.split == 'valid': 99 | valid_loss = evaluate(va_iter) 100 | test_loss = None 101 | elif args.split == 'test': 102 | test_loss = evaluate(te_iter) 103 | valid_loss = None 104 | 105 | def format_log(loss, split): 106 | if args.dataset in ['enwik8', 'text8']: 107 | log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format( 108 | split, loss, loss / math.log(2)) 109 | else: 110 | log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( 111 | split, loss, math.exp(loss)) 112 | return log_str 113 | 114 | log_str = '' 115 | if valid_loss is not None: 116 | log_str += format_log(valid_loss, 'valid') 117 | if test_loss is not None: 118 | log_str += format_log(test_loss, 'test') 119 | 120 | logging('=' * 100) 121 | logging(log_str) 122 | logging('=' * 100) 123 | -------------------------------------------------------------------------------- /examples/transformer-xl/scripts/getdata.sh: -------------------------------------------------------------------------------- 1 | echo "=== Acquiring datasets ===" 2 | echo "---" 3 | 4 | mkdir -p ../data 5 | cd ../data 6 | 7 | if [[ ! -d 'wikitext-2' ]]; then 8 | echo "- Downloading WikiText-2 (WT2)" 9 | wget --quiet --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip 10 | unzip -q wikitext-2-v1.zip 11 | cd wikitext-2 12 | mv wiki.train.tokens train.txt 13 | mv wiki.valid.tokens valid.txt 14 | mv wiki.test.tokens test.txt 15 | cd .. 16 | fi 17 | 18 | echo "- Downloading WikiText-103 (WT2)" 19 | if [[ ! -d 'wikitext-103' ]]; then 20 | wget --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip 21 | unzip -q wikitext-103-v1.zip 22 | cd wikitext-103 23 | mv wiki.train.tokens train.txt 24 | mv wiki.valid.tokens valid.txt 25 | mv wiki.test.tokens test.txt 26 | cd .. 27 | fi 28 | 29 | echo "- Downloading enwik8 (Character)" 30 | if [[ ! -d 'enwik8' ]]; then 31 | mkdir -p enwik8 32 | cd enwik8 33 | wget --continue http://mattmahoney.net/dc/enwik8.zip 34 | wget https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py 35 | python3 prep_enwik8.py 36 | cd .. 37 | fi 38 | 39 | echo "- Downloading text8 (Character)" 40 | if [[ ! -d 'text8' ]]; then 41 | mkdir -p text8 42 | cd text8 43 | wget --continue http://mattmahoney.net/dc/text8.zip 44 | python ../../prep_text8.py 45 | cd .. 46 | fi 47 | 48 | echo "- Downloading Penn Treebank (PTB)" 49 | if [[ ! -d 'penn' ]]; then 50 | wget --quiet --continue http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz 51 | tar -xzf simple-examples.tgz 52 | 53 | mkdir -p penn 54 | cd penn 55 | mv ../simple-examples/data/ptb.train.txt train.txt 56 | mv ../simple-examples/data/ptb.test.txt test.txt 57 | mv ../simple-examples/data/ptb.valid.txt valid.txt 58 | cd .. 59 | 60 | echo "- Downloading Penn Treebank (Character)" 61 | mkdir -p pennchar 62 | cd pennchar 63 | mv ../simple-examples/data/ptb.char.train.txt train.txt 64 | mv ../simple-examples/data/ptb.char.test.txt test.txt 65 | mv ../simple-examples/data/ptb.char.valid.txt valid.txt 66 | cd .. 67 | 68 | rm -rf simple-examples/ 69 | fi 70 | 71 | echo "- Downloading 1B words" 72 | 73 | if [[ ! -d 'one-billion-words' ]]; then 74 | mkdir -p one-billion-words 75 | cd one-billion-words 76 | 77 | wget --no-proxy http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz 78 | tar xzvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz 79 | 80 | path="1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/" 81 | cat ${path}/news.en.heldout-00000-of-00050 > valid.txt 82 | cat ${path}/news.en.heldout-00000-of-00050 > test.txt 83 | 84 | wget https://github.com/rafaljozefowicz/lm/raw/master/1b_word_vocab.txt 85 | 86 | cd .. 87 | fi 88 | 89 | echo "---" 90 | echo "Happy language modeling :)" 91 | -------------------------------------------------------------------------------- /examples/transformer-xl/scripts/run_enwik8_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/enwik8/ \ 8 | --dataset enwik8 \ 9 | --n_layer 12 \ 10 | --d_model 512 \ 11 | --n_head 8 \ 12 | --d_head 64 \ 13 | --d_inner 2048 \ 14 | --dropout 0.1 \ 15 | --dropatt 0.0 \ 16 | --optim adam \ 17 | --lr 0.00025 \ 18 | --warmup_step 0 \ 19 | --max_step 400000 \ 20 | --tgt_len 512 \ 21 | --mem_len 512 \ 22 | --eval_tgt_len 128 \ 23 | --batch_size 22 \ 24 | --multi_gpu \ 25 | --gpu0_bsz 4 \ 26 | ${@:2} 27 | elif [[ $1 == 'eval' ]]; then 28 | echo 'Run evaluation...' 29 | python eval.py \ 30 | --cuda \ 31 | --data ../data/enwik8/ \ 32 | --dataset enwik8 \ 33 | --tgt_len 80 \ 34 | --mem_len 2100 \ 35 | --clamp_len 820 \ 36 | --same_length \ 37 | --split test \ 38 | ${@:2} 39 | else 40 | echo 'unknown argment 1' 41 | fi 42 | -------------------------------------------------------------------------------- /examples/transformer-xl/scripts/run_enwik8_base_moe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/enwik8/ \ 8 | --dataset enwik8 \ 9 | --n_layer 12 \ 10 | --d_model 512 \ 11 | --n_head 8 \ 12 | --d_head 64 \ 13 | --d_inner 1024 \ 14 | --dropout 0.1 \ 15 | --dropatt 0.0 \ 16 | --optim adam \ 17 | --lr 0.00025 \ 18 | --warmup_step 0 \ 19 | --max_step 400000 \ 20 | --tgt_len 512 \ 21 | --mem_len 512 \ 22 | --eval_tgt_len 128 \ 23 | --batch_size 22 \ 24 | --multi_gpu \ 25 | --gpu0_bsz 4 \ 26 | --moe --moe-num-expert 64 --moe-top-k 2 \ 27 | ${@:2} 28 | elif [[ $1 == 'eval' ]]; then 29 | echo 'Run evaluation...' 30 | python eval.py \ 31 | --cuda \ 32 | --data ../data/enwik8/ \ 33 | --dataset enwik8 \ 34 | --tgt_len 80 \ 35 | --mem_len 2100 \ 36 | --clamp_len 820 \ 37 | --same_length \ 38 | --split test \ 39 | ${@:2} 40 | else 41 | echo 'unknown argment 1' 42 | fi 43 | -------------------------------------------------------------------------------- /examples/transformer-xl/scripts/run_enwik8_large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/enwik8/ \ 8 | --dataset enwik8 \ 9 | --n_layer 24 \ 10 | --d_model 1024 \ 11 | --n_head 8 \ 12 | --d_head 128 \ 13 | --d_inner 3072 \ 14 | --dropout 0.15 \ 15 | --dropatt 0.15 \ 16 | --optim adam \ 17 | --lr 0.00025 \ 18 | --warmup_step 4000 \ 19 | --max_step 400000 \ 20 | --tgt_len 768 \ 21 | --mem_len 768 \ 22 | --eval_tgt_len 128 \ 23 | --batch_size 64 \ 24 | --multi_gpu \ 25 | --gpu0_bsz 0 \ 26 | ${@:2} 27 | elif [[ $1 == 'eval' ]]; then 28 | echo 'Run evaluation...' 29 | python eval.py \ 30 | --cuda \ 31 | --data ../data/enwik8/ \ 32 | --dataset enwik8 \ 33 | --tgt_len 128 \ 34 | --mem_len 3800 \ 35 | --clamp_len 1000 \ 36 | --same_length \ 37 | --split test \ 38 | ${@:2} 39 | else 40 | echo 'unknown argment 1' 41 | fi 42 | -------------------------------------------------------------------------------- /examples/transformer-xl/scripts/run_lm1b_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/one-billion-words/ \ 8 | --dataset lm1b \ 9 | --adaptive \ 10 | --n_layer 18 \ 11 | --d_model 1024 \ 12 | --div_val 4 \ 13 | --n_head 8 \ 14 | --d_head 128 \ 15 | --d_inner 4096 \ 16 | --dropout 0.0 \ 17 | --dropatt 0.0 \ 18 | --optim adam \ 19 | --warmup_step 20000 \ 20 | --max_step 500000 \ 21 | --lr 0.00025 \ 22 | --tgt_len 32 \ 23 | --mem_len 32 \ 24 | --eval_tgt_len 32 \ 25 | --batch_size 224 \ 26 | --multi_gpu \ 27 | --gpu0_bsz 32 \ 28 | ${@:2} 29 | elif [[ $1 == 'eval' ]]; then 30 | echo 'Run evaluation...' 31 | python eval.py \ 32 | --cuda \ 33 | --data ../data/one-billion-words/ \ 34 | --dataset lm1b \ 35 | --batch_size 64 \ 36 | --tgt_len 32 \ 37 | --mem_len 128 \ 38 | --split test \ 39 | --same_length \ 40 | ${@:2} 41 | else 42 | echo 'unknown argment 1' 43 | fi 44 | -------------------------------------------------------------------------------- /examples/transformer-xl/scripts/run_lm1b_large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/one-billion-words/ \ 8 | --dataset lm1b \ 9 | --adaptive \ 10 | --div_val 4 \ 11 | --n_layer 24 \ 12 | --d_model 1280 \ 13 | --n_head 16 \ 14 | --d_head 80 \ 15 | --d_inner 8192 \ 16 | --dropout 0.05 \ 17 | --dropatt 0.05 \ 18 | --optim adam \ 19 | --warmup_step 30000 \ 20 | --max_step 1200000 \ 21 | --lr 0.00025 \ 22 | --tgt_len 32 \ 23 | --mem_len 32 \ 24 | --eval_tgt_len 32 \ 25 | --batch_size 512 \ 26 | --multi_gpu \ 27 | --gpu0_bsz 0 \ 28 | ${@:2} 29 | elif [[ $1 == 'eval' ]]; then 30 | echo 'Run evaluation...' 31 | python eval.py \ 32 | --cuda \ 33 | --data ../data/one-billion-words/ \ 34 | --dataset lm1b \ 35 | --batch_size 8 \ 36 | --tgt_len 32 \ 37 | --mem_len 128 \ 38 | --split test \ 39 | --same_length \ 40 | ${@:2} 41 | else 42 | echo 'unknown argment 1' 43 | fi 44 | -------------------------------------------------------------------------------- /examples/transformer-xl/scripts/run_text8_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/text8/ \ 8 | --dataset text8 \ 9 | --n_layer 12 \ 10 | --d_model 512 \ 11 | --n_head 8 \ 12 | --d_head 64 \ 13 | --d_inner 2048 \ 14 | --dropout 0.1 \ 15 | --dropatt 0.0 \ 16 | --optim adam \ 17 | --lr 0.00025 \ 18 | --warmup_step 0 \ 19 | --max_step 400000 \ 20 | --tgt_len 512 \ 21 | --mem_len 512 \ 22 | --eval_tgt_len 128 \ 23 | --batch_size 22 \ 24 | --multi_gpu \ 25 | --gpu0_bsz 4 \ 26 | ${@:2} 27 | elif [[ $1 == 'eval' ]]; then 28 | echo 'Run evaluation...' 29 | python eval.py \ 30 | --cuda \ 31 | --data ../data/text8/ \ 32 | --dataset text8 \ 33 | --tgt_len 80 \ 34 | --mem_len 2100 \ 35 | --clamp_len 820 \ 36 | --same_length \ 37 | --split test \ 38 | ${@:2} 39 | else 40 | echo 'unknown argment 1' 41 | fi 42 | -------------------------------------------------------------------------------- /examples/transformer-xl/scripts/run_text8_large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/text8/ \ 8 | --dataset text8 \ 9 | --n_layer 24 \ 10 | --d_model 1024 \ 11 | --n_head 8 \ 12 | --d_head 128 \ 13 | --d_inner 3072 \ 14 | --dropout 0.15 \ 15 | --dropatt 0.15 \ 16 | --optim adam \ 17 | --lr 0.00025 \ 18 | --tgt_len 768 \ 19 | --mem_len 768 \ 20 | --eval_tgt_len 128 \ 21 | --batch_size 64 \ 22 | --max_step 400000 \ 23 | ${@:2} 24 | elif [[ $1 == 'eval' ]]; then 25 | echo 'Run evaluation...' 26 | python eval.py \ 27 | --cuda \ 28 | --data ../data/text8/ \ 29 | --dataset text8 \ 30 | --tgt_len 128 \ 31 | --mem_len 3800 \ 32 | --clamp_len 1000 \ 33 | --same_length \ 34 | --split test \ 35 | ${@:2} 36 | else 37 | echo 'unknown argment 1' 38 | fi 39 | -------------------------------------------------------------------------------- /examples/transformer-xl/scripts/run_wt103_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/wikitext-103/ \ 8 | --dataset wt103 \ 9 | --adaptive \ 10 | --n_layer 16 \ 11 | --d_model 410 \ 12 | --n_head 10 \ 13 | --d_head 41 \ 14 | --d_inner 2100 \ 15 | --dropout 0.1 \ 16 | --dropatt 0.0 \ 17 | --optim adam \ 18 | --lr 0.00025 \ 19 | --warmup_step 0 \ 20 | --max_step 200000 \ 21 | --tgt_len 150 \ 22 | --mem_len 150 \ 23 | --eval_tgt_len 150 \ 24 | --batch_size 60 \ 25 | --multi_gpu \ 26 | --gpu0_bsz 4 \ 27 | ${@:2} 28 | elif [[ $1 == 'eval' ]]; then 29 | echo 'Run evaluation...' 30 | python eval.py \ 31 | --cuda \ 32 | --data ../data/wikitext-103/ \ 33 | --dataset wt103 \ 34 | --tgt_len 64 \ 35 | --mem_len 640 \ 36 | --clamp_len 400 \ 37 | --same_length \ 38 | --split test \ 39 | ${@:2} 40 | else 41 | echo 'unknown argment 1' 42 | fi 43 | -------------------------------------------------------------------------------- /examples/transformer-xl/scripts/run_wt103_large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=$PWD/cuda/build/lib.linux-x86_64-3.7 3 | 4 | if [[ $1 == 'train' ]]; then 5 | echo 'Run training...' 6 | python3 train.py \ 7 | --cuda \ 8 | --data ../data/wikitext-103/ \ 9 | --dataset wt103 \ 10 | --adaptive \ 11 | --div_val 4 \ 12 | --n_layer 18 \ 13 | --d_model 1024 \ 14 | --n_head 16 \ 15 | --d_head 64 \ 16 | --d_inner 4096 \ 17 | --dropout 0.2 \ 18 | --dropatt 0.2 \ 19 | --optim adam \ 20 | --lr 0.00025 \ 21 | --warmup_step 16000 \ 22 | --max_step 4000000 \ 23 | --tgt_len 384 \ 24 | --mem_len 384 \ 25 | --eval_tgt_len 128 \ 26 | --batch_size 128 \ 27 | --multi_gpu \ 28 | --gpu0_bsz 0 \ 29 | ${@:2} 30 | elif [[ $1 == 'eval' ]]; then 31 | echo 'Run evaluation...' 32 | python eval.py \ 33 | --cuda \ 34 | --data ../data/wikitext-103/ \ 35 | --dataset wt103 \ 36 | --tgt_len 128 \ 37 | --mem_len 1600 \ 38 | --clamp_len 1000 \ 39 | --same_length \ 40 | --split test \ 41 | ${@:2} 42 | else 43 | echo 'unknown argment 1' 44 | fi 45 | -------------------------------------------------------------------------------- /examples/transformer-xl/utils/adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class AdaptiveLogSoftmax(nn.Module): 10 | def __init__(self, in_features, n_classes, cutoffs, keep_order=False): 11 | super(AdaptiveLogSoftmax, self).__init__() 12 | 13 | cutoffs = list(cutoffs) 14 | 15 | if (cutoffs != sorted(cutoffs)) \ 16 | or (min(cutoffs) <= 0) \ 17 | or (max(cutoffs) >= (n_classes - 1)) \ 18 | or (len(set(cutoffs)) != len(cutoffs)) \ 19 | or any([int(c) != c for c in cutoffs]): 20 | 21 | raise ValueError("cutoffs should be a sequence of unique, positive " 22 | "integers sorted in an increasing order, where " 23 | "each value is between 1 and n_classes-1") 24 | 25 | self.in_features = in_features 26 | self.n_classes = n_classes 27 | self.cutoffs = cutoffs + [n_classes] 28 | 29 | self.shortlist_size = self.cutoffs[0] 30 | self.n_clusters = len(self.cutoffs) - 1 31 | self.head_size = self.shortlist_size + self.n_clusters 32 | 33 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features)) 34 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 35 | 36 | self.keep_order = keep_order 37 | 38 | 39 | def forward(self, hidden, target, weight, bias, keep_order=False): 40 | if hidden.size(0) != target.size(0): 41 | raise RuntimeError('Input and target should have the same size ' 42 | 'in the batch dimension.') 43 | 44 | head_weight = torch.cat( 45 | [weight[:self.shortlist_size], self.cluster_weight], dim=0) 46 | head_bias = torch.cat( 47 | [bias[:self.shortlist_size], self.cluster_bias], dim=0) 48 | 49 | head_logit = F.linear(hidden, head_weight, bias=head_bias) 50 | head_logprob = F.log_softmax(head_logit, dim=1) 51 | 52 | nll = torch.zeros_like(target, 53 | dtype=hidden.dtype, device=hidden.device) 54 | 55 | offset = 0 56 | cutoff_values = [0] + self.cutoffs 57 | for i in range(len(cutoff_values) - 1): 58 | l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1] 59 | 60 | mask_i = (target >= l_idx) & (target < h_idx) 61 | indices_i = mask_i.nonzero().squeeze() 62 | 63 | if indices_i.numel() == 0: 64 | continue 65 | 66 | target_i = target.index_select(0, indices_i) - l_idx 67 | head_logprob_i = head_logprob.index_select(0, indices_i) 68 | 69 | if i == 0: 70 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 71 | else: 72 | weight_i = weight[l_idx:h_idx] 73 | bias_i = bias[l_idx:h_idx] 74 | 75 | hidden_i = hidden.index_select(0, indices_i) 76 | 77 | tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i) 78 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 79 | 80 | logprob_i = head_logprob_i[:, -i] \ 81 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 82 | 83 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 84 | nll.index_copy_(0, indices_i, -logprob_i) 85 | else: 86 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 87 | 88 | offset += logprob_i.size(0) 89 | 90 | return nll 91 | -------------------------------------------------------------------------------- /examples/transformer-xl/utils/data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.parallel import DataParallel 3 | import torch 4 | from torch.nn.parallel._functions import Scatter 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, torch.Tensor): 15 | try: 16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 17 | except: 18 | print('obj', obj.size()) 19 | print('dim', dim) 20 | print('chunk_sizes', chunk_sizes) 21 | quit() 22 | if isinstance(obj, tuple) and len(obj) > 0: 23 | return list(zip(*map(scatter_map, obj))) 24 | if isinstance(obj, list) and len(obj) > 0: 25 | return list(map(list, zip(*map(scatter_map, obj)))) 26 | if isinstance(obj, dict) and len(obj) > 0: 27 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 28 | return [obj for targets in target_gpus] 29 | 30 | # After scatter_map is called, a scatter_map cell will exist. This cell 31 | # has a reference to the actual function scatter_map, which has references 32 | # to a closure that has a reference to the scatter_map cell (because the 33 | # fn is recursive). To avoid this reference cycle, we set the function to 34 | # None, clearing the cell 35 | try: 36 | return scatter_map(inputs) 37 | finally: 38 | scatter_map = None 39 | 40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 41 | r"""Scatter with support for kwargs dictionary""" 42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 44 | if len(inputs) < len(kwargs): 45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 46 | elif len(kwargs) < len(inputs): 47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 48 | inputs = tuple(inputs) 49 | kwargs = tuple(kwargs) 50 | return inputs, kwargs 51 | 52 | class BalancedDataParallel(DataParallel): 53 | def __init__(self, gpu0_bsz, *args, **kwargs): 54 | self.gpu0_bsz = gpu0_bsz 55 | super().__init__(*args, **kwargs) 56 | 57 | def forward(self, *inputs, **kwargs): 58 | if not self.device_ids: 59 | return self.module(*inputs, **kwargs) 60 | if self.gpu0_bsz == 0: 61 | device_ids = self.device_ids[1:] 62 | else: 63 | device_ids = self.device_ids 64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 65 | if len(self.device_ids) == 1: 66 | return self.module(*inputs[0], **kwargs[0]) 67 | replicas = self.replicate(self.module, self.device_ids) 68 | if self.gpu0_bsz == 0: 69 | replicas = replicas[1:] 70 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 71 | return self.gather(outputs, self.output_device) 72 | 73 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 74 | return parallel_apply(replicas, inputs, kwargs, device_ids) 75 | 76 | def scatter(self, inputs, kwargs, device_ids): 77 | bsz = inputs[0].size(self.dim) 78 | num_dev = len(self.device_ids) 79 | gpu0_bsz = self.gpu0_bsz 80 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 81 | if gpu0_bsz < bsz_unit: 82 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 83 | delta = bsz - sum(chunk_sizes) 84 | for i in range(delta): 85 | chunk_sizes[i + 1] += 1 86 | if gpu0_bsz == 0: 87 | chunk_sizes = chunk_sizes[1:] 88 | else: 89 | return super().scatter(inputs, kwargs, device_ids) 90 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 91 | 92 | -------------------------------------------------------------------------------- /examples/transformer-xl/utils/exp_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os, shutil 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | 9 | def logging(s, log_path, print_=True, log_=True): 10 | if print_: 11 | print(s) 12 | if log_: 13 | with open(log_path, 'a+') as f_log: 14 | f_log.write(s + '\n') 15 | 16 | def get_logger(log_path, **kwargs): 17 | return functools.partial(logging, log_path=log_path, **kwargs) 18 | 19 | def create_exp_dir(dir_path, scripts_to_save=None, debug=False): 20 | if debug: 21 | print('Debug Mode : no experiment dir created') 22 | return functools.partial(logging, log_path=None, log_=False) 23 | 24 | if not os.path.exists(dir_path): 25 | os.makedirs(dir_path) 26 | 27 | print('Experiment dir : {}'.format(dir_path)) 28 | if scripts_to_save is not None: 29 | script_path = os.path.join(dir_path, 'scripts') 30 | if not os.path.exists(script_path): 31 | os.makedirs(script_path) 32 | for script in scripts_to_save: 33 | dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) 34 | shutil.copyfile(script, dst_file) 35 | 36 | return get_logger(log_path=os.path.join(dir_path, 'log.txt')) 37 | 38 | def save_checkpoint(model, optimizer, path, epoch): 39 | torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch))) 40 | torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch))) 41 | -------------------------------------------------------------------------------- /examples/transformer-xl/utils/log_uniform_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | class LogUniformSampler(object): 6 | def __init__(self, range_max, n_sample): 7 | """ 8 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 9 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 10 | 11 | expected count can be approximated by 1 - (1 - p)^n 12 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 13 | 14 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 15 | """ 16 | with torch.no_grad(): 17 | self.range_max = range_max 18 | log_indices = torch.arange(1., range_max+2., 1.).log_() 19 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 20 | # print('P', self.dist.numpy().tolist()[-30:]) 21 | 22 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 23 | 24 | self.n_sample = n_sample 25 | 26 | def sample(self, labels): 27 | """ 28 | labels: [b1, b2] 29 | Return 30 | true_log_probs: [b1, b2] 31 | samp_log_probs: [n_sample] 32 | neg_samples: [n_sample] 33 | """ 34 | 35 | # neg_samples = torch.empty(0).long() 36 | n_sample = self.n_sample 37 | n_tries = 2 * n_sample 38 | 39 | with torch.no_grad(): 40 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 41 | device = labels.device 42 | neg_samples = neg_samples.to(device) 43 | true_log_probs = self.log_q[labels].to(device) 44 | samp_log_probs = self.log_q[neg_samples].to(device) 45 | return true_log_probs, samp_log_probs, neg_samples 46 | 47 | def sample_logits(embedding, bias, labels, inputs, sampler): 48 | """ 49 | embedding: an nn.Embedding layer 50 | bias: [n_vocab] 51 | labels: [b1, b2] 52 | inputs: [b1, b2, n_emb] 53 | sampler: you may use a LogUniformSampler 54 | Return 55 | logits: [b1, b2, 1 + n_sample] 56 | """ 57 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 58 | n_sample = neg_samples.size(0) 59 | b1, b2 = labels.size(0), labels.size(1) 60 | all_ids = torch.cat([labels.view(-1), neg_samples]) 61 | all_w = embedding(all_ids) 62 | true_w = all_w[: -n_sample].view(b1, b2, -1) 63 | sample_w = all_w[- n_sample:].view(n_sample, -1) 64 | 65 | all_b = bias[all_ids] 66 | true_b = all_b[: -n_sample].view(b1, b2) 67 | sample_b = all_b[- n_sample:] 68 | 69 | hit = (labels[:, :, None] == neg_samples).detach() 70 | 71 | true_logits = torch.einsum('ijk,ijk->ij', 72 | [true_w, inputs]) + true_b - true_log_probs 73 | sample_logits = torch.einsum('lk,ijk->ijl', 74 | [sample_w, inputs]) + sample_b - samp_log_probs 75 | sample_logits.masked_fill_(hit, -1e30) 76 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 77 | 78 | return logits 79 | 80 | 81 | # class LogUniformSampler(object): 82 | # def __init__(self, range_max, unique=False): 83 | # """ 84 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 85 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 86 | # """ 87 | # self.range_max = range_max 88 | # log_indices = torch.arange(1., range_max+2., 1.).log_() 89 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 90 | 91 | # self.unique = unique 92 | 93 | # if self.unique: 94 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0) 95 | 96 | # def sample(self, n_sample, labels): 97 | # pos_sample, new_labels = labels.unique(return_inverse=True) 98 | # n_pos_sample = pos_sample.size(0) 99 | # n_neg_sample = n_sample - n_pos_sample 100 | 101 | # if self.unique: 102 | # self.exclude_mask.index_fill_(0, pos_sample, 1) 103 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) 104 | # self.exclude_mask.index_fill_(0, pos_sample, 0) 105 | # else: 106 | # sample_dist = self.dist 107 | 108 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample) 109 | 110 | # sample = torch.cat([pos_sample, neg_sample]) 111 | # sample_prob = self.dist[sample] 112 | 113 | # return new_labels, sample, sample_prob 114 | 115 | 116 | if __name__ == '__main__': 117 | S, B = 3, 4 118 | n_vocab = 10000 119 | n_sample = 5 120 | H = 32 121 | 122 | labels = torch.LongTensor(S, B).random_(0, n_vocab) 123 | 124 | # sampler = LogUniformSampler(n_vocab, unique=False) 125 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) 126 | 127 | sampler = LogUniformSampler(n_vocab, unique=True) 128 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) 129 | 130 | # print('true_probs', true_probs.numpy().tolist()) 131 | # print('samp_probs', samp_probs.numpy().tolist()) 132 | # print('neg_samples', neg_samples.numpy().tolist()) 133 | 134 | # print('sum', torch.sum(sampler.dist).item()) 135 | 136 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() 137 | 138 | embedding = nn.Embedding(n_vocab, H) 139 | bias = torch.zeros(n_vocab) 140 | inputs = torch.Tensor(S, B, H).normal_() 141 | 142 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) 143 | print('logits', logits.detach().numpy().tolist()) 144 | print('logits shape', logits.size()) 145 | print('out_labels', out_labels.detach().numpy().tolist()) 146 | print('out_labels shape', out_labels.size()) 147 | 148 | -------------------------------------------------------------------------------- /examples/transformer-xl/utils/vocabulary.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter, OrderedDict 3 | 4 | import torch 5 | 6 | class Vocab(object): 7 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, 8 | delimiter=None, vocab_file=None): 9 | self.counter = Counter() 10 | self.special = special 11 | self.min_freq = min_freq 12 | self.max_size = max_size 13 | self.lower_case = lower_case 14 | self.delimiter = delimiter 15 | self.vocab_file = vocab_file 16 | 17 | def tokenize(self, line, add_eos=False, add_double_eos=False): 18 | line = line.strip() 19 | # convert to lower case 20 | if self.lower_case: 21 | line = line.lower() 22 | 23 | # empty delimiter '' will evaluate False 24 | if self.delimiter == '': 25 | symbols = line 26 | else: 27 | symbols = line.split(self.delimiter) 28 | 29 | if add_double_eos: # lm1b 30 | return [''] + symbols + [''] 31 | elif add_eos: 32 | return symbols + [''] 33 | else: 34 | return symbols 35 | 36 | def count_file(self, path, verbose=False, add_eos=False): 37 | if verbose: print('counting file {} ...'.format(path)) 38 | assert os.path.exists(path) 39 | 40 | sents = [] 41 | with open(path, 'r', encoding='utf-8') as f: 42 | for idx, line in enumerate(f): 43 | if verbose and idx > 0 and idx % 500000 == 0: 44 | print(' line {}'.format(idx)) 45 | symbols = self.tokenize(line, add_eos=add_eos) 46 | self.counter.update(symbols) 47 | sents.append(symbols) 48 | 49 | return sents 50 | 51 | def count_sents(self, sents, verbose=False): 52 | """ 53 | sents : a list of sentences, each a list of tokenized symbols 54 | """ 55 | if verbose: print('counting {} sents ...'.format(len(sents))) 56 | for idx, symbols in enumerate(sents): 57 | if verbose and idx > 0 and idx % 500000 == 0: 58 | print(' line {}'.format(idx)) 59 | self.counter.update(symbols) 60 | 61 | def _build_from_file(self, vocab_file): 62 | self.idx2sym = [] 63 | self.sym2idx = OrderedDict() 64 | 65 | with open(vocab_file, 'r', encoding='utf-8') as f: 66 | for line in f: 67 | symb = line.strip().split()[0] 68 | self.add_symbol(symb) 69 | self.unk_idx = self.sym2idx[''] 70 | 71 | def build_vocab(self): 72 | if self.vocab_file: 73 | print('building vocab from {}'.format(self.vocab_file)) 74 | self._build_from_file(self.vocab_file) 75 | print('final vocab size {}'.format(len(self))) 76 | else: 77 | print('building vocab with min_freq={}, max_size={}'.format( 78 | self.min_freq, self.max_size)) 79 | self.idx2sym = [] 80 | self.sym2idx = OrderedDict() 81 | 82 | for sym in self.special: 83 | self.add_special(sym) 84 | 85 | for sym, cnt in self.counter.most_common(self.max_size): 86 | if cnt < self.min_freq: break 87 | self.add_symbol(sym) 88 | 89 | print('final vocab size {} from {} unique tokens'.format( 90 | len(self), len(self.counter))) 91 | 92 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True, 93 | add_double_eos=False): 94 | if verbose: print('encoding file {} ...'.format(path)) 95 | assert os.path.exists(path) 96 | encoded = [] 97 | with open(path, 'r', encoding='utf-8') as f: 98 | for idx, line in enumerate(f): 99 | if verbose and idx > 0 and idx % 500000 == 0: 100 | print(' line {}'.format(idx)) 101 | symbols = self.tokenize(line, add_eos=add_eos, 102 | add_double_eos=add_double_eos) 103 | encoded.append(self.convert_to_tensor(symbols)) 104 | 105 | if ordered: 106 | encoded = torch.cat(encoded) 107 | 108 | return encoded 109 | 110 | def encode_sents(self, sents, ordered=False, verbose=False): 111 | if verbose: print('encoding {} sents ...'.format(len(sents))) 112 | encoded = [] 113 | for idx, symbols in enumerate(sents): 114 | if verbose and idx > 0 and idx % 500000 == 0: 115 | print(' line {}'.format(idx)) 116 | encoded.append(self.convert_to_tensor(symbols)) 117 | 118 | if ordered: 119 | encoded = torch.cat(encoded) 120 | 121 | return encoded 122 | 123 | def add_special(self, sym): 124 | if sym not in self.sym2idx: 125 | self.idx2sym.append(sym) 126 | self.sym2idx[sym] = len(self.idx2sym) - 1 127 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) 128 | 129 | def add_symbol(self, sym): 130 | if sym not in self.sym2idx: 131 | self.idx2sym.append(sym) 132 | self.sym2idx[sym] = len(self.idx2sym) - 1 133 | 134 | def get_sym(self, idx): 135 | assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) 136 | return self.idx2sym[idx] 137 | 138 | def get_idx(self, sym): 139 | if sym in self.sym2idx: 140 | return self.sym2idx[sym] 141 | else: 142 | # print('encounter unk {}'.format(sym)) 143 | assert '' not in sym 144 | assert hasattr(self, 'unk_idx') 145 | return self.sym2idx.get(sym, self.unk_idx) 146 | 147 | def get_symbols(self, indices): 148 | return [self.get_sym(idx) for idx in indices] 149 | 150 | def get_indices(self, symbols): 151 | return [self.get_idx(sym) for sym in symbols] 152 | 153 | def convert_to_tensor(self, symbols): 154 | return torch.LongTensor(self.get_indices(symbols)) 155 | 156 | def convert_to_sent(self, indices, exclude=None): 157 | if exclude is None: 158 | return ' '.join([self.get_sym(idx) for idx in indices]) 159 | else: 160 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) 161 | 162 | def __len__(self): 163 | return len(self.idx2sym) 164 | -------------------------------------------------------------------------------- /fmoe/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The fmoe package contains MoE Layers only. 3 | """ 4 | 5 | from .layers import FMoE 6 | from .linear import FMoELinear 7 | from .transformer import FMoETransformerMLP 8 | from .distributed import DistributedGroupedDataParallel 9 | -------------------------------------------------------------------------------- /fmoe/balance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | metrics = { 5 | "coefficient-variation": lambda c_e: torch.std(c_e) / torch.mean(c_e), 6 | "Lmax-over-Lmin": lambda c_e: (torch.max(c_e) + 1) / (torch.min(c_e) + 1), 7 | "Lmax-over-Lmean": lambda c_e: torch.max(c_e) / torch.mean(c_e), 8 | } 9 | 10 | 11 | def reset_balance_profile(balance_dict, num_layers, balance_strategy): 12 | for key in metrics: 13 | balance_dict[key] = [None for _ in range(num_layers)] 14 | if balance_strategy: 15 | balance_dict[f"{balance_strategy}_loss"] = [None for _ in range(num_layers)] 16 | 17 | 18 | def update_balance_profile( 19 | balance_dict, 20 | gate_top_k_idx, 21 | _gate_score_top_k, 22 | gate_context, 23 | layer_idx, 24 | num_expert, 25 | balance_strategy, 26 | ): 27 | # Fill in this function to conduct balance related jobs 28 | pass 29 | -------------------------------------------------------------------------------- /fmoe/distributed.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Supportive modules to conduct distributed training 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | from .utils import get_torch_default_comm, get_rank_0_in_comm 8 | 9 | 10 | class DistributedGroupedDataParallel(nn.Module): 11 | r""" 12 | A customized DDP module to support different all-reduce regions in the 13 | model. The all-reduce region is defined as an attribution `dp_comm` in the 14 | weight object. 15 | The grads of the weights are identified to be reduced in different groups 16 | according to the weigths' `dp_comm` attribute. 17 | If it is set to `dp`, it will only be reduced across the data-parallel 18 | groups, which means that in the model parallel group, they are not 19 | synchronized. 20 | If it is set to `world`, the gradients is synchronized across all workers, 21 | regardless their model or data parallel group. This is extremely useful for 22 | shared layers like the gate. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | module, 28 | auto_allreduce=False, 29 | need_sync=True, 30 | **kwargs 31 | ): 32 | assert not auto_allreduce, "Automatic all-reduce is not implemented yet" 33 | 34 | super().__init__() 35 | self.module = module 36 | 37 | self.comms = dict() 38 | for k in kwargs: 39 | if k.endswith('_group'): 40 | self.comms[k[:-6]] = kwargs[k] 41 | for k in ['dp', 'gate', 'moe', 'world']: 42 | if k not in self.comms: 43 | self.comms[k] = get_torch_default_comm() 44 | 45 | def allreduce_gradients(no_scale=False, 46 | reduce_after=False, fp32_allreduce=False): 47 | groups = dict() 48 | for p in self.module.parameters(): 49 | if not p.requires_grad or p.grad is None: 50 | continue 51 | if hasattr(p, "dp_comm"): 52 | dp_comm = p.dp_comm 53 | else: 54 | dp_comm = "dp" 55 | group_key = (dp_comm, p.dtype) 56 | if group_key not in groups: 57 | groups[group_key] = [p] 58 | else: 59 | groups[group_key].append(p) 60 | for (dp_comm, dtype), group in groups.items(): 61 | if dp_comm not in self.comms: 62 | continue 63 | comm = self.comms[dp_comm] 64 | grads = [p.grad.data for p in group] 65 | coalesced = _flatten_dense_tensors(grads) 66 | if fp32_allreduce and dtype != torch.float32: 67 | coalesced = coalesced.float() 68 | if not no_scale and not reduce_after: 69 | coalesced /= comm.size() 70 | torch.distributed.all_reduce(coalesced, group=comm) 71 | if not no_scale and reduce_after: 72 | coalesced /= comm.size() 73 | synced = _unflatten_dense_tensors(coalesced, grads) 74 | for g, s in zip(grads, synced): 75 | g.copy_(s) 76 | 77 | def allreduce_params(*args, **kwargs): 78 | return allreduce_gradients(*args, **kwargs) 79 | 80 | self.allreduce_gradients = allreduce_gradients 81 | self.allreduce_params = allreduce_params 82 | if need_sync: 83 | self._sync_params() 84 | 85 | def _sync_params(self): 86 | groups = dict() 87 | for p in self.module.parameters(): 88 | if hasattr(p, "dp_comm"): 89 | dp_comm = p.dp_comm 90 | else: 91 | dp_comm = "dp" 92 | group_key = (dp_comm, p.dtype) 93 | if group_key not in groups: 94 | groups[group_key] = [p] 95 | else: 96 | groups[group_key].append(p) 97 | for (dp_comm, _), group in groups.items(): 98 | if dp_comm not in self.comms: 99 | continue 100 | comm = self.comms[dp_comm] 101 | datas = [p.data for p in group] 102 | coalesced = _flatten_dense_tensors(datas) 103 | torch.distributed.broadcast(coalesced, 104 | get_rank_0_in_comm(comm), group=comm) 105 | torch.cuda.synchronize() 106 | synced = _unflatten_dense_tensors(coalesced, datas) 107 | for d, s in zip(datas, synced): 108 | d.copy_(s) 109 | 110 | def forward(self, *args, **kwargs): 111 | r""" 112 | Directly call the module's forward function. 113 | """ 114 | return self.module(*args, **kwargs) 115 | -------------------------------------------------------------------------------- /fmoe/fastermoe/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laekov/fastmoe/55af4f98eee087cf5b3aac34318abf80c3bcbafd/fmoe/fastermoe/__init__.py -------------------------------------------------------------------------------- /fmoe/fastermoe/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def float_from_env(key, default=-1): 5 | if key in os.environ: 6 | return float(os.environ[key]) 7 | return default 8 | 9 | 10 | def switch_from_env(key, default=False): 11 | if key in os.environ: 12 | return os.environ[key] in ['1', 'ON'] 13 | return default 14 | -------------------------------------------------------------------------------- /fmoe/fastermoe/expert_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_expert_param_size(e, idx): 5 | e = e[idx] 6 | return sum(map(lambda x: x.numel(), e.parameters())) 7 | 8 | 9 | def get_expert_params(e, out, idx): 10 | e = e[idx] 11 | offset = 0 12 | for n, p in e.named_parameters(): 13 | seg = out[offset:offset + p.numel()] 14 | offset += p.numel() 15 | seg.copy_(p.data.flatten()) 16 | 17 | 18 | def stash_expert_params(e, params, idx): 19 | e = e[idx] 20 | if not hasattr(e, 'expert_param_stash'): 21 | setattr(e, 'expert_param_stash', dict()) 22 | setattr(e, 'expert_grad_stash', dict()) 23 | offset = 0 24 | for n, p in e.named_parameters(): 25 | if n not in e.expert_param_stash: 26 | e.expert_param_stash[n] = p.data.clone() 27 | e.expert_grad_stash[n] = p.grad.clone() if p.grad is not None else None 28 | with torch.no_grad(): 29 | seg = params[offset:offset + p.numel()] 30 | offset += p.numel() 31 | p.copy_(seg.reshape(p.shape)) 32 | p.grad = None 33 | 34 | 35 | def pop_expert_params(e, idx): 36 | e = e[idx] 37 | if not hasattr(e, 'expert_param_stash'): 38 | return 39 | if not e.expert_param_stash: 40 | return 41 | for n, p in e.named_parameters(): 42 | with torch.no_grad(): 43 | p.copy_(e.expert_param_stash[n]) 44 | if e.expert_grad_stash[n] is not None: 45 | p.grad = e.expert_grad_stash[n].clone() 46 | e.expert_param_stash.clear() 47 | e.expert_grad_stash.clear() 48 | 49 | 50 | def collect_expert_grads(e, grads, idx): 51 | e = e[idx] 52 | offset = 0 53 | for _, p in e.named_parameters(): 54 | seg = grads[offset:offset + p.numel()] 55 | offset += p.numel() 56 | if p.grad is not None: 57 | seg.copy_(p.grad.flatten()) 58 | p.grad = None 59 | else: 60 | seg.zero_() 61 | 62 | 63 | def set_grads(e, grads, idx): 64 | e = e[idx] 65 | offset = 0 66 | for n, p in e.named_parameters(): 67 | seg = grads[offset:offset + p.numel()] 68 | offset += p.numel() 69 | if p.grad is None: 70 | p.grad = seg.clone().reshape(p.shape) 71 | else: 72 | p.grad += seg.reshape(p.shape) 73 | -------------------------------------------------------------------------------- /fmoe/fastermoe/schedule.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The smart schedule proposed in FasterMoE. 3 | """ 4 | import torch 5 | from torch.autograd.function import Function 6 | 7 | from fmoe.functions import prepare_forward, ensure_comm 8 | from fmoe.functions import _local_scatter, _local_gather 9 | import fmoe_cuda as fmoe_native 10 | from fmoe.fastermoe import expert_utils 11 | 12 | from .shadow_policy import get_shadow_policy 13 | 14 | 15 | class MoEForward(Function): 16 | @staticmethod 17 | def forward( 18 | ctx, 19 | expert_fn, 20 | experts, 21 | inp, # models, 22 | pos_s, pos_g, 23 | local_expert_count, global_expert_count, 24 | stored_models, 25 | fwd_batch_size, out_batch_size, 26 | num_expert, 27 | world_size): 28 | local_input_buf = _local_scatter(inp, pos_s) 29 | 30 | ctx.gibs = [None] * (world_size * num_expert * 2) 31 | ctx.gobs = [None] * (world_size * num_expert * 2) 32 | def _expert_forward(x, y, expert_idx, store_idx): 33 | nothing = lambda a: a 34 | x = x.data 35 | with torch.enable_grad(): 36 | x.requires_grad = True 37 | try: 38 | # To skip torch autograd's version check. 39 | with torch.autograd.graph.saved_tensors_hooks(nothing, nothing): 40 | y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64), expert_idx) 41 | except Exception as e: 42 | # Ignore the error and fall back for compatibility to older 43 | # versions of PyTorch 44 | y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64), expert_idx) 45 | ctx.gibs[store_idx] = x 46 | ctx.gobs[store_idx] = y0 47 | y.copy_(y0) 48 | 49 | ctx.experts = experts 50 | if stored_models.any(): 51 | ctx.expert_size = expert_utils.get_expert_param_size(experts, 0) 52 | for i in range(num_expert): 53 | assert ctx.expert_size == expert_utils.get_expert_param_size(experts, i), "report bug" 54 | else: 55 | ctx.expert_size = 0 56 | get_param_fn = lambda out, idx: expert_utils.get_expert_params(experts, out, idx) 57 | pop_fn = lambda idx: expert_utils.pop_expert_params(experts, idx) 58 | ctx.shadows = [None] * world_size * num_expert 59 | def stash_fn(params, store_idx, expert_idx): 60 | expert_utils.stash_expert_params(experts, params, expert_idx) 61 | ctx.shadows[store_idx] = params 62 | 63 | local_output_buf, gib = fmoe_native.smart_sch_forward( 64 | local_input_buf, 65 | local_expert_count, global_expert_count, 66 | stored_models, fwd_batch_size, ctx.expert_size, 67 | world_size, _expert_forward, get_param_fn, stash_fn, pop_fn) 68 | 69 | out = _local_gather(local_output_buf, pos_g, out_batch_size, 70 | maybe_overlap=False) 71 | 72 | # gib and local_input_buf are necessary, because ctx.gibs are created 73 | # based on their memory 74 | variables = (pos_s, pos_g, local_expert_count, global_expert_count, 75 | stored_models, gib, local_input_buf) 76 | 77 | ctx.moe_args = fwd_batch_size, inp.shape[0], num_expert, world_size 78 | ctx.save_for_backward(*variables) 79 | 80 | return out 81 | 82 | @staticmethod 83 | def backward(ctx, grad_out): 84 | (pos_s, pos_g, local_expert_count, global_expert_count, 85 | stored_models, _1, _2) = ctx.saved_tensors 86 | (fwd_batch_size, inp_batch_size, num_expert, world_size) = ctx.moe_args 87 | 88 | def _expert_backward(grad_y, grad_x, expert_idx, store_idx): 89 | y = ctx.gobs[store_idx] 90 | x = ctx.gibs[store_idx] 91 | torch.autograd.backward([y], [grad_y]) 92 | grad_x.copy_(x.grad) 93 | 94 | experts = ctx.experts 95 | def stash_fn(store_idx, expert_idx): 96 | expert_utils.stash_expert_params(experts, ctx.shadows[store_idx], expert_idx) 97 | pop_fn = lambda idx: expert_utils.pop_expert_params(experts, idx) 98 | def collect_fn(store_idx, root, expert_idx): 99 | grad = ctx.shadows[store_idx] 100 | expert_utils.collect_expert_grads(experts, grad, expert_idx) 101 | fmoe_native.reduce_grad(grad, root, ctx.expert_size) 102 | set_grad_fn = lambda store_idx, expert_idx: expert_utils.set_grads(experts, ctx.shadows[store_idx], expert_idx) 103 | 104 | grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g) 105 | grad_in_buf = fmoe_native.smart_sch_backward( 106 | grad_out_buf, 107 | local_expert_count, global_expert_count, 108 | stored_models, 109 | pos_s.shape[0], fwd_batch_size, 110 | world_size, 111 | _expert_backward, stash_fn, pop_fn, collect_fn, set_grad_fn) 112 | grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size) 113 | 114 | return (None, None, grad_in, None, None, None, None, None, None, None, None, None) 115 | 116 | 117 | policy_fn = None 118 | 119 | 120 | def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None, stored_models=None): 121 | # TODO: Using multiple tensors as input is to be supported. 122 | assert(isinstance(inp, torch.Tensor)) 123 | ( 124 | pos, 125 | local_expert_count, 126 | global_expert_count, 127 | fwd_expert_count, 128 | fwd_batch_size, 129 | ) = prepare_forward(gate, n_expert, world_size) 130 | 131 | global policy_fn 132 | if policy_fn is None: 133 | policy_fn = get_shadow_policy(d_model=inp.shape[-1]) 134 | 135 | if stored_models is None: 136 | stored_models = policy_fn(local_expert_count, global_expert_count, 137 | n_expert, world_size, inp.device) 138 | 139 | topk = 1 140 | if len(gate.shape) == 2: 141 | topk = gate.shape[1] 142 | out_batch_size = inp.shape[0] * topk 143 | 144 | return MoEForward.apply(expert_fn, experts, inp, 145 | torch.div(pos, topk, rounding_mode='floor'), pos, 146 | local_expert_count, global_expert_count, stored_models, 147 | fwd_batch_size, out_batch_size, n_expert, world_size) 148 | -------------------------------------------------------------------------------- /fmoe/fastermoe/shadow_policy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | 6 | from .config import float_from_env, switch_from_env 7 | from fmoe.functions import get_moe_group 8 | 9 | 10 | def global_policy(local_expert_count, _gec, num_expert, world_size, device): 11 | r""" 12 | This is the policy for two-layer MLPs, using the formula in the PPoPP paper. 13 | A few parameters are used in this policy. 14 | * `d_model`: feature length of the MLP input and output. 15 | * `alpha`: the ratio of the MLP's hidden size to `d_model`. 16 | * `bw_net`: bandwidth of the network (GBps) 17 | * `bw_mm`: computation throughput of performing GeMM (FLOPs) 18 | """ 19 | bw_net = float_from_env('FMOE_FASTER_GLBPLC_NETBW', 50 * 1e9 / 8) 20 | bw_mm = float_from_env('FMOE_FASTER_GLBPLC_GPUTP', 11.5e12) 21 | alpha = float_from_env('FMOE_FASTER_GLBPLC_ALPHA', 2) 22 | d_model = float_from_env('FMOE_FASTER_GLBPLC_DMODEL', 2048) 23 | 24 | moe_group = get_moe_group() 25 | local_expert_count = local_expert_count.to(device) 26 | agecs = [torch.empty_like(local_expert_count) for _ in range(world_size)] 27 | dist.all_gather(agecs, local_expert_count, group=moe_group) 28 | all_global_expert_count = torch.stack(agecs) 29 | 30 | # TODO: data type other than float 31 | data_size = 4 32 | 33 | fwd_expert_counts = all_global_expert_count.sum(1).cpu() 34 | B_ws, indices = fwd_expert_counts.flatten().sort(0, descending=True) 35 | 36 | alphaH2 = alpha * (d_model ** 2) 37 | B_w = B_ws[0] 38 | 39 | comm = float('+inf') 40 | send_feature_time = d_model * data_size / bw_net 41 | send_model_time = 2 * alphaH2 * data_size / bw_net 42 | comp_time = 4 * alphaH2 / bw_mm 43 | lat_base = 3 * comp_time * B_w + 4 * send_feature_time * B_w 44 | 45 | res = torch.zeros(world_size * num_expert, dtype=torch.bool) 46 | shadow_time = 0 47 | 48 | for i, index in enumerate(indices): 49 | if i + 1 == indices.numel(): 50 | break 51 | B_k = B_ws[i + 1] 52 | shadow_time += send_model_time 53 | lat_new = 3 * comp_time * B_k + 4 * send_feature_time * B_k + shadow_time 54 | 55 | if lat_new < lat_base: 56 | lat_base = lat_new 57 | res[index] = True 58 | else: 59 | break 60 | return res 61 | 62 | 63 | def no_shadow_policy(_lec, _gec, num_expert, world_size, device): 64 | res = torch.zeros(world_size * num_expert, dtype=bool) 65 | return res 66 | 67 | 68 | def get_shadow_policy(d_model=None): 69 | if d_model is not None and 'FMOE_FASTER_GLBPLC_DMODEL' not in os.environ: 70 | os.environ['FMOE_FASTER_GLBPLC_DMODEL'] = str(d_model) 71 | if not switch_from_env('FMOE_FASTER_SHADOW_ENABLE'): 72 | return no_shadow_policy 73 | return global_policy 74 | -------------------------------------------------------------------------------- /fmoe/gates/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Different implementations of the Gate are located in separate files here. 3 | """ 4 | from .zero_gate import ZeroGate 5 | from .naive_gate import NaiveGate 6 | from .noisy_gate import NoisyGate 7 | 8 | from .gshard_gate import GShardGate 9 | from .switch_gate import SwitchGate 10 | from .dc_gate import DCGate 11 | 12 | from .swipe_gate import SwipeGate 13 | -------------------------------------------------------------------------------- /fmoe/gates/base_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Base gate with standard interface 3 | """ 4 | import torch.nn as nn 5 | 6 | 7 | class BaseGate(nn.Module): 8 | def __init__(self, num_expert, world_size): 9 | super().__init__() 10 | self.world_size = world_size 11 | self.num_expert = num_expert 12 | self.tot_expert = world_size * num_expert 13 | self.loss = None 14 | 15 | def forward(self, x): 16 | raise NotImplementedError('Base gate cannot be directly used for fwd') 17 | 18 | def set_loss(self, loss): 19 | self.loss = loss 20 | 21 | def get_loss(self, clear=True): 22 | loss = self.loss 23 | if clear: 24 | self.loss = None 25 | return loss 26 | 27 | @property 28 | def has_loss(self): 29 | return self.loss is not None 30 | -------------------------------------------------------------------------------- /fmoe/gates/dc_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Distributed Capacity gate, extended from GShard gate. 3 | Instead of setting capacity based on local batch size and expert count, 4 | the global load of each experts are calculated, and then the experts make 5 | decisions of capacities on each worker. 6 | """ 7 | import math 8 | import torch 9 | import torch.nn.functional as F 10 | from .naive_gate import NaiveGate 11 | from .utils import limit_by_capacity 12 | 13 | 14 | class DCGate(NaiveGate): 15 | def __init__(self, d_model, num_expert, world_size, 16 | topk=2, capacity=(1.2, 2.4), random_routing=True, gate_bias=True): 17 | assert topk == 2, 'topk should be 2 in gshard' 18 | super().__init__(d_model, num_expert, world_size, top_k=2, gate_bias=gate_bias) 19 | self.capacity = capacity 20 | self.random_routing = random_routing 21 | 22 | def forward(self, x): 23 | naive_outs = super().forward(x, return_all_scores=True) 24 | topk_idx, topk_val, gate_score = naive_outs 25 | 26 | S = gate_score.shape[0] 27 | top1_idx = topk_idx.view((-1, self.top_k))[:, 0] 28 | c_e = torch.scatter_add( 29 | torch.zeros(self.tot_expert, device=top1_idx.device), 30 | 0, 31 | top1_idx, 32 | torch.ones_like(top1_idx, dtype=torch.float), 33 | ) / S 34 | m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0) 35 | loss = torch.mean(c_e * m_e) * (self.num_expert ** 2) 36 | self.set_loss(loss) 37 | 38 | cap_rate = self.capacity[0 if self.training else 1] 39 | capacity = math.ceil(cap_rate * x.shape[0]) 40 | _new_lec, _new_gec, topk_idx = limit_by_capacity( 41 | topk_idx, self.num_expert, self.world_size, capacity) 42 | 43 | if self.random_routing: 44 | rand_routing_prob = torch.rand(gate_score.size(0), device=x.device) 45 | mask = (2 * topk_val[:, 1] < rand_routing_prob) 46 | topk_idx[:, 1].masked_fill_(mask, -1) 47 | 48 | return topk_idx, topk_val 49 | -------------------------------------------------------------------------------- /fmoe/gates/faster_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The example topology-aware gate for two-layer tree-like topology, proposed by 3 | the PPoPP'22 paper, FasterMoE. Limited number of tokens are sent across the 4 | upper-level slow connection, and other ones are re-directed to experts in the 5 | local network. 6 | 7 | The number of GPUs to form such a local network is defined by an environment 8 | variable `FMOE_TOPO_GPUS_PER_NODE`, and it is by default `8`. 9 | 10 | The fraction of tokens that are allowed to be sent across nodes is defined by 11 | an environement variable `FMOE_TOPO_OUTGOING_FRACTION`, and it is by default 12 | `0.14`. Users are supposed to set the proper value in their own environemnt, 13 | guided by some performance model, to achieve maximum throughput. 14 | """ 15 | from .naive_gate import NaiveGate 16 | 17 | import os 18 | import sys 19 | import torch 20 | import torch.nn.functional as F 21 | from .utils import limit_by_capacity 22 | import fmoe_cuda 23 | from fmoe.functions import count_by_gate 24 | 25 | 26 | nw_per_node = 8 27 | try: 28 | nw_per_node = int(os.environ['FMOE_TOPO_GPUS_PER_NODE']) 29 | except Exception: 30 | pass 31 | 32 | 33 | class FasterGate(NaiveGate): 34 | def __init__(self, d_model, n_expert, world_size, node_rank, gate_bias=True): 35 | super().__init__(d_model, n_expert, world_size, top_k=2, gate_bias=gate_bias) 36 | self.ne_per_node = nw_per_node * n_expert 37 | self.ogn_ratio = .14 38 | try: 39 | self.ogn_ratio = float(os.environ['FMOE_TOPO_OUTGOING_FRACTION']) 40 | except Exception: 41 | pass 42 | self.node_rank = node_rank 43 | 44 | mask = [1] * world_size * n_expert 45 | for i in range(n_expert * world_size): 46 | if i // self.ne_per_node == self.node_rank: 47 | mask[i] = 0 48 | self.mask = torch.Tensor(mask).bool() 49 | self.policy_fn = None 50 | print('node rank {} mask {}'.format(node_rank, mask)) 51 | 52 | def forward(self, inp): 53 | if self.mask.device != inp.device: 54 | self.mask = self.mask.to(inp.device) 55 | 56 | gate_score = self.gate(inp) 57 | lim_mask = self.mask 58 | 59 | top2_val, top2_idx = torch.topk(gate_score, k=2, dim=-1) 60 | S = gate_score.shape[0] 61 | top_k = 2 62 | 63 | with torch.no_grad(): 64 | top1_idx = top2_idx.view((-1, top_k))[:, 0] 65 | top1_val = top2_val.view((-1, top_k))[:, 0] 66 | c_e = torch.scatter_add( 67 | torch.zeros(self.tot_expert, device=top1_idx.device), 68 | 0, 69 | top1_idx, 70 | torch.ones_like(top1_idx, dtype=torch.float), 71 | ) / S 72 | m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0) 73 | loss = torch.mean(c_e * m_e) * (self.num_expert ** 2) 74 | self.set_loss(loss) 75 | 76 | with torch.no_grad(): 77 | if self.policy_fn is None: 78 | stored_models = torch.zeros(self.num_expert * self.world_size, 79 | dtype=torch.bool) 80 | else: 81 | # TODO: Fix this after expert shadowing is ported 82 | _, lec, aec, gec, agec = count_by_gate(top2_idx, 83 | self.num_expert, self.world_size, require_pos=False) 84 | stored_models = self.policy_fn(aec, agec, 85 | self.num_expert, self.world_size, inp.shape[-1], True) 86 | lim_mask = lim_mask & ~stored_models.view(-1).to(lim_mask.device) 87 | 88 | ogn_mask = lim_mask[top1_idx] 89 | ogn_thres = int(inp.shape[0] * self.ogn_ratio) 90 | 91 | if ogn_mask.sum().item() < ogn_thres: 92 | topk_val, topk_idx = torch.topk(gate_score, k=self.top_k) 93 | topk_val = F.softmax(topk_val, dim=-1) 94 | return topk_idx, topk_val 95 | 96 | with torch.no_grad(): 97 | top1_val[~ogn_mask] = float('-inf') 98 | _, top_ogn = torch.topk(top1_val.view(-1), k=ogn_thres) 99 | cand = gate_score.clone() 100 | cand[:, lim_mask] = float('-inf') 101 | _, topk_idx = torch.topk(cand, k=self.top_k) 102 | topk_idx[top_ogn, 1] = top1_idx.view(-1)[top_ogn] 103 | 104 | idx_x = torch.arange(inp.shape[0], device=inp.device).repeat_interleave(2) 105 | topk_val = gate_score[idx_x, topk_idx.view(-1)].view(-1, self.top_k) 106 | 107 | topk_val = F.softmax(topk_val, dim=-1) 108 | 109 | return topk_idx, topk_val 110 | 111 | 112 | def gen_faster_gate(rank): 113 | def _gen(d_model, n_expert, world_size, top_k=2): 114 | assert top_k == 2 115 | return FasterGate(d_model, n_expert, world_size, rank // nw_per_node) 116 | return _gen 117 | -------------------------------------------------------------------------------- /fmoe/gates/gshard_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Balanced gate with GShard's policy (Google, 2020) 3 | """ 4 | import math 5 | import torch 6 | import torch.nn.functional as F 7 | from .naive_gate import NaiveGate 8 | from .utils import limit_by_capacity 9 | import fmoe_cuda as fmoe_native 10 | 11 | 12 | class GShardGate(NaiveGate): 13 | def __init__(self, d_model, num_expert, world_size, 14 | topk=2, capacity=(1.2, 2.4), random_routing=True, gate_bias=True): 15 | assert topk == 2, 'topk should be 2 in gshard' 16 | super().__init__(d_model, num_expert, world_size, top_k=2, gate_bias=gate_bias) 17 | self.capacity = capacity 18 | self.random_routing = random_routing 19 | 20 | def forward(self, x): 21 | naive_outs = super().forward(x, return_all_scores=True) 22 | topk_idx, topk_val, gate_score = naive_outs 23 | 24 | S = gate_score.shape[0] 25 | top1_idx = topk_idx.view((-1, self.top_k))[:, 0] 26 | c_e = torch.scatter_add( 27 | torch.zeros(self.tot_expert, device=top1_idx.device), 28 | 0, 29 | top1_idx, 30 | torch.ones_like(top1_idx, dtype=torch.float), 31 | ) / S 32 | m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0) 33 | loss = torch.mean(c_e * m_e) * (self.num_expert ** 2) 34 | self.set_loss(loss) 35 | 36 | cap_rate = self.capacity[0 if self.training else 1] 37 | capacity = math.ceil(cap_rate * x.shape[0]) 38 | capacity = capacity * self.top_k // (self.world_size * self.num_expert) 39 | capacity = torch.ones(self.num_expert * self.world_size, 40 | dtype=torch.int32, device=topk_idx.device) * capacity 41 | topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx, capacity, 42 | self.num_expert, self.world_size) 43 | 44 | if self.random_routing: 45 | rand_routing_prob = torch.rand(gate_score.size(0), device=x.device) 46 | mask = (2 * topk_val[:, 1] < rand_routing_prob) 47 | topk_idx[:, 1].masked_fill_(mask, -1) 48 | 49 | return topk_idx, topk_val 50 | -------------------------------------------------------------------------------- /fmoe/gates/naive_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Naive gate 3 | """ 4 | from .base_gate import BaseGate 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class NaiveGate(BaseGate): 12 | r""" 13 | A naive gate implementation that defines the standard behavior of the gate 14 | which determines which experts the tokens are going to. 15 | Both the indicies and the score, or confidence, are output to the parent 16 | module. 17 | The load-balance strategies are also designed to be implemented within the 18 | `Gate` module. 19 | """ 20 | 21 | def __init__(self, d_model, num_expert, world_size, top_k=2, gate_bias=True): 22 | super().__init__(num_expert, world_size) 23 | self.gate = nn.Linear(d_model, self.tot_expert, bias = gate_bias) 24 | self.top_k = top_k 25 | 26 | def forward(self, inp, return_all_scores=False): 27 | r""" 28 | The naive implementation simply calculates the top-k of a linear layer's 29 | output. 30 | """ 31 | gate = self.gate(inp) 32 | gate_top_k_val, gate_top_k_idx = torch.topk( 33 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 34 | ) # [.. x top_k] 35 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) 36 | 37 | # (BxL) x 1 x top_k 38 | gate_score = F.softmax(gate_top_k_val, dim=-1) 39 | 40 | # dummy loss 41 | self.set_loss(torch.zeros(1, requires_grad=True).to(inp.device)) 42 | 43 | if return_all_scores: 44 | return gate_top_k_idx, gate_score, gate 45 | return gate_top_k_idx, gate_score 46 | -------------------------------------------------------------------------------- /fmoe/gates/noisy_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Noisy gate for gshard and switch 3 | """ 4 | from .base_gate import BaseGate 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.distributions.normal import Normal 10 | import math 11 | 12 | 13 | class NoisyGate(BaseGate): 14 | def __init__(self, d_model, num_expert, world_size, top_k=2): 15 | super().__init__(num_expert, world_size) 16 | self.w_gate = nn.Parameter( 17 | torch.zeros(d_model, self.tot_expert), requires_grad=True 18 | ) 19 | self.w_noise = nn.Parameter( 20 | torch.zeros(d_model, self.tot_expert), requires_grad=True 21 | ) 22 | self.top_k = top_k 23 | self.softplus = nn.Softplus() 24 | self.softmax = nn.Softmax(1) 25 | 26 | self.noise_epsilon = 1e-2 27 | 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | # Approach is the same as in torch.nn.Linear 32 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88 33 | 34 | torch.nn.init.kaiming_uniform_(self.w_gate, a=math.sqrt(5)) 35 | torch.nn.init.kaiming_uniform_(self.w_noise, a=math.sqrt(5)) 36 | 37 | 38 | def _gates_to_load(self, gates): 39 | """Compute the true load per expert, given the gates. 40 | The load is the number of examples for which the corresponding gate is >0. 41 | Args: 42 | gates: a `Tensor` of shape [batch_size, n] 43 | Returns: 44 | a float32 `Tensor` of shape [n] 45 | """ 46 | return (gates > 0).sum(0) 47 | 48 | def _prob_in_top_k( 49 | self, clean_values, noisy_values, noise_stddev, noisy_top_values 50 | ): 51 | """Helper function to NoisyTopKGating. 52 | Computes the probability that value is in top k, given different random noise. 53 | This gives us a way of backpropagating from a loss that balances the number 54 | of times each expert is in the top k experts per example. 55 | In the case of no noise, pass in None for noise_stddev, and the result will 56 | not be differentiable. 57 | Args: 58 | clean_values: a `Tensor` of shape [batch, n]. 59 | noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus 60 | normally distributed noise with standard deviation noise_stddev. 61 | noise_stddev: a `Tensor` of shape [batch, n], or None 62 | noisy_top_values: a `Tensor` of shape [batch, m]. 63 | "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 64 | Returns: 65 | a `Tensor` of shape [batch, n]. 66 | """ 67 | 68 | batch = clean_values.size(0) 69 | m = noisy_top_values.size(1) 70 | top_values_flat = noisy_top_values.flatten() 71 | threshold_positions_if_in = ( 72 | torch.arange(batch, device=clean_values.device) * m + self.top_k 73 | ) 74 | threshold_if_in = torch.unsqueeze( 75 | torch.gather(top_values_flat, 0, threshold_positions_if_in), 1 76 | ) 77 | is_in = torch.gt(noisy_values, threshold_if_in) 78 | threshold_positions_if_out = threshold_positions_if_in - 1 79 | threshold_if_out = torch.unsqueeze( 80 | torch.gather(top_values_flat, 0, threshold_positions_if_out), 1 81 | ) 82 | # is each value currently in the top k. 83 | normal = Normal( 84 | torch.tensor([0.0], device=clean_values.device), 85 | torch.tensor([1.0], device=clean_values.device), 86 | ) 87 | prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev) 88 | prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev) 89 | prob = torch.where(is_in, prob_if_in, prob_if_out) 90 | return prob 91 | 92 | def cv_squared(self, x): 93 | """The squared coefficient of variation of a sample. 94 | Useful as a loss to encourage a positive distribution to be more uniform. 95 | Epsilons added for numerical stability. 96 | Returns 0 for an empty Tensor. 97 | Args: 98 | x: a `Tensor`. 99 | Returns: 100 | a `Scalar`. 101 | """ 102 | eps = 1e-10 103 | # if only num_expert = 1 104 | if x.shape[0] == 1: 105 | return torch.Tensor([0]) 106 | return x.float().var() / (x.float().mean() ** 2 + eps) 107 | 108 | def forward(self, inp): 109 | clean_logits = inp @ self.w_gate 110 | raw_noise_stddev = inp @ self.w_noise 111 | noise_stddev = ( 112 | self.softplus(raw_noise_stddev) + self.noise_epsilon 113 | ) * self.training 114 | noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) 115 | logits = noisy_logits 116 | 117 | # calculate topk + 1 that will be needed for the noisy gates 118 | top_logits, top_indices = logits.topk( 119 | min(self.top_k + 1, self.tot_expert), dim=1 120 | ) 121 | top_k_logits = top_logits[:, : self.top_k] 122 | top_k_indices = top_indices[:, : self.top_k] 123 | top_k_gates = self.softmax(top_k_logits) 124 | 125 | zeros = torch.zeros_like(logits, requires_grad=True) 126 | gates = zeros.scatter(1, top_k_indices, top_k_gates) 127 | 128 | if self.top_k < self.tot_expert: 129 | load = ( 130 | self._prob_in_top_k( 131 | clean_logits, noisy_logits, noise_stddev, top_logits 132 | ) 133 | ).sum(0) 134 | else: 135 | load = self._gates_to_load(gates) 136 | 137 | importance = gates.sum(0) 138 | loss = self.cv_squared(importance) + self.cv_squared(load) 139 | self.set_loss(loss) 140 | 141 | return ( 142 | top_k_indices.contiguous().view(-1), 143 | top_k_gates.contiguous().unsqueeze(1), 144 | ) 145 | -------------------------------------------------------------------------------- /fmoe/gates/swipe_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Balanced gate using SWIPE algorithm 3 | """ 4 | import math 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from .naive_gate import NaiveGate 10 | 11 | from fmoe.functions import count_by_gate 12 | import fmoe_cuda as fmoe_native 13 | 14 | 15 | class SwipeGate(NaiveGate): 16 | def __init__(self, d_model, num_expert, world_size, top_k=2, gate_bias=True): 17 | super().__init__(d_model, num_expert, world_size, top_k, gate_bias) 18 | 19 | def swipe_once(self, idx, capacity, bias): 20 | with torch.no_grad(): 21 | idx_new, capacity = fmoe_native.swipe_once(idx, capacity, 22 | self.num_expert, self.world_size, bias) 23 | idx_new = idx_new.to(idx.device) 24 | return idx_new, capacity 25 | 26 | 27 | def forward(self, inp): 28 | score = self.gate(inp) 29 | orig_score, orig_idx = torch.topk(score, k=self.top_k, dim=-1) 30 | 31 | if not self.training: 32 | topk_val = F.softmax(orig_score, dim=-1) 33 | return orig_idx, topk_val 34 | 35 | capacity = torch.scalar_tensor(inp.shape[0] * self.top_k, 36 | dtype=torch.long) 37 | 38 | topk_idxs = [] 39 | topk_vals = [] 40 | idx_x = torch.arange(inp.shape[0], device=inp.device) 41 | for k in range(self.top_k): 42 | idx, capacity = self.swipe_once(orig_idx[:, k], capacity, 43 | k % self.num_expert) 44 | topk_vals.append(score[idx_x, idx]) 45 | topk_idxs.append(idx) 46 | topk_idx = torch.stack(topk_idxs).transpose(0, 1) 47 | topk_val = torch.stack(topk_vals).transpose(0, 1) 48 | topk_val = F.softmax(topk_val, dim=-1) 49 | return topk_idx, topk_val 50 | -------------------------------------------------------------------------------- /fmoe/gates/switch_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Balanced gate with Switch Transformer's policy (Google, 2021) 3 | """ 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from .naive_gate import NaiveGate 9 | from .utils import limit_by_capacity 10 | 11 | 12 | class SwitchGate(NaiveGate): 13 | r""" 14 | A switch gate implementation 15 | """ 16 | 17 | def __init__(self, d_model, num_expert, world_size, topk=1, 18 | switch_eps=.1, capacity=(1.2, 2.4), gate_bias=True): 19 | assert topk == 1, 'topk should be 1 in switch' 20 | super().__init__(d_model, num_expert, world_size, top_k=1, gate_bias=gate_bias) 21 | self.switch_eps = switch_eps 22 | self.capacity = capacity 23 | 24 | def forward(self, inp): 25 | r""" 26 | The switch firstly conduct softmax and then calculates the top-1 27 | """ 28 | score = self.gate(inp) 29 | 30 | if self.training: 31 | # random uniform number from [1-eps, 1+eps] 32 | noise = torch.rand_like(score) 33 | noise = noise * 2 * self.switch_eps + 1.0 - self.switch_eps 34 | score += noise 35 | 36 | # fp32 softmax for numerical stability 37 | score = F.softmax(score.float(), dim=-1) 38 | 39 | top1_score, top1_idx = torch.topk( 40 | score, k=1, dim=-1, largest=True 41 | ) # [.. x top_k] 42 | top1_score = top1_score.to(dtype=inp.dtype) 43 | 44 | cap_rate = self.capacity[0 if self.training else 1] 45 | capacity = math.ceil(cap_rate * inp.shape[0] / self.num_expert) 46 | _new_lec, _new_gec, top1_idx = limit_by_capacity( 47 | top1_idx, self.num_expert, self.world_size, capacity) 48 | 49 | valid_idx = top1_idx[top1_idx > -1] 50 | fraction_expert = torch.scatter_add( 51 | torch.zeros(self.tot_expert, device=valid_idx.device), 52 | 0, 53 | valid_idx, 54 | torch.ones_like(valid_idx, dtype=torch.float), 55 | ) / valid_idx.numel() 56 | prob_expert = score.sum(dim=0) / valid_idx.numel() 57 | loss = (fraction_expert * prob_expert).sum() * self.tot_expert 58 | self.set_loss(loss) 59 | return top1_idx, top1_score 60 | -------------------------------------------------------------------------------- /fmoe/gates/utils.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Utilities that may be used in the gates 3 | """ 4 | import torch 5 | from fmoe.functions import count_by_gate 6 | import fmoe_cuda as fmoe_native 7 | 8 | 9 | def limit_by_capacity(topk_idx, num_expert, world_size, capacity): 10 | with torch.no_grad(): 11 | capacity = torch.ones(num_expert, dtype=torch.int32, 12 | device=topk_idx.device) * capacity 13 | 14 | pos, lec, gec = count_by_gate(topk_idx, num_expert, world_size, 15 | require_pos=False) 16 | new_gec = fmoe_native.limit_by_capacity(gec, capacity, 17 | num_expert, world_size) 18 | if world_size > 1: 19 | new_lec = fmoe_native.expert_exchange(new_gec, num_expert, 20 | world_size) 21 | else: 22 | new_lec = new_gec 23 | 24 | topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx, 25 | new_lec.to(torch.int32), num_expert, world_size) 26 | return new_lec, new_gec, topk_idx 27 | -------------------------------------------------------------------------------- /fmoe/gates/zero_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Zero gate that direct all input to gate 0 3 | """ 4 | from .base_gate import BaseGate 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class ZeroGate(BaseGate): 12 | r""" 13 | Guide all input samples to gate 0. 14 | """ 15 | 16 | def __init__(self, _1, num_expert, world_size, top_k=2): 17 | super().__init__(num_expert, world_size) 18 | self.top_k = top_k 19 | 20 | def forward(self, inp): 21 | r""" 22 | All output to expert 1 23 | """ 24 | idx = torch.zeros( 25 | inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device 26 | ) 27 | gate_score = ( 28 | torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k 29 | ) 30 | return idx, gate_score.reshape(-1, 1, self.top_k) 31 | -------------------------------------------------------------------------------- /fmoe/linear.py: -------------------------------------------------------------------------------- 1 | r""" 2 | FMoE's parallel linear layer 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | import math 8 | 9 | import fmoe_cuda 10 | 11 | 12 | class MOELinear(Function): 13 | r""" 14 | Computes linear operators within one GPU on different experts simutaneously. 15 | """ 16 | 17 | @staticmethod 18 | def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None): 19 | global_output_buf = fmoe_cuda.linear_forward( 20 | global_input_buf, fwd_expert_count, weight, bias 21 | ) 22 | variables = (global_input_buf, fwd_expert_count, weight, bias) 23 | ctx.save_for_backward(*variables) 24 | return global_output_buf 25 | 26 | @staticmethod 27 | def backward(ctx, grad_out): 28 | (input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors 29 | grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward( 30 | grad_out, input_buf, fwd_expert_count, weight, bias 31 | ) 32 | 33 | if not torch.is_tensor(bias): 34 | grad_bias = None 35 | 36 | return grad_inp_buf, None, grad_weight, grad_bias 37 | 38 | 39 | 40 | class FMoELinear(nn.Module): 41 | r""" 42 | A linear layer that contains multiple experts. 43 | As multiple experts can be placed on the same worker, the computation can be 44 | performed in parallel to increase the performance. 45 | The FMoELinear module provides such function. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | num_expert: int, 51 | in_feat: int, 52 | out_feat: int, 53 | bias: bool = True, 54 | rank: int = 0, 55 | ): 56 | super().__init__() 57 | self.num_expert = num_expert 58 | self.in_feat = in_feat 59 | self.out_feat = out_feat 60 | self.rank = rank 61 | self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat)) 62 | if bias: 63 | self.bias = nn.Parameter(torch.zeros(num_expert, out_feat)) 64 | else: 65 | self.register_parameter("bias", None) 66 | 67 | self.reset_parameters() 68 | 69 | def forward(self, inp, fwd_expert_count): 70 | r""" 71 | Call MOE function 72 | """ 73 | x = MOELinear.apply(inp.type_as(self.weight), fwd_expert_count, self.weight, self.bias) 74 | return x 75 | 76 | def extra_repr(self) -> str: 77 | return "num_expert={}, in_features={}, \ 78 | out_features={}, bias={}, rank={}".format( 79 | self.num_expert, 80 | self.in_feat, 81 | self.out_feat, 82 | self.bias is not None, 83 | self.rank, 84 | ) 85 | 86 | def reset_parameters(self): 87 | # Approach is the same as in torch.nn.Linear 88 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88 89 | # bias is left to zero, similar as megatron 90 | 91 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 92 | 93 | -------------------------------------------------------------------------------- /fmoe/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | A set of modules to plugin into Megatron-LM with FastMoE 3 | """ 4 | from .utils import add_fmoe_args 5 | 6 | from .layers import MegatronMLP 7 | from .layers import fmoefy 8 | 9 | from .checkpoint import save_checkpoint 10 | from .checkpoint import load_checkpoint 11 | 12 | from .distributed import DistributedDataParallel 13 | 14 | from .balance import reset_gate_hook 15 | from .balance import get_balance_profile 16 | from .balance import generate_megatron_gate_hook 17 | from .balance import add_balance_log 18 | 19 | from .patch import patch_forward_step 20 | from .patch import patch_model_provider 21 | -------------------------------------------------------------------------------- /fmoe/megatron/balance.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Support for monitoring loss in Megatron 3 | """ 4 | import torch 5 | from fmoe.balance import reset_balance_profile 6 | from fmoe.balance import update_balance_profile 7 | from fmoe.utils import get_torch_default_comm 8 | 9 | 10 | balance_dict = {} 11 | num_layers = 0 12 | 13 | 14 | def reset_gate_hook(_num_layers=None): 15 | from megatron import get_args 16 | 17 | global balance_dict, num_layers 18 | if _num_layers is not None: 19 | num_layers = _num_layers 20 | reset_balance_profile(balance_dict, num_layers, get_args().balance_strategy) 21 | 22 | 23 | def get_balance_profile(): 24 | global balance_dict 25 | return balance_dict 26 | 27 | 28 | def generate_megatron_gate_hook(layer_idx, num_expert_global): 29 | from megatron import get_args 30 | 31 | balance_strategy = get_args().balance_strategy 32 | 33 | def megatron_gate_hook(gate_top_k_idx, gate_score_top_k, gate_context): 34 | global balance_dict 35 | update_balance_profile( 36 | balance_dict, 37 | gate_top_k_idx, 38 | gate_score_top_k, 39 | gate_context, 40 | layer_idx, 41 | num_expert_global, 42 | balance_strategy, 43 | ) 44 | 45 | return megatron_gate_hook 46 | 47 | 48 | def add_balance_log(model, writer, iteration): 49 | r""" 50 | Note that this function does not work with pipeline parallelism 51 | """ 52 | from megatron import is_last_rank 53 | 54 | while hasattr(model, 'module'): 55 | model = model.module 56 | 57 | losses = [l.mlp.gate.get_loss(clear=True) 58 | for l in model.language_model.transformer.layers 59 | if l.mlp.gate.has_loss] 60 | if len(losses) == 0: 61 | return 62 | balance_dict_tensor = torch.vstack(losses).detach() 63 | world_group = get_torch_default_comm() 64 | world_size = torch.distributed.get_world_size(group=world_group) 65 | torch.distributed.all_reduce(balance_dict_tensor, group=world_group) 66 | balance_dict_tensor /= world_size 67 | 68 | if writer and is_last_rank(): 69 | for idx, metric_name in enumerate(balance_dict): 70 | for layer_id, val in enumerate(balance_dict_tensor[idx]): 71 | writer.add_scalar( 72 | f"balance-{metric_name}/layer-{layer_id}", val.item(), iteration 73 | ) 74 | writer.add_scalar( 75 | f"balance-{metric_name}/all", 76 | balance_dict_tensor[idx].mean().item(), 77 | iteration, 78 | ) 79 | -------------------------------------------------------------------------------- /fmoe/megatron/distributed.py: -------------------------------------------------------------------------------- 1 | r""" 2 | distributed support for Megatron 3 | """ 4 | import torch 5 | 6 | from fmoe.distributed import DistributedGroupedDataParallel 7 | 8 | 9 | _groups = None 10 | 11 | 12 | def _set_groups(**kwargs): 13 | global _groups 14 | _groups = kwargs 15 | 16 | 17 | def get_moe_group(): 18 | return _groups["moe_group"] 19 | 20 | 21 | def _init(): 22 | from megatron import get_args 23 | from megatron import mpu 24 | args = get_args() 25 | 26 | # Create a comm prependicular to the pipeline group as gate group 27 | stage_size = args.world_size // args.pipeline_model_parallel_size 28 | for i in range(0, args.world_size, stage_size): 29 | ranks = range(i, i + stage_size) 30 | group = torch.distributed.new_group(ranks) 31 | if args.rank in ranks: 32 | gate_group = group 33 | 34 | _set_groups( 35 | dp_group=mpu.get_data_parallel_group(), 36 | moe_group=mpu.get_data_parallel_group(), 37 | gate_group=gate_group) 38 | 39 | 40 | class DistributedDataParallel(DistributedGroupedDataParallel): 41 | r""" 42 | A wrapper that is used to replace the DDP module provided by Megatron, which 43 | is adapted to enable the sophiscated parallel and reduction strategies in 44 | Fast MoE. 45 | """ 46 | 47 | def __init__(self, module, accumulate_allreduce_grads_in_fp32=False, use_contiguous_buffers_in_ddp=False): 48 | assert not accumulate_allreduce_grads_in_fp32, "FastMoE not supports accumulate_allrecude_grads_in_fp32" 49 | assert not use_contiguous_buffers_in_ddp, "FastMoE not supports use_contiguous_buffers_in_ddp" 50 | 51 | if _groups is None: 52 | _init() 53 | super().__init__(module, **_groups) 54 | 55 | def set_input_tensor(self, *args, **kwargs): 56 | r""" 57 | Keep consitency with Megatron 58 | """ 59 | return self.module.set_input_tensor(*args, **kwargs) 60 | 61 | def state_dict(self, *args, **kwargs): 62 | r""" 63 | Keep consitency with Megatron 64 | """ 65 | return self.module.state_dict(*args, **kwargs) 66 | 67 | def state_dict_for_save_checkpoint(self, *args, **kwargs): 68 | r""" 69 | Keep consitency with Megatron 70 | """ 71 | return self.module.state_dict_for_save_checkpoint(*args, **kwargs) 72 | 73 | def load_state_dict(self, *args, **kwargs): 74 | r""" 75 | Keep consitency with Megatron 76 | """ 77 | return self.module.load_state_dict(*args, **kwargs) 78 | -------------------------------------------------------------------------------- /fmoe/megatron/utils.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Utility in Megatron 3 | """ 4 | 5 | import argparse 6 | 7 | def add_fmoe_args(parser): 8 | group = parser.add_argument_group(title="fastmoe") 9 | 10 | group.add_argument("--fmoefy", action="store_true") 11 | try: 12 | group.add_argument("--num-experts", type=int, default=None) 13 | except argparse.ArgumentError: 14 | group.add_argument("--fmoe-num-experts", type=int, default=None) 15 | group.add_argument("--top-k", type=int, default=2) 16 | group.add_argument("--balance-loss-weight", type=float, default=1) 17 | group.add_argument("--balance-strategy", type=str, default=None) 18 | group.add_argument("--hidden-hidden-size", type=int, default=None) 19 | 20 | return parser 21 | -------------------------------------------------------------------------------- /fmoe/transformer.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Adaption to act as the MLP layer using an MoE MLP layer in transformer. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from .layers import FMoE 7 | from .linear import FMoELinear 8 | from .fastermoe.config import switch_from_env 9 | 10 | 11 | class _Expert(nn.Module): 12 | r""" 13 | An expert using 2 FMoELinear modules to speed up the computation of experts 14 | within one worker. 15 | """ 16 | 17 | def __init__(self, num_expert, d_model, d_hidden, activation, rank=0): 18 | super().__init__() 19 | self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank) 20 | self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank) 21 | self.activation = activation 22 | 23 | def forward(self, inp, fwd_expert_count): 24 | r""" 25 | First expand input to 4h (the hidden size is variable, but is called h4 26 | for convenience). Then perform activation. Finally shirink back to h. 27 | """ 28 | x = self.htoh4(inp, fwd_expert_count) 29 | x = self.activation(x) 30 | x = self.h4toh(x, fwd_expert_count) 31 | return x 32 | 33 | 34 | class FMoETransformerMLP(FMoE): 35 | r""" 36 | A complete MoE MLP module in a Transformer block. 37 | * `activation` is the activation function to be used in MLP in each expert. 38 | * `d_hidden` is the dimension of the MLP layer. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | num_expert=32, 44 | d_model=1024, 45 | d_hidden=4096, 46 | activation=torch.nn.GELU(), 47 | expert_dp_comm="none", 48 | expert_rank=0, 49 | **kwargs 50 | ): 51 | def one_expert(d_model): 52 | return _Expert(1, d_model, d_hidden, activation, rank=0) 53 | 54 | expert = one_expert 55 | super().__init__(num_expert=num_expert, d_model=d_model, expert=expert, **kwargs) 56 | self.mark_parallel_comm(expert_dp_comm) 57 | 58 | def forward(self, inp: torch.Tensor): 59 | r""" 60 | This module wraps up the FMoE module with reshape, residual and layer 61 | normalization. 62 | """ 63 | original_shape = inp.shape 64 | inp = inp.reshape(-1, self.d_model) 65 | output = super().forward(inp) 66 | return output.reshape(original_shape) 67 | -------------------------------------------------------------------------------- /fmoe/utils.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Utils to play with PyTorch. 3 | """ 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | # pylint: disable=broad-except 9 | # pylint: disable=protected-access 10 | def get_torch_default_comm(): 11 | r""" 12 | The NCCL communicator is needed so that Fast MoE can perform customized 13 | communication operators in the C code. However, it is not a publicly 14 | available variable. Therefore, a hacking class of the `ProcessGroupNCCL` 15 | in Fast MoE's C code takes the `_default_pg` and tries to dig the 16 | communicator out from the object. As PyTorch's private interface varies from 17 | time to time, different hacking techniques are tried one-by-one to be 18 | compatible with various versions of PyTorch. 19 | """ 20 | try: 21 | comm = dist.distributed_c10d._get_default_group() 22 | return comm 23 | except Exception as _: 24 | pass 25 | try: 26 | comm = dist.distributed_c10d._default_pg 27 | if comm is not None: 28 | return comm 29 | except Exception as _: 30 | pass 31 | raise RuntimeError("Unsupported PyTorch version") 32 | 33 | 34 | def get_rank_0_in_comm(comm): 35 | world_size = dist.get_world_size(comm) 36 | x = torch.tensor([dist.get_rank()], dtype=torch.int64, device='cuda') 37 | ys = [torch.empty_like(x) for _ in range(world_size)] 38 | dist.all_gather(ys, x, group=comm) 39 | root_rank = ys[0].item() 40 | return root_rank 41 | 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | ninja 4 | dm-tree 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | import os 4 | import torch 5 | 6 | cxx_flags = [] 7 | ext_libs = [] 8 | 9 | authors = [ 10 | 'Jiaao He', 11 | 'Jiezhong Qiu', 12 | 'Aohan Zeng', 13 | 'Tiago Antunes', 14 | 'Jinjun Peng', 15 | 'Qin Li', 16 | 'Mingshu Zhai' 17 | ] 18 | 19 | is_rocm_pytorch = False 20 | if torch.__version__ >= '1.5': 21 | from torch.utils.cpp_extension import ROCM_HOME 22 | is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False 23 | 24 | if os.environ.get('USE_NCCL', '1') == '1': 25 | cxx_flags.append('-DFMOE_USE_NCCL') 26 | cxx_flags.append('-DUSE_C10D_NCCL') 27 | if is_rocm_pytorch: 28 | ext_libs.append('rccl') 29 | else: 30 | ext_libs.append('nccl') 31 | 32 | if os.environ.get('MOE_DEBUG', '0') == '1': 33 | cxx_flags.append('-DMOE_DEBUG') 34 | 35 | if is_rocm_pytorch: 36 | define_macros=[('FMOE_USE_HIP', None)] 37 | else: 38 | define_macros=[] 39 | 40 | 41 | if __name__ == '__main__': 42 | setuptools.setup( 43 | name='fastmoe', 44 | version='1.1.0', 45 | description='An efficient Mixture-of-Experts system for PyTorch', 46 | author=', '.join(authors), 47 | author_email='hja20@mails.tsinghua.edu.cn', 48 | license='Apache-2', 49 | url='https://github.com/laekov/fastmoe', 50 | packages=['fmoe', 'fmoe.megatron', 'fmoe.gates', 'fmoe.fastermoe'], 51 | ext_modules=[ 52 | CUDAExtension( 53 | name='fmoe_cuda', 54 | sources=[ 55 | 'cuda/stream_manager.cpp', 56 | 'cuda/local_exchange.cu', 57 | 'cuda/balancing.cu', 58 | 'cuda/global_exchange.cpp', 59 | 'cuda/parallel_linear.cu', 60 | 'cuda/fmoe_cuda.cpp', 61 | 'cuda/fastermoe/smart_schedule.cpp', 62 | ], 63 | define_macros=define_macros, 64 | extra_compile_args={ 65 | 'cxx': cxx_flags, 66 | 'nvcc': cxx_flags 67 | }, 68 | libraries=ext_libs 69 | ) 70 | ], 71 | cmdclass={ 72 | 'build_ext': BuildExtension 73 | }) 74 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | FastMoE test 2 | === 3 | 4 | To run unit test, directly run `pytest` in this directory. 5 | 6 | `test.sh` is a wrapper script to execute single tests without pytest for 7 | debugging purpose. 8 | -------------------------------------------------------------------------------- /tests/benchmark_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fmoe import FMoETransformerMLP 4 | from fmoe.gates import NaiveGate 5 | from moe import BruteForceMoELinear 6 | import time 7 | import sys 8 | import os 9 | 10 | 11 | rank = None 12 | world_size = None 13 | dev_name_default = "cuda:0" 14 | 15 | 16 | class BruteForceMoE(nn.Module): 17 | def __init__( 18 | self, 19 | num_expert=32, 20 | d_model=1024, 21 | d_hidden=4096, 22 | world_size=1, 23 | mp_group=None, 24 | activation=torch.nn.functional.gelu, 25 | gate=NaiveGate, 26 | top_k=1, 27 | pre_lnorm=False, 28 | ): 29 | assert world_size == 1, "Distributed brute force is not supported" 30 | super().__init__() 31 | self.mlp = BruteForceMoELinear( 32 | activation, num_expert, d_model, d_hidden, 1, top_k 33 | ) 34 | self.top_k = top_k 35 | self.gate = gate(d_model, num_expert, world_size, top_k) 36 | self.pre_lnorm = pre_lnorm 37 | self.layer_norm = nn.LayerNorm(d_model) 38 | self.d_model = d_model 39 | 40 | def forward(self, inp): 41 | if self.pre_lnorm: 42 | inp = self.layer_norm(inp) 43 | gate_top_k_idx, gate_score = self.gate(inp) 44 | inp = inp.repeat_interleave(repeats=self.top_k, dim=0) 45 | x = self.mlp(inp, gate_top_k_idx, gate_score) 46 | if not self.pre_lnorm: 47 | x = self.layer_norm(x) 48 | return x 49 | 50 | 51 | def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k): 52 | torch.manual_seed(42 + rank) 53 | torch.cuda.manual_seed(42 + rank) 54 | if rank == 0: 55 | print( 56 | "Performance test of {} mm size {} {}x{} experts {}x{} topk {}".format( 57 | MOELayer.__name__, 58 | batch_size, 59 | in_feat, 60 | hidden_feat, 61 | world_size, 62 | num_expert, 63 | top_k, 64 | ) 65 | ) 66 | if world_size > 1: 67 | dev_name = "cuda" 68 | else: 69 | dev_name = dev_name_default 70 | 71 | inp = torch.rand(batch_size, in_feat).cuda(dev_name) 72 | inp.requires_grad = True 73 | 74 | moe = MOELayer( 75 | num_expert=num_expert, 76 | d_model=in_feat, 77 | d_hidden=hidden_feat, 78 | world_size=world_size, 79 | top_k=top_k, 80 | ).cuda(dev_name) 81 | moe.train() 82 | 83 | # warm up 84 | for _ in range(4): 85 | _ = moe(inp) 86 | 87 | n_runs = 16 88 | tott = 0.0 89 | backt = 0.0 90 | maxt = 0.0 91 | sqtot = 0.0 92 | for i in range(n_runs): 93 | ts = time.time() 94 | o = moe(inp) 95 | te = time.time() 96 | 97 | loss = o.sum() 98 | 99 | bts = time.time() 100 | loss.backward() 101 | bte = time.time() 102 | 103 | tott += te - ts 104 | sqtot += (te - ts) ** 2 105 | maxt = max(maxt, te - ts) 106 | backt += bte - bts 107 | 108 | gflops = ( 109 | 2e-9 110 | * n_runs 111 | * ( 112 | in_feat * hidden_feat * batch_size * top_k * 2 113 | + batch_size * in_feat * num_expert 114 | ) 115 | / tott 116 | ) 117 | print( 118 | "Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs".format( 119 | tott * 1e3 / n_runs, 120 | maxt * 1e3, 121 | (sqtot / n_runs - (tott / n_runs) ** 2) * 1e3 * top_k / n_runs, 122 | backt * 1e3 / n_runs, 123 | gflops, 124 | ) 125 | ) 126 | 127 | 128 | if __name__ == "__main__": 129 | if int(os.environ["WORLD_SIZE"]) > 1: 130 | torch.distributed.init_process_group(backend="nccl") 131 | rank = torch.distributed.get_rank() 132 | world_size = torch.distributed.get_world_size() 133 | else: 134 | rank = 0 135 | world_size = 1 136 | batch_size = int(os.environ.get("BATCH_SIZE", "4096")) 137 | d_model = int(os.environ.get("D_MODEL", "1024")) 138 | d_hidden = int(os.environ.get("D_HIDDEN", "4096")) 139 | num_expert = int(os.environ.get("NUM_EXPERT", "64")) 140 | top_k = int(os.environ.get("TOP_K", "2")) 141 | benchmark_mlp(FMoETransformerMLP, batch_size, d_model, d_hidden, num_expert, top_k) 142 | if world_size == 1: 143 | benchmark_mlp(BruteForceMoE, batch_size, d_model, d_hidden, num_expert, top_k) 144 | -------------------------------------------------------------------------------- /tests/moe.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | import torch 4 | 5 | 6 | class BruteForceMoELinear(nn.Module): 7 | def __init__( 8 | self, 9 | activation, 10 | num_expert=32, 11 | d_model=1024, 12 | d_hidden=2048, 13 | world_size=1, 14 | top_k=2, 15 | ): 16 | super(BruteForceMoELinear, self).__init__() 17 | self.num_expert = num_expert 18 | self.d_model = d_model 19 | self.activation = activation 20 | self.weight_htoh4 = nn.Parameter( 21 | torch.Tensor(num_expert * world_size, d_hidden, d_model) 22 | ) 23 | self.bias_htoh4 = nn.Parameter(torch.Tensor(num_expert * world_size, d_hidden)) 24 | self.weight_h4toh = nn.Parameter( 25 | torch.Tensor(num_expert * world_size, d_model, d_hidden) 26 | ) 27 | self.bias_h4toh = nn.Parameter(torch.Tensor(num_expert * world_size, d_model)) 28 | self.top_k = top_k 29 | 30 | def forward(self, inp, gate_idx, gate_score): 31 | inp = inp.repeat_interleave(repeats=self.top_k, dim=0) 32 | gate_long = gate_idx.long().view(-1) 33 | batch_size = inp.size(0) 34 | o = torch.empty(batch_size, self.d_model, dtype=inp.dtype, device=inp.device) 35 | for i in range(self.weight_htoh4.shape[0]): 36 | idx = gate_long == i 37 | x = inp[idx] 38 | x = x @ self.weight_htoh4[i].t() 39 | x = x + self.bias_htoh4[i] 40 | x = self.activation(x) 41 | x = x @ self.weight_h4toh[i].t() 42 | x = x + self.bias_h4toh[i] 43 | o[idx] = x 44 | gate_score = gate_score.unsqueeze(1) 45 | 46 | x = torch.bmm(gate_score, o.view(-1, self.top_k, self.d_model)).reshape( 47 | -1, self.d_model 48 | ) 49 | return x 50 | 51 | 52 | class BruteForceMoE(nn.Module): 53 | def __init__(self, expert, num_expert=32, d_model=1024, world_size=1, top_k=2): 54 | super(BruteForceMoE, self).__init__() 55 | self.num_expert = num_expert 56 | self.d_model = d_model 57 | self.top_k = top_k 58 | if type(expert) is list: 59 | self.experts = [e(d_model) for e in expert] 60 | self.num_expert = num_expert = len(expert) 61 | else: 62 | self.experts = [expert(d_model) for _ in range(num_expert * world_size)] 63 | 64 | def forward(self, inp, gate_idx, gate_score): 65 | inp = inp.repeat_interleave(repeats=self.top_k, dim=0) 66 | gate_long = gate_idx.long().view(-1) 67 | batch_size = inp.size(0) 68 | x = inp.new_zeros((batch_size, self.d_model)) 69 | for i in range(batch_size): 70 | x[i] = self.experts[gate_long[i]](inp[i]) 71 | gate_score = gate_score.unsqueeze(1) 72 | x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape( 73 | -1, self.d_model 74 | ) 75 | return x 76 | 77 | 78 | class NaiveExpert(nn.Module): 79 | def __init__(self, d_model): 80 | super(NaiveExpert, self).__init__() 81 | self.linear = nn.Linear(d_model, d_model) 82 | 83 | def forward(self, x, fec=None): 84 | return self.linear(x) 85 | 86 | 87 | class LinearExpert(nn.Module): 88 | def __init__(self, d_model): 89 | super(LinearExpert, self).__init__() 90 | self.model = nn.Sequential( 91 | nn.Linear(d_model, d_model * 2), nn.ReLU(), nn.Linear(d_model * 2, d_model), 92 | ) 93 | 94 | def forward(self, x, fec=None): 95 | return self.model(x) 96 | -------------------------------------------------------------------------------- /tests/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z $MASTER_ADDR ] 3 | then 4 | if [ -z $SLURM_JOB_ID ] 5 | then 6 | export MASTER_ADDR=localhost 7 | else 8 | export MASTER_ADDR=$(scontrol show JobId=$SLURM_JOB_ID | grep BatchHost | tr '=' ' ' | awk '{print $2}') 9 | fi 10 | fi 11 | if [ -z $MASTER_PORT ] 12 | then 13 | export MASTER_PORT=12215 14 | fi 15 | 16 | if [ ! -z $OMPI_COMM_WORLD_RANK ] 17 | then 18 | RANK=$OMPI_COMM_WORLD_RANK 19 | localrank=$OMPI_COMM_WORLD_LOCAL_RANK 20 | elif [ ! -z $SLURM_PROCID ] 21 | then 22 | export RANK=$SLURM_PROCID 23 | export WORLD_SIZE=$SLURM_NPROCS 24 | localrank=$SLURM_LOCALID 25 | else 26 | RANK=0 27 | localrank=0 28 | WORLD_SIZE=1 29 | fi 30 | 31 | export CUDA_VISIBLE_DEVICES=$localrank 32 | 33 | exec $@ 2>&1 | tee $RANK.log 34 | -------------------------------------------------------------------------------- /tests/test_comm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import os 4 | import sys 5 | import json 6 | import math 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.nn.functional as F 11 | from fmoe.functions import ensure_comm 12 | 13 | from test_ddp import _ensure_initialized, _run_distributed 14 | 15 | 16 | @pytest.mark.parametrize("n", [1, 2]) 17 | def test_ensure(n): 18 | _run_distributed('_test_ensure', 19 | n, dict(), 20 | script=__file__ 21 | ) 22 | 23 | 24 | def _test_ensure(): 25 | _ensure_initialized() 26 | rank = torch.distributed.get_rank() 27 | x = torch.rand(10).cuda() 28 | ensure_comm(x, None) 29 | 30 | 31 | if __name__ == '__main__': 32 | if len(sys.argv) >= 3: 33 | args = json.loads(sys.argv[2]) 34 | locals()[sys.argv[1]](**args) 35 | else: 36 | _ensure_initialized() 37 | _test_ensure() 38 | -------------------------------------------------------------------------------- /tests/test_dp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | 6 | from fmoe.gates import NaiveGate 7 | from fmoe.layers import FMoE 8 | from fmoe.transformer import _Expert 9 | 10 | n_devices = int(os.environ.get("N_GPUS", "2")) 11 | 12 | 13 | class MyMoE(FMoE): 14 | def __init__(self, num_expert, d_model, d_hidden, top_k, activation): 15 | super().__init__( 16 | num_expert=num_expert, 17 | d_model=d_model, 18 | gate=NaiveGate, 19 | world_size=1, 20 | mp_group=None, 21 | top_k=top_k, 22 | ) 23 | self.experts = _Expert(num_expert, d_model, d_hidden, activation) 24 | 25 | 26 | @pytest.mark.parametrize("num_expert", [4, 8]) 27 | @pytest.mark.parametrize("top_k", [2, 3]) 28 | @pytest.mark.parametrize("batch_size", [4]) 29 | @pytest.mark.parametrize("d_model", [16]) 30 | @pytest.mark.parametrize("d_hidden", [32]) 31 | def test_fmoe_dp( 32 | num_expert, 33 | top_k, 34 | batch_size, 35 | d_model, 36 | d_hidden, 37 | activation=torch.nn.functional.gelu, 38 | ): 39 | torch.manual_seed(42) 40 | torch.cuda.manual_seed(42) 41 | 42 | moe = MyMoE(num_expert, d_model, d_hidden, top_k, activation).cuda() 43 | moe_dp = torch.nn.DataParallel(moe, device_ids=list(range(n_devices))) 44 | 45 | for i in range(5): 46 | output = moe_dp(torch.rand(batch_size, d_model).cuda()) 47 | 48 | 49 | if __name__ == "__main__": 50 | test_fmoe_dp(4, 2, 4, 16, 32) 51 | -------------------------------------------------------------------------------- /tests/test_faster_gate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import os 4 | import sys 5 | import json 6 | import math 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.nn.functional as F 11 | from fmoe.gates.faster_gate import FasterGate 12 | from test_ddp import _ensure_initialized, _run_distributed 13 | 14 | 15 | @pytest.mark.parametrize("n_process", [8]) 16 | @pytest.mark.parametrize("d_model", [1024]) 17 | @pytest.mark.parametrize("batch_size", [16]) 18 | @pytest.mark.parametrize("n_expert", [1, 4]) 19 | @pytest.mark.parametrize("gpu_per_node", [2, 4, 8]) 20 | @pytest.mark.parametrize("frac", [.2]) 21 | def test_faster_gate(n_process, d_model, batch_size, n_expert, gpu_per_node, frac): 22 | _run_distributed('_test_faster_gate', 23 | n_process, 24 | { 25 | 'd_model': d_model, 26 | 'batch_size': batch_size, 27 | 'n_expert': n_expert, 28 | 'gpu_per_node': gpu_per_node, 29 | 'frac': frac 30 | }, 31 | script=__file__, 32 | env=dict( 33 | FMOE_TOPO_GPUS_PER_NODE=str(gpu_per_node), 34 | FMOE_TOPO_OUTGOING_FRACTION=str(frac) 35 | ) 36 | ) 37 | 38 | 39 | def _test_faster_gate(d_model, batch_size, n_expert, gpu_per_node, frac): 40 | _ensure_initialized() 41 | rank = dist.get_rank() 42 | node_rank = rank // gpu_per_node 43 | 44 | gate = FasterGate(d_model, n_expert, dist.get_world_size(), node_rank).cuda() 45 | x = torch.rand(batch_size, d_model).cuda() 46 | topk_idx, topk_val = gate(x) 47 | 48 | cnto = 0 49 | idxs = topk_idx[:, 0].cpu().view(-1).numpy() 50 | for v in idxs: 51 | assert(v != -1) 52 | if v // n_expert // gpu_per_node != rank // gpu_per_node: 53 | cnto += 1 54 | assert(cnto <= math.ceil(batch_size * frac)) 55 | 56 | 57 | if __name__ == '__main__': 58 | if len(sys.argv) >= 3: 59 | args = json.loads(sys.argv[2]) 60 | locals()[sys.argv[1]](**args) 61 | else: 62 | test_faster_gate(8, 1024, 16, 1, 2, .2) 63 | -------------------------------------------------------------------------------- /tests/test_faster_schedule.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import os 4 | import sys 5 | import json 6 | import math 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.nn.functional as F 11 | from fmoe.functions import ensure_comm 12 | from test_ddp import _ensure_initialized, _run_distributed 13 | from test_numerical import _assert_numerical 14 | from fmoe.fastermoe.schedule import _fmoe_general_global_forward as smart_fwd 15 | from fmoe.layers import _fmoe_general_global_forward as naive_fwd 16 | 17 | 18 | @pytest.mark.parametrize("n_process", [8]) 19 | @pytest.mark.parametrize("d_model", [1024]) 20 | @pytest.mark.parametrize("batch_size", [16]) 21 | @pytest.mark.parametrize("n_expert", [1, 4]) 22 | @pytest.mark.parametrize("group_sz", [1, 2, 4]) 23 | def test_faster_schedule(n_process, d_model, batch_size, n_expert, group_sz): 24 | _run_distributed('_test_faster_schedule', 25 | n_process, 26 | { 27 | 'd_model': d_model, 28 | 'batch_size': batch_size, 29 | 'n_expert': n_expert 30 | }, 31 | script=__file__, 32 | env=dict( 33 | FMOE_FASTER_GROUP_SIZE=str(group_sz) 34 | ) 35 | ) 36 | 37 | 38 | def _test_faster_schedule(d_model, batch_size, n_expert): 39 | _ensure_initialized() 40 | rank = dist.get_rank() 41 | world_size = dist.get_world_size() 42 | 43 | x1 = torch.rand(batch_size, d_model).cuda() 44 | x1.requires_grad = True 45 | x2 = x1.data.clone() 46 | x2.requires_grad = True 47 | topk_idx = torch.randint(0, world_size * n_expert, (batch_size, 2)).cuda() 48 | m1s = [torch.nn.Linear(d_model, d_model).cuda() for _ in range(n_expert)] 49 | m2s = [torch.nn.Linear(d_model, d_model).cuda() for _ in range(n_expert)] 50 | with torch.no_grad(): 51 | for m1, m2 in zip(m1s, m2s): 52 | m2.weight.copy_(m1.weight) 53 | m2.bias.copy_(m1.bias) 54 | 55 | def ef1(x, fec, eidx): 56 | return m1s[eidx](x) 57 | 58 | def ef2(x, fec): 59 | o = 0 60 | ys = [] 61 | for m, i in zip(m2s, fec): 62 | if i > 0: 63 | ys.append(m(x[o:o + i])) 64 | o += i 65 | y = torch.cat(ys) 66 | return y 67 | 68 | ensure_comm(x1, None) 69 | y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=m1s) 70 | y1.sum().backward() 71 | 72 | y2 = naive_fwd(x2, topk_idx, ef2, n_expert, world_size, experts=m2s) 73 | y2.sum().backward() 74 | _assert_numerical(['out', 'grad_in'], 75 | [y1, x1.grad], 76 | [y2, x2.grad], rank) 77 | for i in range(n_expert): 78 | _assert_numerical([f'grad_bias_{i}', f'grad_weight_{i}'], 79 | [m1s[i].bias.grad, m1s[i].weight.grad], 80 | [m2s[i].bias.grad, m2s[i].weight.grad], rank) 81 | 82 | 83 | if __name__ == '__main__': 84 | if len(sys.argv) >= 3: 85 | args = json.loads(sys.argv[2]) 86 | locals()[sys.argv[1]](**args) 87 | else: 88 | # test_faster_schedule(8, 16, 16, 1, 2) 89 | _test_faster_schedule(4, 2, 4) 90 | -------------------------------------------------------------------------------- /tests/test_faster_shadow.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import os 4 | import sys 5 | import json 6 | import math 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.nn.functional as F 11 | from fmoe.functions import ensure_comm 12 | from test_ddp import _ensure_initialized, _run_distributed 13 | from test_numerical import _assert_numerical 14 | from fmoe.fastermoe.schedule import _fmoe_general_global_forward as smart_fwd 15 | from fmoe.layers import _fmoe_general_global_forward as naive_fwd 16 | 17 | 18 | @pytest.mark.parametrize("n_process", [8]) 19 | @pytest.mark.parametrize("d_model", [1024]) 20 | @pytest.mark.parametrize("batch_size", [16, 512]) 21 | @pytest.mark.parametrize("n_expert", [1]) 22 | @pytest.mark.parametrize("group_sz", [1, 2, 4]) 23 | @pytest.mark.parametrize("pass_stored", [True, False]) 24 | def test_faster_shadow(n_process, d_model, batch_size, n_expert, group_sz, pass_stored): 25 | _run_distributed('_test_faster_shadow', 26 | n_process, 27 | { 28 | 'd_model': d_model, 29 | 'batch_size': batch_size, 30 | 'n_expert': n_expert, 31 | 'pass_stored': pass_stored 32 | }, 33 | script=__file__, 34 | env=dict( 35 | FMOE_FASTER_GROUP_SIZE=str(group_sz), 36 | FMOE_FASTER_SHADOW_ENABLE='ON' 37 | ) 38 | ) 39 | 40 | 41 | def _test_faster_shadow(d_model, batch_size, n_expert, pass_stored): 42 | _ensure_initialized() 43 | rank = dist.get_rank() 44 | world_size = dist.get_world_size() 45 | 46 | x1 = torch.rand(batch_size, d_model).cuda() 47 | x1.requires_grad = True 48 | x2 = x1.data.clone() 49 | x2.requires_grad = True 50 | topk_idx = torch.randint(0, world_size * n_expert, (batch_size, 2)).cuda() 51 | m1 = torch.nn.Linear(d_model, d_model).cuda() 52 | m2 = torch.nn.Linear(d_model, d_model).cuda() 53 | with torch.no_grad(): 54 | m2.weight.copy_(m1.weight) 55 | m2.bias.copy_(m1.bias) 56 | 57 | def ef1(x, fec, eidx): 58 | y = m1(x) 59 | return y 60 | def ef2(x, fec): 61 | y = m2(x) 62 | return y 63 | 64 | if pass_stored: 65 | stored_models = torch.randint(0, 2, (world_size * n_expert,)).bool().cuda() 66 | while stored_models.sum().item() == 0: 67 | stored_models = torch.randint(0, 2, (world_size * n_expert,)).bool().cuda() 68 | stored_models[-1] = True 69 | dist.broadcast(stored_models, 0) 70 | stored_models = stored_models.cpu() 71 | print(stored_models) 72 | 73 | ensure_comm(x1, None) 74 | if pass_stored: 75 | y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=[m1], 76 | stored_models=stored_models) 77 | else: 78 | y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=[m1]) 79 | y1.sum().backward() 80 | 81 | y2 = naive_fwd(x2, topk_idx, ef2, n_expert, world_size, experts=[m2]) 82 | y2.sum().backward() 83 | _assert_numerical(['out', 'grad_in', 'grad_bias', 'grad_weight'], 84 | [y1, x1.grad, m1.bias.grad, m1.weight.grad], 85 | [y2, x2.grad, m2.bias.grad, m2.weight.grad], rank) 86 | 87 | 88 | if __name__ == '__main__': 89 | if len(sys.argv) >= 3: 90 | args = json.loads(sys.argv[2]) 91 | locals()[sys.argv[1]](**args) 92 | else: 93 | # test_faster_shadow(8, 16, 16, 1, 2) 94 | _test_faster_shadow(1024, 16, 1, True) 95 | -------------------------------------------------------------------------------- /tests/test_gates.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import os 4 | import sys 5 | import json 6 | import math 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.nn.functional as F 11 | from fmoe.gates import GShardGate, SwitchGate 12 | from fmoe.functions import ensure_comm 13 | from test_ddp import _ensure_initialized, _run_distributed 14 | 15 | 16 | @pytest.mark.parametrize("d_model", [1024]) 17 | @pytest.mark.parametrize("batch_size", [16]) 18 | @pytest.mark.parametrize("n_expert", [1, 4]) 19 | @pytest.mark.parametrize("cap", [.1, 1.1]) 20 | def test_gshard_gate(d_model, batch_size, n_expert, cap): 21 | if 1 * n_expert < 2: 22 | pytest.skip("No enough experts") 23 | _run_distributed('_test_gshard_gate', 24 | 1, 25 | { 26 | 'd_model': d_model, 27 | 'batch_size': batch_size, 28 | 'n_expert': n_expert, 29 | 'cap': cap 30 | }, 31 | script=__file__ 32 | ) 33 | 34 | 35 | def _test_gshard_gate(d_model, batch_size, n_expert, cap): 36 | _ensure_initialized() 37 | rank = torch.distributed.get_rank() 38 | gate = GShardGate(d_model, n_expert, dist.get_world_size(), 39 | capacity=(cap, cap)).cuda() 40 | x = torch.rand(batch_size, d_model).cuda() 41 | ensure_comm(x, None) 42 | topk_idx, topk_val = gate(x) 43 | counts = [0 for _ in range(n_expert * dist.get_world_size())] 44 | for v in topk_idx.cpu().view(-1).numpy(): 45 | if v != -1: 46 | counts[v] += 1 47 | real_cap = math.ceil(cap * batch_size) 48 | for i in counts: 49 | assert(i <= real_cap) 50 | 51 | gate_score = gate.gate(x) 52 | gate_top_k_val, gate_top_k_idx = torch.topk( 53 | gate_score, k=gate.top_k, dim=-1, largest=True, sorted=False 54 | ) 55 | gate_top_k_val = gate_top_k_val.view(-1, gate.top_k) 56 | gate_score = F.softmax(gate_top_k_val, dim=-1) 57 | 58 | for i in range(batch_size): 59 | for j in range(gate.top_k): 60 | v = topk_idx[i, j] 61 | if v != -1: 62 | assert topk_val[i, j] == gate_score[i, j] 63 | 64 | 65 | @pytest.mark.parametrize("d_model", [1024]) 66 | @pytest.mark.parametrize("batch_size", [4096]) 67 | @pytest.mark.parametrize("n_expert", [1, 16]) 68 | @pytest.mark.parametrize("cap", [.1, .8]) 69 | def test_switch_gate(d_model, batch_size, n_expert, cap): 70 | _run_distributed('_test_switch_gate', 71 | 1, 72 | { 73 | 'd_model': d_model, 74 | 'batch_size': batch_size, 75 | 'n_expert': n_expert, 76 | 'cap': cap 77 | }, 78 | script=__file__ 79 | ) 80 | 81 | 82 | def _test_switch_gate(d_model, batch_size, n_expert, cap): 83 | _ensure_initialized() 84 | gate = SwitchGate(d_model, n_expert, dist.get_world_size(), 85 | capacity=(cap, cap)).cuda() 86 | x = torch.rand(batch_size, d_model).cuda() 87 | rng = torch.cuda.get_rng_state() # save rng state 88 | topk_idx, topk_val = gate(x) 89 | counts = [0 for _ in range(n_expert * dist.get_world_size())] 90 | for v in topk_idx.cpu().view(-1).numpy(): 91 | if v != -1: 92 | counts[v] += 1 93 | real_cap = math.ceil(cap * batch_size) 94 | for i in counts: 95 | assert(i <= real_cap) 96 | 97 | score = gate.gate(x) 98 | 99 | if gate.training: 100 | # reset rng state to make sure noise is the same as in gate.forward() 101 | torch.cuda.set_rng_state(rng) 102 | # random uniform number from [1-eps, 1+eps] 103 | noise = torch.rand_like(score) 104 | noise = noise * 2 * gate.switch_eps + 1.0 - gate.switch_eps 105 | score += noise 106 | 107 | # fp32 softmax for numerical stability 108 | score = F.softmax(score.float(), dim=-1) 109 | 110 | for i in range(batch_size): 111 | v = topk_idx[i] 112 | if v != -1: 113 | assert topk_val[i] == score[i, topk_idx[i]] 114 | 115 | 116 | if __name__ == '__main__': 117 | if len(sys.argv) >= 3: 118 | args = json.loads(sys.argv[2]) 119 | locals()[sys.argv[1]](**args) 120 | else: 121 | # _ensure_initialized() 122 | # test_gshard_gate(4096, 1024, 4, .2) 123 | _test_gshard_gate(8, 16, 4, .1) 124 | # test_switch_gate(4096, 1024, 4, .2) 125 | -------------------------------------------------------------------------------- /tests/test_local_exchange.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | from typing import List, Type, Union 4 | 5 | import pytest 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | from copy import deepcopy 11 | from fmoe.functions import MOEGather, MOEScatter, count_by_gate 12 | 13 | from test_numerical import _assert_numerical 14 | 15 | @pytest.mark.parametrize("n_expert", [1, 4, 8]) 16 | @pytest.mark.parametrize("topk", [1, 2]) 17 | @pytest.mark.parametrize("batch_size", [12]) 18 | @pytest.mark.parametrize("d_model", [6]) 19 | @pytest.mark.parametrize("world_size", [1]) 20 | def test_scatter(n_expert, topk, batch_size, d_model, world_size): 21 | gate_idx = torch.randint(n_expert + 1, (batch_size, topk)) - 1 22 | gate_idx = gate_idx.long().cuda() 23 | pos, lec, gec = count_by_gate(gate_idx, n_expert, world_size) 24 | fbs = int(gec.sum().item()) 25 | inp = torch.rand(batch_size, d_model).cuda() 26 | inp.requires_grad = True 27 | out = MOEScatter.apply(inp, pos % batch_size, lec, gec, fbs, world_size) 28 | out.sum().backward() 29 | 30 | inp_raw = inp.data.clone() 31 | out_raw = torch.empty(pos.shape[0], d_model, 32 | device=inp.device, dtype=inp.dtype) 33 | # out_raw.sum().backward() 34 | for i, f in enumerate(pos.cpu()): 35 | out_raw[i] = inp[f % batch_size] 36 | _assert_numerical(['out'], [out], [out_raw], 0) 37 | # TODO: check grad 38 | 39 | if __name__ == '__main__': 40 | test_scatter(4, 2, 8, 6, 1) 41 | -------------------------------------------------------------------------------- /tests/test_mimo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | from fmoe.gates import NaiveGate 9 | from fmoe.layers import FMoE 10 | from fmoe.linear import FMoELinear 11 | from fmoe.megatron.layers import _megatron_init_method 12 | 13 | 14 | def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3): 15 | for name, mo, ro in zip(names, moe_out_list, raw_out_list): 16 | err = (mo - ro).abs().max() 17 | print("Rank {} {} abs err {}".format(rank, name, err)) 18 | if err > precision: 19 | sys.stderr.write(f"=========== {name} moe out ==============\n") 20 | sys.stderr.write("{}\n".format(mo)) 21 | sys.stderr.write(f"=========== {name} raw out ==============\n") 22 | sys.stderr.write("{}\n".format(ro)) 23 | sys.stderr.write(f"=========== {name} diff ==============\n") 24 | sys.stderr.write("{}\n{}\n".format((mo - ro).abs(), err)) 25 | assert False 26 | 27 | 28 | class MyExpert(nn.Module): 29 | r""" 30 | An expert using 2 FMoELinear modules to speed up the computation of experts 31 | within one worker. 32 | """ 33 | 34 | def __init__(self, num_expert, d_model, d_hidden, activation, rank=0): 35 | super().__init__() 36 | self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank) 37 | self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank) 38 | self.activation = activation 39 | 40 | def forward(self, inp, fwd_expert_count): 41 | r""" 42 | First expand input to 4h (the hidden size is variable, but is called h4 43 | for convenience). Then perform activation. Finally shirink back to h. 44 | """ 45 | if type(inp) == dict: 46 | x = inp["x"] 47 | y = inp["y"] 48 | elif type(inp) == list: 49 | x = inp[0] 50 | y = inp[1] 51 | else: 52 | raise NotImplementedError 53 | x = self.htoh4(x, fwd_expert_count) 54 | x = self.activation(x) 55 | x = self.h4toh(x, fwd_expert_count) 56 | y = self.htoh4(y, fwd_expert_count) 57 | y = self.activation(y) 58 | y = self.h4toh(y, fwd_expert_count) 59 | if type(inp) == dict: 60 | ret = {"x": x, "y": y} 61 | elif type(inp) == list: 62 | ret = [x, y] 63 | 64 | return ret 65 | 66 | 67 | class MyGate(NaiveGate): 68 | def __init__(self, d_model, num_expert, world_size, top_k=2, gate_bias=True): 69 | super().__init__(d_model, num_expert, world_size, top_k, gate_bias=gate_bias) 70 | 71 | def forward(self, inp, return_all_scores=False): 72 | if type(inp) == dict: 73 | x = inp["x"] 74 | elif type(inp) == list: 75 | x = inp[0] 76 | else: 77 | raise NotImplementedError 78 | return super().forward(x, return_all_scores) 79 | 80 | 81 | class MyMoE(FMoE): 82 | def __init__( 83 | self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation 84 | ): 85 | super().__init__( 86 | num_expert=num_expert, 87 | d_model=d_model, 88 | gate=MyGate, 89 | world_size=world_size, 90 | mp_group=mp_group, 91 | top_k=top_k, 92 | ) 93 | self.experts = MyExpert(num_expert, d_model, d_hidden, activation) 94 | 95 | rng = np.random.default_rng(1234) 96 | _megatron_init_method(self.experts.htoh4, rng, 1.0) 97 | _megatron_init_method(self.experts.h4toh, rng, 1.0) 98 | 99 | 100 | @pytest.mark.parametrize("num_expert", [4, 8]) 101 | @pytest.mark.parametrize("top_k", [2, 3]) 102 | @pytest.mark.parametrize("batch_size", [4]) 103 | @pytest.mark.parametrize("d_model", [16]) 104 | @pytest.mark.parametrize("d_hidden", [32]) 105 | @pytest.mark.parametrize("rank", [0]) 106 | @pytest.mark.parametrize("world_size", [1]) 107 | @pytest.mark.parametrize("mp_group", [None]) 108 | @pytest.mark.parametrize("dp_group", [None]) 109 | @pytest.mark.parametrize("world_group", [None]) 110 | @pytest.mark.parametrize( 111 | "data_type", [torch.float32, torch.float16, torch.bfloat16, torch.double] 112 | ) 113 | @pytest.mark.parametrize("list_input", [False, True]) 114 | def test_fmoe_mimo_linear( 115 | num_expert, 116 | top_k, 117 | batch_size, 118 | d_model, 119 | d_hidden, 120 | rank, 121 | world_size, 122 | mp_group, 123 | dp_group, 124 | world_group, 125 | data_type, 126 | list_input, 127 | activation=torch.nn.functional.gelu, 128 | ): 129 | 130 | torch.manual_seed(42 + rank) 131 | torch.cuda.manual_seed(42 + rank) 132 | 133 | moe = MyMoE( 134 | num_expert=num_expert, 135 | d_model=d_model, 136 | d_hidden=4 * d_model, 137 | world_size=world_size, 138 | mp_group=mp_group, 139 | top_k=top_k, 140 | activation=activation, 141 | ).cuda().to(data_type) 142 | 143 | x = torch.rand(batch_size, d_model).cuda().to(data_type) 144 | inp = [x, x.clone()] if list_input else {"x": x, "y": x.clone()} 145 | moe_out = moe(inp) 146 | 147 | if list_input: 148 | _assert_numerical(["x"], [moe_out[0]], [moe_out[1]], rank) 149 | else: 150 | _assert_numerical(["x"], [moe_out["x"]], [moe_out["y"]], rank) 151 | 152 | 153 | if __name__ == "__main__": 154 | test_fmoe_mimo_linear( 155 | batch_size=2, 156 | num_expert=2, 157 | d_model=2, 158 | top_k=2, 159 | d_hidden=16, 160 | rank=0, 161 | world_size=1, 162 | mp_group=None, 163 | dp_group=None, 164 | world_group=None, 165 | data_type=torch.bfloat16, 166 | list_input=True 167 | ) 168 | -------------------------------------------------------------------------------- /tests/test_swipe.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import os 4 | import sys 5 | import json 6 | import math 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.nn.functional as F 11 | from fmoe.functions import ensure_comm 12 | from fmoe.gates.swipe_gate import SwipeGate 13 | from test_ddp import _ensure_initialized, _run_distributed 14 | 15 | 16 | 17 | @pytest.mark.parametrize("d_model", [1024]) 18 | @pytest.mark.parametrize("batch_size", [16]) 19 | @pytest.mark.parametrize("n_expert", [1, 4]) 20 | @pytest.mark.parametrize("top_k", [2, 4]) 21 | @pytest.mark.parametrize("world_size", [2, 4, 8]) 22 | def test_swipe_gate(world_size, d_model, batch_size, n_expert, top_k): 23 | if world_size * n_expert < 2: 24 | pytest.skip("No enough experts") 25 | _run_distributed('_test_swipe_gate', 26 | world_size, 27 | { 28 | 'd_model': d_model, 29 | 'batch_size': batch_size, 30 | 'n_expert': n_expert, 31 | 'top_k': top_k 32 | }, 33 | script=__file__ 34 | ) 35 | 36 | 37 | def _test_swipe_gate(d_model, batch_size, n_expert, top_k): 38 | _ensure_initialized() 39 | gate = SwipeGate(d_model, n_expert, dist.get_world_size()).cuda() 40 | x = torch.rand(batch_size, d_model).cuda() 41 | ensure_comm(x, None) 42 | topk_idx, topk_val = gate(x) 43 | 44 | 45 | @pytest.mark.parametrize("batch_size", [16]) 46 | @pytest.mark.parametrize("n_expert", [1, 4]) 47 | @pytest.mark.parametrize("world_size", [2, 4, 8]) 48 | def test_swipe_once(world_size, batch_size, n_expert): 49 | if world_size * n_expert < 2: 50 | pytest.skip("No enough experts") 51 | _run_distributed('_test_swipe_once', 52 | world_size, 53 | { 54 | 'batch_size': batch_size, 55 | 'n_expert': n_expert 56 | }, 57 | script=__file__ 58 | ) 59 | 60 | 61 | def _test_swipe_once(batch_size, n_expert): 62 | _ensure_initialized() 63 | rank = dist.get_rank() 64 | world_size = dist.get_world_size() 65 | gate = SwipeGate(4, n_expert, dist.get_world_size()).cuda() 66 | idx = torch.randint(0, n_expert * world_size, (batch_size,)).cuda() 67 | capacity = torch.scalar_tensor(batch_size * 2, dtype=torch.long) 68 | ensure_comm(idx, None) 69 | new_idx, new_cap = gate.swipe_once(idx, capacity, 0) 70 | idx = torch.randint(0, n_expert * world_size, (batch_size,)).cuda() 71 | new_idx, new_cap = gate.swipe_once(idx, new_cap, 0) 72 | 73 | if __name__ == '__main__': 74 | if len(sys.argv) >= 3: 75 | args = json.loads(sys.argv[2]) 76 | locals()[sys.argv[1]](**args) 77 | else: 78 | test_swipe_gate(8, 4, 8, 4, 2) 79 | # test_swipe_once(8, 800, 4) 80 | -------------------------------------------------------------------------------- /tests/test_zero.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | from fmoe.layers import _fmoe_general_global_forward 6 | from fmoe import FMoETransformerMLP 7 | 8 | from test_ddp import _run_distributed 9 | 10 | 11 | class ConstantGate(torch.nn.Module): 12 | def __init__(self, d_model, num_expert, world_size, top_k=1): 13 | super().__init__() 14 | self.top_k = top_k 15 | 16 | def forward(self, inp): 17 | idx = torch.zeros((inp.shape[0], self.top_k), dtype=torch.int64, 18 | device=inp.device) 19 | score = torch.ones((inp.shape[0], 1, self.top_k), device=inp.device) / 2 20 | return idx, score 21 | 22 | 23 | def test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1): 24 | _run_distributed('_test_zero_fwd', 25 | 1, 26 | { 27 | 'num_expert': num_expert, 28 | 'batch_size': batch_size, 29 | 'd_hidden': d_hidden 30 | }, 31 | script=__file__ 32 | ) 33 | 34 | def _test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1): 35 | inp = torch.rand(batch_size, d_hidden).cuda() 36 | gate = torch.zeros(batch_size, dtype=torch.int64).cuda() 37 | x = _fmoe_general_global_forward(inp, gate, lambda x, y: x, num_expert, 38 | world_size) 39 | 40 | 41 | def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1): 42 | _run_distributed('_test_zero_transformer', 43 | 1, 44 | { 45 | 'num_expert': num_expert, 46 | 'batch_size': batch_size, 47 | 'd_hidden': d_hidden 48 | }, 49 | script=__file__ 50 | ) 51 | 52 | def _test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1): 53 | inp = torch.rand(batch_size, d_hidden).cuda() 54 | mask = torch.zeros(inp.shape[0], dtype=torch.long) 55 | mask[1] = 1 56 | mask_dict = { 57 | 1: torch.zeros(d_hidden).cuda() 58 | } 59 | model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4, 60 | world_size=world_size, gate=ConstantGate, mask=mask, 61 | mask_dict=mask_dict).cuda() 62 | oup = model(inp) 63 | 64 | 65 | if __name__ == '__main__': 66 | if len(sys.argv) >= 3: 67 | args = json.loads(sys.argv[2]) 68 | os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0") 69 | os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1") 70 | os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"] 71 | torch.distributed.init_process_group(backend="nccl") 72 | args['world_size'] = torch.distributed.get_world_size() 73 | locals()[sys.argv[1]](**args) 74 | else: 75 | # test_zero_fwd(world_size=torch.distributed.get_world_size()) 76 | test_zero_transformer(num_expert=16, batch_size=4096, d_hidden=1024, 77 | world_size=1) 78 | print('done') 79 | --------------------------------------------------------------------------------