├── fastermoe ├── __init__.py ├── config.py ├── expert_utils.py ├── shadow_policy.py └── schedule.py ├── fastmoe.egg-info ├── dependency_links.txt ├── top_level.txt ├── SOURCES.txt └── PKG-INFO ├── requirements.txt ├── .gitignore ├── gates ├── __init__.py ├── base_gate.py ├── zero_gate.py ├── utils.py ├── naive_gate.py ├── swipe_gate.py ├── gshard_gate.py ├── switch_gate.py ├── faster_gate.py └── noisy_gate.py ├── scripts ├── smoe-s.sh ├── glam-m.sh ├── smoe-m.sh ├── smoe-l.sh ├── smoe-mom-s.sh ├── glam-adam-m.sh ├── glam-mom-m.sh ├── smoe-adam-m.sh ├── smoe-mom-m.sh └── smoe-mom-l.sh ├── custom_utils.py ├── .github └── workflows │ └── typecheck.yaml ├── README.md ├── setup.py ├── custom_transformer.py ├── data.py ├── trainer.py ├── custom_functions.py ├── custom_gates.py ├── finetune_trainer.py ├── utils.py ├── finetune_train.py ├── train.py ├── config.py ├── LICENSE.txt ├── custom_layers.py ├── vocabulary.py ├── finetune_data.py └── custom_layers_opt.py /fastermoe/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys -------------------------------------------------------------------------------- /fastmoe.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /fastmoe.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | fmoe_cuda 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # torch==2.2.0 2 | # numpy 3 | ninja 4 | dm-tree 5 | tqdm 6 | torchmetrics==1.3.1 7 | -------------------------------------------------------------------------------- /fastmoe.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | setup.py 3 | fastmoe.egg-info/PKG-INFO 4 | fastmoe.egg-info/SOURCES.txt 5 | fastmoe.egg-info/dependency_links.txt 6 | fastmoe.egg-info/top_level.txt -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | 3 | # Compiled and optimized Python code and libraries 4 | __pycache__/ 5 | *.py[cod] 6 | 7 | # Temporary editor files 8 | *~ 9 | [._]*.sw[nop] 10 | 11 | # Temporary macOS files 12 | .DS_Store 13 | -------------------------------------------------------------------------------- /gates/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from .zero_gate import ZeroGate 3 | from .naive_gate import NaiveGate 4 | from .noisy_gate import NoisyGate 5 | 6 | from .gshard_gate import GShardGate 7 | from .switch_gate import SwitchGate 8 | 9 | from .swipe_gate import SwipeGate -------------------------------------------------------------------------------- /fastmoe.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: fastmoe 3 | Version: 1.1.0 4 | Summary: An efficient Mixture-of-Experts system for PyTorch 5 | Author: Jiaao He, Jiezhong Qiu, Aohan Zeng, Tiago Antunes, Jinjun Peng, Qin Li, Mingshu Zhai 6 | Author-email: hja20@mails.tsinghua.edu.cn 7 | License: Apache-2 8 | -------------------------------------------------------------------------------- /fastermoe/config.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | def float_from_env(key, default=-1): 4 | if key in os.environ: 5 | return float(os.environ[key]) 6 | return default 7 | 8 | def switch_from_env(key, default=False): 9 | if key in os.environ: 10 | return os.environ[key] in ['1', 'ON'] 11 | return default -------------------------------------------------------------------------------- /gates/base_gate.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch.nn as nn 3 | 4 | class BaseGate(nn.Module): 5 | def __init__(self, num_expert, world_size): 6 | super().__init__() 7 | self.world_size = world_size 8 | self.num_expert = num_expert 9 | self.tot_expert = world_size * num_expert 10 | self.loss = None 11 | 12 | def forward(self, x): 13 | raise NotImplementedError('Base gate cannot be directly used for fwd') 14 | 15 | def set_loss(self, loss): 16 | self.loss = loss 17 | 18 | def get_loss(self, clear=True): 19 | loss = self.loss 20 | if clear: 21 | self.loss = None 22 | return loss 23 | 24 | @property 25 | def has_loss(self): 26 | return self.loss is not None -------------------------------------------------------------------------------- /gates/zero_gate.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from .base_gate import BaseGate 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class ZeroGate(BaseGate): 9 | r""" 10 | Guide all input samples to gate 0. 11 | """ 12 | 13 | def __init__(self, _1, num_expert, world_size, top_k=2): 14 | super().__init__(num_expert, world_size) 15 | self.top_k = top_k 16 | 17 | def forward(self, inp): 18 | r""" 19 | All output to expert 1 20 | """ 21 | idx = torch.zeros( 22 | inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device 23 | ) 24 | gate_score = ( 25 | torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k 26 | ) 27 | return idx, gate_score.reshape(-1, 1, self.top_k) -------------------------------------------------------------------------------- /scripts/smoe-s.sh: -------------------------------------------------------------------------------- 1 | mkdir -p /path/to/checkpoint/directory/ 2 | 3 | args=" 4 | --data /path/to/data/directory/wikitext-103/ \ 5 | --base_arch transformer \ 6 | --architecture sgsgsg \ 7 | --gate_name smoe \ 8 | --nlayers 3 \ 9 | --hid-sz 128 \ 10 | --inner-hid-sz 128 \ 11 | --nheads 8 \ 12 | --block-sz 256 \ 13 | --attn-span 256 \ 14 | --dropout 0.7 \ 15 | --load_balance 0.01 \ 16 | --optim adam \ 17 | --lr 0.0007 \ 18 | --lr-warmup 3000 \ 19 | --niter 60 \ 20 | --batch-sz 96 \ 21 | --batch-split 2 \ 22 | --nbatches 1000 \ 23 | --distributed \ 24 | --checkpoint /path/to/checkpoint/directory/smoe.pt \ 25 | " 26 | 27 | echo "Training ..." 28 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args 29 | 30 | echo "Evaluation ..." 31 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode -------------------------------------------------------------------------------- /gates/utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | from fmoe.functions import count_by_gate 4 | import fmoe_cuda as fmoe_native 5 | 6 | def limit_by_capacity(topk_idx, num_expert, world_size, capacity): 7 | with torch.no_grad(): 8 | capacity = torch.ones(num_expert, dtype=torch.int32, 9 | device=topk_idx.device) * capacity 10 | 11 | pos, lec, gec = count_by_gate(topk_idx, num_expert, world_size, 12 | require_pos=False) 13 | new_gec = fmoe_native.limit_by_capacity(gec, capacity, 14 | num_expert, world_size) 15 | if world_size > 1: 16 | new_lec = fmoe_native.expert_exchange(new_gec, num_expert, 17 | world_size) 18 | else: 19 | new_lec = new_gec 20 | 21 | topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx, 22 | new_lec.to(torch.int32), num_expert, world_size) 23 | return new_lec, new_gec, topk_idx -------------------------------------------------------------------------------- /scripts/glam-m.sh: -------------------------------------------------------------------------------- 1 | mkdir -p /path/to/checkpoint/directory/ 2 | 3 | args=" 4 | --data /path/to/data/directory/wikitext-103/ \ 5 | --base_arch glam \ 6 | --architecture sgsfsgsfsgsfsgsfsgsfsgsf \ 7 | --gate_name smoe \ 8 | --nlayers 6 \ 9 | --hid-sz 352 \ 10 | --inner-hid-sz 352 \ 11 | --nheads 8 \ 12 | --block-sz 512 \ 13 | --attn-span 2048 \ 14 | --dropout 0.1 \ 15 | --load_balance 0.01 \ 16 | --optim adam \ 17 | --lr 0.00007 \ 18 | --lr-warmup 4000 \ 19 | --niter 120 \ 20 | --batch-sz 48 \ 21 | --batch-split 2 \ 22 | --nbatches 1000 \ 23 | --distributed \ 24 | --checkpoint /path/to/checkpoint/directory/smoe.pt \ 25 | " 26 | 27 | echo "Training ..." 28 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args 29 | 30 | echo "Evaluation ..." 31 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode -------------------------------------------------------------------------------- /scripts/smoe-m.sh: -------------------------------------------------------------------------------- 1 | mkdir -p /path/to/checkpoint/directory/ 2 | 3 | args=" 4 | --data /path/to/data/directory/wikitext-103/ \ 5 | --base_arch transformer \ 6 | --architecture sgsgsgsgsgsg \ 7 | --gate_name smoe \ 8 | --nlayers 6 \ 9 | --hid-sz 352 \ 10 | --inner-hid-sz 352 \ 11 | --nheads 8 \ 12 | --block-sz 512 \ 13 | --attn-span 1024 \ 14 | --dropout 0.1 \ 15 | --load_balance 0.01 \ 16 | --optim adam \ 17 | --lr 0.0007 \ 18 | --lr-warmup 4000 \ 19 | --niter 80 \ 20 | --batch-sz 48 \ 21 | --batch-split 2 \ 22 | --nbatches 1000 \ 23 | --distributed \ 24 | --checkpoint /path/to/checkpoint/directory/smoe.pt \ 25 | " 26 | 27 | echo "Training ..." 28 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args 29 | 30 | echo "Evaluation ..." 31 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode 32 | -------------------------------------------------------------------------------- /scripts/smoe-l.sh: -------------------------------------------------------------------------------- 1 | mkdir -p /path/to/checkpoint/directory/ 2 | 3 | args=" 4 | --data /path/to/data/directory/wikitext-103/ \ 5 | --base_arch transformer \ 6 | --architecture sgsgsgsgsgsgsgsgsgsgsgsg \ 7 | --gate_name smoe \ 8 | --nlayers 12 \ 9 | --hid-sz 512 \ 10 | --inner-hid-sz 512 \ 11 | --nheads 8 \ 12 | --block-sz 1024 \ 13 | --attn-span 2048 \ 14 | --dropout 0.1 \ 15 | --load_balance 0.01 \ 16 | --optim adam \ 17 | --lr 0.0007 \ 18 | --lr-warmup 5000 \ 19 | --niter 80 \ 20 | --batch-sz 24 \ 21 | --batch-split 2 \ 22 | --nbatches 1000 \ 23 | --distributed \ 24 | --checkpoint /path/to/checkpoint/directory/smoe.pt \ 25 | " 26 | 27 | echo "Training ..." 28 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args 29 | 30 | echo "Evaluation ..." 31 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode 32 | -------------------------------------------------------------------------------- /scripts/smoe-mom-s.sh: -------------------------------------------------------------------------------- 1 | mkdir -p /path/to/checkpoint/directory/ 2 | 3 | args=" 4 | --data /path/to/data/directory/wikitext-103/ \ 5 | --base_arch transformer \ 6 | --architecture smsmsm \ 7 | --gate_name smoe \ 8 | --nlayers 3 \ 9 | --hid-sz 128 \ 10 | --inner-hid-sz 128 \ 11 | --nheads 8 \ 12 | --block-sz 256 \ 13 | --attn-span 256 \ 14 | --dropout 0.7 \ 15 | --load_balance 0.01 \ 16 | --optim adam \ 17 | --lr 0.0007 \ 18 | --lr-warmup 3000 \ 19 | --niter 60 \ 20 | --batch-sz 96 \ 21 | --batch-split 2 \ 22 | --nbatches 1000 \ 23 | --distributed \ 24 | --gamma1 0.8 \ 25 | --gamma2 1.0 \ 26 | --mu 0.7 \ 27 | --beta1 0.9 \ 28 | --beta2 0.999 \ 29 | --checkpoint /path/to/checkpoint/directory/smoe.pt \ 30 | " 31 | 32 | echo "Training ..." 33 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args 34 | 35 | echo "Evaluation ..." 36 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode -------------------------------------------------------------------------------- /scripts/glam-adam-m.sh: -------------------------------------------------------------------------------- 1 | mkdir -p /path/to/checkpoint/directory/ 2 | 3 | args=" 4 | --data /path/to/data/directory/wikitext-103/ \ 5 | --base_arch glam \ 6 | --architecture sasfsasfsasfsasfsasfsasf \ 7 | --gate_name smoe \ 8 | --nlayers 6 \ 9 | --hid-sz 352 \ 10 | --inner-hid-sz 352 \ 11 | --nheads 8 \ 12 | --block-sz 512 \ 13 | --attn-span 2048 \ 14 | --dropout 0.1 \ 15 | --load_balance 0.01 \ 16 | --optim adam \ 17 | --lr 0.00007 \ 18 | --lr-warmup 4000 \ 19 | --niter 120 \ 20 | --batch-sz 48 \ 21 | --batch-split 2 \ 22 | --nbatches 1000 \ 23 | --distributed \ 24 | --gamma1 0.8 \ 25 | --gamma2 1.0 \ 26 | --mu 0.7 \ 27 | --beta1 0.9 \ 28 | --beta2 0.999 \ 29 | --checkpoint /path/to/checkpoint/directory/smoe.pt \ 30 | " 31 | 32 | echo "Training ..." 33 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args 34 | 35 | echo "Evaluation ..." 36 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode -------------------------------------------------------------------------------- /scripts/glam-mom-m.sh: -------------------------------------------------------------------------------- 1 | mkdir -p /path/to/checkpoint/directory/ 2 | 3 | args=" 4 | --data /path/to/data/directory/wikitext-103/ \ 5 | --base_arch glam \ 6 | --architecture smsfsmsfsmsfsmsfsmsfsmsf \ 7 | --gate_name smoe \ 8 | --nlayers 6 \ 9 | --hid-sz 352 \ 10 | --inner-hid-sz 352 \ 11 | --nheads 8 \ 12 | --block-sz 512 \ 13 | --attn-span 2048 \ 14 | --dropout 0.1 \ 15 | --load_balance 0.01 \ 16 | --optim adam \ 17 | --lr 0.00007 \ 18 | --lr-warmup 4000 \ 19 | --niter 120 \ 20 | --batch-sz 48 \ 21 | --batch-split 2 \ 22 | --nbatches 1000 \ 23 | --distributed \ 24 | --gamma1 0.8 \ 25 | --gamma2 1.0 \ 26 | --mu 0.7 \ 27 | --beta1 0.9 \ 28 | --beta2 0.999 \ 29 | --checkpoint /path/to/checkpoint/directory/smoe.pt \ 30 | " 31 | 32 | echo "Training ..." 33 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args 34 | 35 | echo "Evaluation ..." 36 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode -------------------------------------------------------------------------------- /scripts/smoe-adam-m.sh: -------------------------------------------------------------------------------- 1 | mkdir -p /path/to/checkpoint/directory/ 2 | 3 | args=" 4 | --data /path/to/data/directory/wikitext-103/ \ 5 | --base_arch transformer \ 6 | --architecture sasasasasasa \ 7 | --gate_name smoe \ 8 | --nlayers 6 \ 9 | --hid-sz 352 \ 10 | --inner-hid-sz 352 \ 11 | --nheads 8 \ 12 | --block-sz 512 \ 13 | --attn-span 1024 \ 14 | --dropout 0.1 \ 15 | --load_balance 0.01 \ 16 | --optim adam \ 17 | --lr 0.0007 \ 18 | --lr-warmup 4000 \ 19 | --niter 80 \ 20 | --batch-sz 48 \ 21 | --batch-split 2 \ 22 | --nbatches 1000 \ 23 | --distributed \ 24 | --gamma1 1.0 \ 25 | --gamma2 1.0 \ 26 | --mu 0.7 \ 27 | --beta1 0.9 \ 28 | --beta2 0.999 \ 29 | --checkpoint /path/to/checkpoint/directory/smoe.pt \ 30 | " 31 | 32 | echo "Training ..." 33 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args 34 | 35 | echo "Evaluation ..." 36 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode 37 | -------------------------------------------------------------------------------- /scripts/smoe-mom-m.sh: -------------------------------------------------------------------------------- 1 | mkdir -p /path/to/checkpoint/directory/ 2 | 3 | args=" 4 | --data /path/to/data/directory/wikitext-103/ \ 5 | --base_arch transformer \ 6 | --architecture smsmsmsmsmsm \ 7 | --gate_name smoe \ 8 | --nlayers 6 \ 9 | --hid-sz 352 \ 10 | --inner-hid-sz 352 \ 11 | --nheads 8 \ 12 | --block-sz 512 \ 13 | --attn-span 1024 \ 14 | --dropout 0.1 \ 15 | --load_balance 0.01 \ 16 | --optim adam \ 17 | --lr 0.0007 \ 18 | --lr-warmup 4000 \ 19 | --niter 80 \ 20 | --batch-sz 48 \ 21 | --batch-split 2 \ 22 | --nbatches 1000 \ 23 | --distributed \ 24 | --gamma1 1.0 \ 25 | --gamma2 1.0 \ 26 | --mu 0.7 \ 27 | --beta1 0.9 \ 28 | --beta2 0.999 \ 29 | --checkpoint /path/to/checkpoint/directory/smoe.pt \ 30 | " 31 | 32 | echo "Training ..." 33 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args 34 | 35 | echo "Evaluation ..." 36 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode 37 | -------------------------------------------------------------------------------- /scripts/smoe-mom-l.sh: -------------------------------------------------------------------------------- 1 | mkdir -p /path/to/checkpoint/directory/ 2 | 3 | args=" 4 | --data /path/to/data/directory/wikitext-103/ \ 5 | --base_arch transformer \ 6 | --architecture smsmsmsmsmsmsmsmsmsmsmsm \ 7 | --gate_name smoe \ 8 | --nlayers 12 \ 9 | --hid-sz 512 \ 10 | --inner-hid-sz 512 \ 11 | --nheads 8 \ 12 | --block-sz 1024 \ 13 | --attn-span 2048 \ 14 | --dropout 0.1 \ 15 | --load_balance 0.01 \ 16 | --optim adam \ 17 | --lr 0.0007 \ 18 | --lr-warmup 5000 \ 19 | --niter 80 \ 20 | --batch-sz 24 \ 21 | --batch-split 2 \ 22 | --nbatches 1000 \ 23 | --distributed \ 24 | --gamma1 0.8 \ 25 | --gamma2 1.0 \ 26 | --mu 0.7 \ 27 | --beta1 0.9 \ 28 | --beta2 0.999 \ 29 | --checkpoint /path/to/checkpoint/directory/smoe.pt \ 30 | " 31 | 32 | echo "Training ..." 33 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args 34 | 35 | echo "Evaluation ..." 36 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode 37 | -------------------------------------------------------------------------------- /custom_utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import math, random 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | # pylint: disable=broad-except 9 | # pylint: disable=protected-access 10 | def get_torch_default_comm(): 11 | r""" 12 | The NCCL communicator is needed so that Fast MoE can perform customized 13 | communication operators in the C code. However, it is not a publicly 14 | available variable. Therefore, a hacking class of the `ProcessGroupNCCL` 15 | in Fast MoE's C code takes the `_default_pg` and tries to dig the 16 | communicator out from the object. As PyTorch's private interface varies from 17 | time to time, different hacking techniques are tried one-by-one to be 18 | compatible with various versions of PyTorch. 19 | """ 20 | try: 21 | comm = dist.distributed_c10d._get_default_group() 22 | return comm 23 | except Exception as _: 24 | pass 25 | try: 26 | comm = dist.distributed_c10d._default_pg 27 | if comm is not None: 28 | return comm 29 | except Exception as _: 30 | pass 31 | raise RuntimeError("Unsupported PyTorch version") 32 | -------------------------------------------------------------------------------- /gates/naive_gate.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from .base_gate import BaseGate 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class NaiveGate(BaseGate): 9 | r""" 10 | A naive gate implementation that defines the standard behavior of the gate 11 | which determines which experts the tokens are going to. 12 | Both the indicies and the score, or confidence, are output to the parent 13 | module. 14 | The load-balance strategies are also designed to be implemented within the 15 | `Gate` module. 16 | """ 17 | 18 | def __init__(self, d_model, num_expert, world_size, top_k=2): 19 | super().__init__(num_expert, world_size) 20 | self.gate = nn.Linear(d_model, self.tot_expert) 21 | self.top_k = top_k 22 | 23 | def forward(self, inp, return_all_scores=False): 24 | r""" 25 | The naive implementation simply calculates the top-k of a linear layer's 26 | output. 27 | """ 28 | gate = self.gate(inp) 29 | gate_top_k_val, gate_top_k_idx = torch.topk( 30 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 31 | ) # [.. x top_k] 32 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) 33 | 34 | # (BxL) x 1 x top_k 35 | gate_score = F.softmax(gate_top_k_val, dim=-1) 36 | 37 | if return_all_scores: 38 | return gate_top_k_idx, gate_score, gate 39 | return gate_top_k_idx, gate_score -------------------------------------------------------------------------------- /gates/swipe_gate.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import math 3 | import torch 4 | import torch.distributed as dist 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from .naive_gate import NaiveGate 8 | 9 | from fmoe.functions import count_by_gate 10 | import fmoe_cuda as fmoe_native 11 | 12 | class SwipeGate(NaiveGate): 13 | def __init__(self, d_model, num_expert, world_size, top_k=2): 14 | super().__init__(d_model, num_expert, world_size, top_k) 15 | 16 | def swipe_once(self, idx, capacity, bias): 17 | with torch.no_grad(): 18 | idx_new, capacity = fmoe_native.swipe_once(idx, capacity, 19 | self.num_expert, self.world_size, bias) 20 | idx_new = idx_new.to(idx.device) 21 | return idx_new, capacity 22 | 23 | def forward(self, inp): 24 | score = self.gate(inp) 25 | orig_score, orig_idx = torch.topk(score, k=self.top_k, dim=-1) 26 | 27 | if not self.training: 28 | topk_val = F.softmax(orig_score, dim=-1) 29 | return orig_idx, topk_val 30 | 31 | capacity = torch.scalar_tensor(inp.shape[0] * self.top_k, 32 | dtype=torch.long) 33 | 34 | topk_idxs = [] 35 | topk_vals = [] 36 | idx_x = torch.arange(inp.shape[0], device=inp.device) 37 | for k in range(self.top_k): 38 | idx, capacity = self.swipe_once(orig_idx[:, k], capacity, 39 | k % self.num_expert) 40 | topk_vals.append(score[idx_x, idx]) 41 | topk_idxs.append(idx) 42 | topk_idx = torch.stack(topk_idxs).transpose(0, 1) 43 | topk_val = torch.stack(topk_vals).transpose(0, 1) 44 | topk_val = F.softmax(topk_val, dim=-1) 45 | return topk_idx, topk_val -------------------------------------------------------------------------------- /fastermoe/expert_utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | 4 | def get_expert_param_size(e): 5 | return sum(map(lambda x: x.numel(), e.parameters())) 6 | 7 | 8 | def get_expert_params(e, out): 9 | offset = 0 10 | for n, p in e.named_parameters(): 11 | seg = out[offset:offset + p.numel()] 12 | offset += p.numel() 13 | seg.copy_(p.data.flatten()) 14 | 15 | def stash_expert_params(e, params): 16 | if not hasattr(e, 'expert_param_stash'): 17 | setattr(e, 'expert_param_stash', dict()) 18 | offset = 0 19 | for n, p in e.named_parameters(): 20 | if n not in e.expert_param_stash: 21 | e.expert_param_stash[n] = p.data.clone() 22 | with torch.no_grad(): 23 | seg = params[offset:offset + p.numel()] 24 | offset += p.numel() 25 | p.copy_(seg.reshape(p.shape)) 26 | 27 | def pop_expert_params(e): 28 | if not hasattr(e, 'expert_param_stash'): 29 | return 30 | for n, p in e.named_parameters(): 31 | with torch.no_grad(): 32 | p.copy_(e.expert_param_stash[n]) 33 | e.expert_param_stash.clear() 34 | 35 | def collect_expert_grads(e, grads): 36 | offset = 0 37 | for _, p in e.named_parameters(): 38 | seg = grads[offset:offset + p.numel()] 39 | offset += p.numel() 40 | if p.grad is not None: 41 | seg.copy_(p.grad.flatten()) 42 | p.grad = None 43 | else: 44 | seg.zero_() 45 | 46 | def set_grads(e, grads): 47 | offset = 0 48 | for n, p in e.named_parameters(): 49 | seg = grads[offset:offset + p.numel()] 50 | offset += p.numel() 51 | if p.grad is None: 52 | p.grad = seg.clone() 53 | else: 54 | p.grad += seg.reshape(p.shape) -------------------------------------------------------------------------------- /gates/gshard_gate.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from .naive_gate import NaiveGate 6 | from .utils import limit_by_capacity 7 | 8 | class GShardGate(NaiveGate): 9 | def __init__(self, d_model, num_expert, world_size, 10 | topk=2, capacity=(1.2, 2.4), random_routing=True): 11 | assert topk == 2, 'topk should be 2 in gshard' 12 | super().__init__(d_model, num_expert, world_size, top_k=2) 13 | self.capacity = capacity 14 | self.random_routing = random_routing 15 | 16 | def forward(self, x): 17 | naive_outs = super().forward(x, return_all_scores=True) 18 | topk_idx, topk_val, gate_score = naive_outs 19 | 20 | S = gate_score.shape[0] 21 | top_k = topk_idx.shape[0] // gate_score.shape[0] 22 | top1_idx = topk_idx.view((-1, top_k))[:, 0] 23 | c_e = torch.scatter_add( 24 | torch.zeros(self.tot_expert, device=top1_idx.device), 25 | 0, 26 | top1_idx, 27 | torch.ones_like(top1_idx, dtype=torch.float), 28 | ) / S 29 | m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0) 30 | loss = torch.mean(c_e * m_e) * (self.num_expert ** 2) 31 | self.set_loss(loss) 32 | 33 | cap_rate = self.capacity[0 if self.training else 1] 34 | capacity = math.ceil(cap_rate * x.shape[0]) 35 | _new_lec, _new_gec, topk_idx = limit_by_capacity( 36 | topk_idx, self.num_expert, self.world_size, capacity) 37 | 38 | if self.random_routing: 39 | rand_routing_prob = torch.rand(gate_score.size(0), device=x.device) 40 | mask = (2 * topk_val[:, 1] < rand_routing_prob) 41 | topk_idx[:, 1].masked_fill_(mask, -1) 42 | 43 | return topk_idx, topk_val -------------------------------------------------------------------------------- /.github/workflows/typecheck.yaml: -------------------------------------------------------------------------------- 1 | name: Typecheck 2 | 3 | # These checks will run if at least one file is outside of the `paths-ignore` 4 | # list, but will be skipped if *all* files are in the `paths-ignore` list. 5 | # 6 | # Fore more info, see: 7 | # https://docs.github.com/en/actions/writing-workflows/workflow-syntax-for-github-actions#example-excluding-paths 8 | 9 | on: 10 | push: 11 | branches: 12 | - 'main' 13 | paths-ignore: 14 | - '**.md' 15 | 16 | pull_request: 17 | branches: 18 | - 'main' 19 | paths-ignore: 20 | - '**.md' 21 | 22 | jobs: 23 | test: 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | os: [ 'ubuntu-24.04' ] 28 | python: [ '3.10' ] 29 | 30 | runs-on: ${{ matrix.os }} 31 | name: Python ${{ matrix.python }} on ${{ matrix.os }} 32 | 33 | steps: 34 | - name: Checkout the repo 35 | uses: actions/checkout@v4 36 | 37 | - name: Setup Python 38 | uses: actions/setup-python@v5 39 | with: 40 | python-version: ${{ matrix.python }} 41 | cache: 'pip' 42 | 43 | - name: Update pip 44 | run: python -m pip install --upgrade pip 45 | 46 | - name: Install Python deps 47 | run: python -m pip install -r requirements.txt 48 | 49 | - name: Install Mypy 50 | run: python -m pip install --upgrade mypy 51 | 52 | - name: Check types with Mypy 53 | run: python -m mypy --python-version=${{ matrix.python }} . 54 | # TODO: fix the type checking errors and remove this line to make errors 55 | # obvious by failing the test. 56 | continue-on-error: true 57 | 58 | - name: Install PyType 59 | run: python -m pip install --upgrade pytype 60 | 61 | - name: Check types with PyType 62 | run: python -m pytype --python-version=${{ matrix.python }} -k . 63 | # TODO: fix the type checking errors and remove this line to make errors 64 | # obvious by failing the test. 65 | continue-on-error: true 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MomentumSMoE: Integrating Momentum into Sparse Mixture of Experts 2 | MomentumSMoE: Integrating Momentum into Sparse Mixture of Experts 3 | 4 | https://arxiv.org/abs/2410.14574 5 | 6 | ### Prerequisites 7 | 8 | - pytorch 9 | - fastmoe: https://github.com/laekov/fastmoe 10 | - The toolkit supports [Weights & Biases](https://docs.wandb.ai/) for monitoring jobs. If you use it, also install `wandb`. 11 | 12 | ### Usage 13 | 14 | 15 | #### Prepare WikiText-103 Datasets: 16 | 17 | - Download the WikiText-103 dataset from [here](https://github.com/laekov/fastmoe/blob/master/examples/transformer-xl/scripts/getdata.sh), then change bash scripts based on your local data paths. 18 | ```bash 19 | data_directory/ 20 | └── wikitext-103 21 | ├── test.txt 22 | ├── train.txt 23 | └── valid.txt 24 | ``` 25 | 26 | #### Pretraining SMoE (SwitchTransformers) on WikiText-103: 27 | 28 | ``` # WikiText-103 dataset: 29 | bash scripts/smoe-s.sh 30 | bash scripts/smoe-m.sh 31 | bash scripts/smoe-l.sh 32 | ``` 33 | 34 | #### Pretraining *Momentum*SMoE on WikiText-103: 35 | 36 | ``` # WikiText-103 dataset: 37 | bash scripts/smoe-mom-s.sh 38 | bash scripts/smoe-mom-m.sh 39 | bash scripts/smoe-mom-l.sh 40 | ``` 41 | 42 | #### Pretraining *Adam*SMoE on WikiText-103: 43 | 44 | ``` # WikiText-103 dataset: 45 | bash scripts/smoe-adam-m.sh 46 | ``` 47 | 48 | #### Pretraining GLaM on WikiText-103: 49 | 50 | ``` # WikiText-103 dataset: 51 | bash scripts/glam-m.sh 52 | ``` 53 | 54 | #### Pretraining *Momentum*GLaM on WikiText-103: 55 | 56 | ``` # WikiText-103 dataset: 57 | bash scripts/glam-mom-m.sh 58 | ``` 59 | 60 | #### Pretraining *Adam*GLaM on WikiText-103: 61 | 62 | ``` # WikiText-103 dataset: 63 | bash scripts/glam-adam-m.sh 64 | ``` 65 | 66 | #### Wandb support: 67 | - Add these flags to bash script with your project and job name 68 | ``` # Wandb: 69 | --wandb 70 | --project-name test 71 | --job-name test 72 | ``` 73 | 74 | ### Checkpoints 75 | Please find our SMoE checkpoints at . 76 | 77 | 78 | -------------------------------------------------------------------------------- /gates/switch_gate.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .naive_gate import NaiveGate 7 | from .utils import limit_by_capacity 8 | 9 | class SwitchGate(NaiveGate): 10 | r""" 11 | A switch gate implementation 12 | """ 13 | 14 | def __init__(self, d_model, num_expert, world_size, topk=1, 15 | switch_eps=.1, capacity=(1.2, 2.4)): 16 | assert topk == 1, 'topk should be 1 in switch' 17 | super().__init__(d_model, num_expert, world_size, top_k=1) 18 | self.switch_eps = switch_eps 19 | self.capacity = capacity 20 | 21 | def forward(self, inp): 22 | r""" 23 | The switch firstly conduct softmax and then calculates the top-1 24 | """ 25 | score = self.gate(inp) 26 | 27 | if self.training: 28 | # random uniform number from [1-eps, 1+eps] 29 | noise = torch.rand_like(score) 30 | noise = noise * 2 * self.switch_eps + 1.0 - self.switch_eps 31 | score += noise 32 | 33 | # fp32 softmax for numerical stability 34 | score = F.softmax(score.float(), dim=-1) 35 | 36 | top1_score, top1_idx = torch.topk( 37 | score, k=1, dim=-1, largest=True 38 | ) # [.. x top_k] 39 | top1_score = top1_score.to(dtype=inp.dtype) 40 | 41 | cap_rate = self.capacity[0 if self.training else 1] 42 | capacity = math.ceil(cap_rate * inp.shape[0]) 43 | _new_lec, _new_gec, top1_idx = limit_by_capacity( 44 | top1_idx, self.num_expert, self.world_size, capacity) 45 | 46 | valid_idx = top1_idx[top1_idx > -1] 47 | fraction_expert = torch.scatter_add( 48 | torch.zeros(self.tot_expert, device=valid_idx.device), 49 | 0, 50 | valid_idx, 51 | torch.ones_like(valid_idx, dtype=torch.float), 52 | ) / valid_idx.numel() 53 | prob_expert = score.sum(dim=0) / valid_idx.numel() 54 | loss = (fraction_expert * prob_expert).sum() * self.tot_expert 55 | self.set_loss(loss) 56 | return top1_idx, top1_score -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | import setuptools 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | import os 5 | import torch 6 | 7 | cxx_flags = [] 8 | ext_libs = [] 9 | 10 | authors = [ 11 | 'Jiaao He', 12 | 'Jiezhong Qiu', 13 | 'Aohan Zeng', 14 | 'Tiago Antunes', 15 | 'Jinjun Peng', 16 | 'Qin Li', 17 | 'Mingshu Zhai' 18 | ] 19 | 20 | is_rocm_pytorch = False 21 | if torch.__version__ >= '1.5': 22 | from torch.utils.cpp_extension import ROCM_HOME 23 | is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False 24 | 25 | if os.environ.get('USE_NCCL', '1') == '1': 26 | cxx_flags.append('-DFMOE_USE_NCCL') 27 | cxx_flags.append('-DUSE_C10D_NCCL') 28 | if is_rocm_pytorch: 29 | ext_libs.append('rccl') 30 | else: 31 | ext_libs.append('nccl') 32 | 33 | if os.environ.get('MOE_DEBUG', '0') == '1': 34 | cxx_flags.append('-DMOE_DEBUG') 35 | 36 | if is_rocm_pytorch: 37 | define_macros=[('FMOE_USE_HIP', None)] 38 | else: 39 | define_macros=[] 40 | 41 | 42 | if __name__ == '__main__': 43 | setuptools.setup( 44 | name='fastmoe', 45 | version='1.1.0', 46 | description='An efficient Mixture-of-Experts system for PyTorch', 47 | author=', '.join(authors), 48 | author_email='hja20@mails.tsinghua.edu.cn', 49 | license='Apache-2', 50 | # url='https://github.com/laekov/fastmoe', 51 | # packages=['fmoe', 'fmoe.megatron', 'fmoe.gates', 'fmoe.fastermoe'], 52 | ext_modules=[ 53 | CUDAExtension( 54 | name='fmoe_cuda', 55 | sources=[ 56 | # 'cuda/stream_manager.cpp', 57 | # 'cuda/local_exchange.cu', 58 | # 'cuda/balancing.cu', 59 | # 'cuda/global_exchange.cpp', 60 | # 'cuda/parallel_linear.cu', 61 | 'cuda/fmoe_cuda.cpp', 62 | # 'cuda/fastermoe/smart_schedule.cpp', 63 | ], 64 | define_macros=define_macros, 65 | extra_compile_args={ 66 | 'cxx': cxx_flags, 67 | 'nvcc': cxx_flags 68 | }, 69 | libraries=ext_libs 70 | ) 71 | ], 72 | cmdclass={ 73 | 'build_ext': BuildExtension 74 | }) 75 | -------------------------------------------------------------------------------- /fastermoe/shadow_policy.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torch.distributed as dist 4 | 5 | from .config import float_from_env, switch_from_env 6 | from fmoe.functions import get_moe_group 7 | 8 | def global_policy(local_expert_count, _gec, num_expert, world_size): 9 | r""" 10 | This is the policy for two-layer MLPs, using the formula in the PPoPP paper. 11 | A few parameters are used in this policy. 12 | * `d_model`: feature length of the MLP input and output. 13 | * `alpha`: the ratio of the MLP's hidden size to `d_model`. 14 | * `bw_net`: bandwidth of the network (GBps) 15 | * `bw_mm`: computation throughput of performing GeMM (FLOPs) 16 | """ 17 | bw_net = float_from_env('FMOE_FASTER_GLBPLC_NETBW', 50 * 1e9 / 8) 18 | bw_mm = float_from_env('FMOE_FASTER_GLBPLC_GPUTP', 11.5e12) 19 | alpha = float_from_env('FMOE_FASTER_GLBPLC_ALPHA', 2) 20 | d_model = float_from_env('FMOE_FASTER_GLBPLC_DMODEL', 2048) 21 | 22 | moe_group = get_moe_group() 23 | local_expert_count = local_expert_count.cuda() 24 | agecs = [torch.empty_like(local_expert_count) for _ in range(world_size)] 25 | dist.all_gather(agecs, local_expert_count, group=moe_group) 26 | all_global_expert_count = torch.stack(agecs) 27 | 28 | # TODO: data type other than float 29 | data_size = 4 30 | 31 | fwd_expert_counts = all_global_expert_count.sum(1).cpu() 32 | B_ws, indices = fwd_expert_counts.flatten().sort(0, descending=True) 33 | 34 | alphaH2 = alpha * (d_model ** 2) 35 | B_w = B_ws[0] 36 | 37 | comm = float('+inf') 38 | send_feature_time = d_model * data_size / bw_net 39 | send_model_time = 2 * alphaH2 * data_size / bw_net 40 | comp_time = 4 * alphaH2 / bw_mm 41 | lat_base = 3 * comp_time * B_w + 4 * send_feature_time * B_w 42 | 43 | res = torch.zeros(world_size * num_expert, dtype=torch.bool) 44 | shadow_time = 0 45 | 46 | for i, index in enumerate(indices): 47 | if i + 1 == indices.numel(): 48 | break 49 | B_k = B_ws[i + 1] 50 | shadow_time += send_model_time 51 | lat_new = 3 * comp_time * B_k + 4 * send_feature_time * B_k + shadow_time 52 | 53 | if lat_new < lat_base: 54 | lat_base = lat_new 55 | res[index] = True 56 | else: 57 | break 58 | return res 59 | 60 | def no_shadow_policy(_lec, _gec, num_expert, world_size): 61 | res = torch.zeros(world_size * num_expert, dtype=bool) 62 | return res 63 | 64 | def get_shadow_policy(d_model=None): 65 | if d_model is not None and 'FMOE_FASTER_GLBPLC_DMODEL' not in os.environ: 66 | os.environ['FMOE_FASTER_GLBPLC_DMODEL'] = str(d_model) 67 | if not switch_from_env('FMOE_FASTER_SHADOW_ENABLE'): 68 | return no_policy 69 | return global_policy -------------------------------------------------------------------------------- /gates/faster_gate.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from .naive_gate import NaiveGate 3 | 4 | import os 5 | import sys 6 | import torch 7 | import torch.nn.functional as F 8 | from .utils import limit_by_capacity 9 | import fmoe_cuda 10 | from fmoe.functions import count_by_gate 11 | 12 | nw_per_node = 8 13 | try: 14 | nw_per_node = int(os.environ['FMOE_TOPO_GPUS_PER_NODE']) 15 | except Exception: 16 | pass 17 | 18 | class FasterGate(NaiveGate): 19 | def __init__(self, d_model, n_expert, world_size, node_rank): 20 | super().__init__(d_model, n_expert, world_size, top_k=2) 21 | self.ne_per_node = nw_per_node * n_expert 22 | self.ogn_ratio = .14 23 | try: 24 | self.ogn_ratio = float(os.environ['FMOE_TOPO_OUTGOING_FRACTION']) 25 | except Exception: 26 | pass 27 | self.node_rank = node_rank 28 | 29 | mask = [1] * world_size * n_expert 30 | for i in range(n_expert * world_size): 31 | if i // self.ne_per_node == self.node_rank: 32 | mask[i] = 0 33 | self.mask = torch.Tensor(mask).bool() 34 | self.policy_fn = None 35 | print('node rank {} mask {}'.format(node_rank, mask)) 36 | 37 | def forward(self, inp): 38 | if self.mask.device != inp.device: 39 | self.mask = self.mask.to(inp.device) 40 | 41 | gate_score = self.gate(inp) 42 | lim_mask = self.mask 43 | 44 | top2_val, top2_idx = torch.topk(gate_score, k=2, dim=-1) 45 | S = gate_score.shape[0] 46 | top_k = 2 47 | 48 | with torch.no_grad(): 49 | top1_idx = top2_idx.view((-1, top_k))[:, 0] 50 | top1_val = top2_val.view((-1, top_k))[:, 0] 51 | c_e = torch.scatter_add( 52 | torch.zeros(self.tot_expert, device=top1_idx.device), 53 | 0, 54 | top1_idx, 55 | torch.ones_like(top1_idx, dtype=torch.float), 56 | ) / S 57 | m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0) 58 | loss = torch.mean(c_e * m_e) * (self.num_expert ** 2) 59 | self.set_loss(loss) 60 | 61 | with torch.no_grad(): 62 | if self.policy_fn is None: 63 | stored_models = torch.zeros(self.num_expert * self.world_size, 64 | dtype=torch.bool) 65 | else: 66 | # TODO: Fix this after expert shadowing is ported 67 | _, lec, aec, gec, agec = count_by_gate(top2_idx, 68 | self.num_expert, self.world_size, require_pos=False) 69 | stored_models = self.policy_fn(aec, agec, 70 | self.num_expert, self.world_size, inp.shape[-1], True) 71 | lim_mask = lim_mask & ~stored_models.view(-1).to(lim_mask.device) 72 | 73 | ogn_mask = lim_mask[top1_idx] 74 | ogn_thres = int(inp.shape[0] * self.ogn_ratio) 75 | 76 | if ogn_mask.sum().item() < ogn_thres: 77 | topk_val, topk_idx = torch.topk(gate_score, k=self.top_k) 78 | topk_val = F.softmax(topk_val, dim=-1) 79 | return topk_idx, topk_val 80 | 81 | with torch.no_grad(): 82 | top1_val[~ogn_mask] = float('-inf') 83 | _, top_ogn = torch.topk(top1_val.view(-1), k=ogn_thres) 84 | cand = gate_score.clone() 85 | cand[:, lim_mask] = float('-inf') 86 | _, topk_idx = torch.topk(cand, k=self.top_k) 87 | topk_idx[top_ogn, 1] = top1_idx.view(-1)[top_ogn] 88 | 89 | idx_x = torch.arange(inp.shape[0], device=inp.device).repeat_interleave(2) 90 | topk_val = gate_score[idx_x, topk_idx.view(-1)].view(-1, self.top_k) 91 | 92 | topk_val = F.softmax(topk_val, dim=-1) 93 | 94 | return topk_idx, topk_val 95 | 96 | def gen_faster_gate(rank): 97 | def _gen(d_model, n_expert, world_size, top_k=2): 98 | assert top_k == 2 99 | return FasterGate(d_model, n_expert, world_size, rank // nw_per_node) 100 | return _gen -------------------------------------------------------------------------------- /custom_transformer.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import math, random 4 | import torch 5 | import torch.nn as nn 6 | from custom_layers import FMoE 7 | from custom_layers import FMoELinear 8 | from custom_layers_opt import FMoEOpt 9 | 10 | 11 | class _Expert(nn.Module): 12 | r""" 13 | An expert using 2 FMoELinear modules to speed up the computation of experts 14 | within one worker. 15 | """ 16 | 17 | def __init__(self, num_expert, d_model, d_hidden, activation, rank=0): 18 | super().__init__() 19 | self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank) 20 | self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank) 21 | self.activation = activation 22 | 23 | def forward(self, inp, fwd_expert_count): 24 | r""" 25 | First expand input to 4h (the hidden size is variable, but is called h4 26 | for convenience). Then perform activation. Finally shirink back to h. 27 | """ 28 | x = self.htoh4(inp, fwd_expert_count) 29 | x = self.activation(x) 30 | x = self.h4toh(x, fwd_expert_count) 31 | return x 32 | 33 | 34 | class FMoETransformerMLP(FMoE): 35 | r""" 36 | A complete MoE MLP module in a Transformer block. 37 | * `activation` is the activation function to be used in MLP in each expert. 38 | * `d_hidden` is the dimension of the MLP layer. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | num_expert=32, 44 | d_model=1024, 45 | d_hidden=4096, 46 | activation=torch.nn.GELU(), 47 | expert_dp_comm="none", 48 | expert_rank=0, 49 | moe_top_k=2, 50 | **kwargs 51 | ): 52 | super().__init__( 53 | num_expert=num_expert, d_model=d_model, moe_top_k=moe_top_k, **kwargs 54 | ) 55 | self.experts = _Expert( 56 | num_expert, d_model, d_hidden, activation, rank=expert_rank 57 | ) 58 | self.mark_parallel_comm(expert_dp_comm) 59 | 60 | def forward(self, inp: torch.Tensor): 61 | r""" 62 | This module wraps up the FMoE module with reshape, residual and layer 63 | normalization. 64 | """ 65 | original_shape = inp.shape 66 | inp = inp.reshape(-1, self.d_model) 67 | output = super().forward(inp) 68 | return output.reshape(original_shape) 69 | 70 | 71 | class FMoETransformerMLPOpt(FMoEOpt): 72 | r""" 73 | A complete MoE MLP module in a Transformer block. 74 | * `activation` is the activation function to be used in MLP in each expert. 75 | * `d_hidden` is the dimension of the MLP layer. 76 | """ 77 | 78 | def __init__( 79 | self, 80 | num_expert=32, 81 | d_model=1024, 82 | d_hidden=4096, 83 | activation=torch.nn.GELU(), 84 | expert_dp_comm="none", 85 | expert_rank=0, 86 | moe_top_k=2, 87 | freq=0.0, 88 | alpha=0.0, 89 | act_experts="shuffle", 90 | g_blance=False, 91 | opt_blance=False, 92 | combine_gate=False, 93 | opt_loss="mse", 94 | **kwargs 95 | ): 96 | super().__init__( 97 | num_expert=num_expert, 98 | d_model=d_model, 99 | moe_top_k=moe_top_k, 100 | freq=freq, 101 | alpha=alpha, 102 | act_experts=act_experts, 103 | g_blance=g_blance, 104 | opt_blance=opt_blance, 105 | combine_gate=combine_gate, 106 | opt_loss=opt_loss, 107 | **kwargs 108 | ) 109 | self.experts = _Expert( 110 | num_expert, d_model, d_hidden, activation, rank=expert_rank 111 | ) 112 | self.mark_parallel_comm(expert_dp_comm) 113 | 114 | def forward(self, inp: torch.Tensor): 115 | r""" 116 | This module wraps up the FMoE module with reshape, residual and layer 117 | normalization. 118 | """ 119 | original_shape = inp.shape 120 | inp = inp.reshape(-1, self.d_model) 121 | output = super().forward(inp) 122 | return output.reshape(original_shape) 123 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import math, random 4 | import torch 5 | import tqdm 6 | 7 | 8 | def _tokenize(text_path, dictionary_to_update): 9 | """Tokenizes a text file.""" 10 | print("Tokenizing {}".format(text_path)) 11 | assert os.path.exists(text_path) 12 | 13 | nb_tokens_in_dictionary = len(dictionary_to_update) 14 | 15 | # Count nb of tokens in text and update the dictionary 16 | with open(text_path, "r", encoding="utf8") as f: 17 | for line in f: 18 | tokens = line.split() + [""] 19 | for token in tokens: 20 | if token not in dictionary_to_update: 21 | dictionary_to_update[token] = nb_tokens_in_dictionary 22 | nb_tokens_in_dictionary += 1 23 | 24 | # Assign to each token its identifier 25 | ids = [] 26 | with open(text_path, "r", encoding="utf8") as f: 27 | for line in f: 28 | tokens = line.split() + [""] 29 | for token in tokens: 30 | ids.append(dictionary_to_update[token]) 31 | ids = torch.LongTensor(ids) 32 | return ids 33 | 34 | 35 | class Corpus: 36 | def __init__(self, data_path): 37 | self._dictionary = {} 38 | self.train = _tokenize( 39 | text_path=os.path.join(data_path, "train.txt"), 40 | dictionary_to_update=self._dictionary, 41 | ) 42 | self.valid = _tokenize( 43 | text_path=os.path.join(data_path, "valid.txt"), 44 | dictionary_to_update=self._dictionary, 45 | ) 46 | self.test = _tokenize( 47 | text_path=os.path.join(data_path, "test.txt"), 48 | dictionary_to_update=self._dictionary, 49 | ) 50 | 51 | @property 52 | def vocab_size(self): 53 | return len(self._dictionary) 54 | 55 | 56 | def _batchify(data_tensor, batch_size): 57 | nb_batches = data_tensor.size(0) // batch_size 58 | # trim away some tokens to make whole batches 59 | data_tensor = data_tensor.narrow(0, 0, nb_batches * batch_size) 60 | data_tensor = data_tensor.view(batch_size, -1).contiguous() 61 | return data_tensor 62 | 63 | 64 | def _build_corpus(data_path, env_params, data_name=None): 65 | # save the corpus to a file so that it's faster next time 66 | corpus_path = os.path.join(data_path, "corpus.pt") 67 | if os.path.exists(corpus_path): 68 | print("Loading an existing corpus file from {}".format(corpus_path)) 69 | corpus = torch.load(corpus_path) 70 | else: 71 | print("Creating a corpus file at {}".format(corpus_path)) 72 | if env_params["distributed"]: 73 | # only one process need to create a corpus file 74 | if env_params["rank"] == 0: 75 | corpus = Corpus(data_path) 76 | torch.save(corpus, corpus_path) 77 | # sync with other processes 78 | torch.distributed.broadcast(torch.zeros(1).cuda(), src=0) 79 | else: 80 | print("Waiting rank0 to create a corpus file.") 81 | # sync with rank0 82 | torch.distributed.broadcast(torch.zeros(1).cuda(), src=0) 83 | corpus = torch.load(corpus_path) 84 | else: 85 | corpus = Corpus(data_path) 86 | torch.save(corpus, corpus_path) 87 | return corpus 88 | 89 | 90 | def _get_train_val_test_data(corpus, batch_size): 91 | return [ 92 | _batchify(corpus.train, batch_size), 93 | _batchify(corpus.valid, batch_size), 94 | _batchify(corpus.test, batch_size), 95 | ] 96 | 97 | 98 | def get_train_val_test_data(data_params, env_params, batch_size, device): 99 | corpus = _build_corpus(**data_params, env_params=env_params) 100 | data_params["vocab_size"] = corpus.vocab_size 101 | train_data, val_data, test_data = _get_train_val_test_data( 102 | corpus=corpus, batch_size=batch_size 103 | ) 104 | 105 | if env_params["distributed"]: 106 | # split the data into equal parts 107 | assert batch_size % env_params["world_size"] == 0 108 | device_batch_size = batch_size // env_params["world_size"] 109 | slice_data = slice( 110 | device_batch_size * env_params["rank"], 111 | device_batch_size * (env_params["rank"] + 1), 112 | ) 113 | train_data = train_data[slice_data] 114 | val_data = val_data[slice_data] 115 | test_data = test_data[slice_data] 116 | 117 | train_data = train_data.to(device) 118 | val_data = val_data.to(device) 119 | test_data = test_data.to(device) 120 | return train_data, val_data, test_data 121 | -------------------------------------------------------------------------------- /fastermoe/schedule.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | from torch.autograd.function import Function 4 | 5 | from fmoe.functions import prepare_forward, ensure_comm 6 | from fmoe.functions import _local_scatter, _local_gather 7 | import fmoe_cuda as fmoe_native 8 | from fmoe.fastermoe import expert_utils 9 | 10 | from .shadow_policy import get_shadow_policy 11 | 12 | class MoEForward(Function): 13 | @staticmethod 14 | def forward( 15 | ctx, 16 | expert_fn, 17 | experts, 18 | inp, # models, 19 | pos_s, pos_g, 20 | local_expert_count, global_expert_count, 21 | stored_models, 22 | fwd_batch_size, out_batch_size, 23 | world_size): 24 | local_input_buf = _local_scatter(inp, pos_s) 25 | 26 | ctx.gibs = [None] * (world_size * 2) 27 | ctx.gobs = [None] * (world_size * 2) 28 | def _expert_forward(x, y, idx): 29 | nothing = lambda a: a 30 | x = x.data 31 | with torch.enable_grad(): 32 | x.requires_grad = True 33 | # To skip torch autograd's version check. 34 | with torch.autograd.graph.saved_tensors_hooks(nothing, nothing): 35 | y0 = expert_fn(x, [x.shape[0]]) 36 | ctx.gibs[idx] = x 37 | ctx.gobs[idx] = y0 38 | y.copy_(y0) 39 | 40 | ctx.experts = experts 41 | if stored_models.any(): 42 | ctx.expert_size = expert_utils.get_expert_param_size(experts) 43 | else: 44 | ctx.expert_size = 0 45 | get_param_fn = lambda out: expert_utils.get_expert_params(experts, out) 46 | pop_fn = lambda: expert_utils.pop_expert_params(experts) 47 | ctx.shadows = [None] * world_size 48 | def stash_fn(params, idx): 49 | expert_utils.stash_expert_params(experts, params) 50 | ctx.shadows[idx] = params 51 | 52 | local_output_buf, gib = fmoe_native.smart_sch_forward( 53 | local_input_buf, 54 | local_expert_count, global_expert_count, 55 | stored_models, fwd_batch_size, ctx.expert_size, 56 | world_size, _expert_forward, get_param_fn, stash_fn, pop_fn) 57 | 58 | out = _local_gather(local_output_buf, pos_g, out_batch_size, 59 | maybe_overlap=False) 60 | 61 | # gib and local_input_buf are necessary, because ctx.gibs are created 62 | # based on their memory 63 | variables = (pos_s, pos_g, local_expert_count, global_expert_count, 64 | stored_models, gib, local_input_buf) 65 | 66 | ctx.moe_args = fwd_batch_size, inp.shape[0], world_size 67 | ctx.save_for_backward(*variables) 68 | 69 | return out 70 | 71 | @staticmethod 72 | def backward(ctx, grad_out): 73 | (pos_s, pos_g, local_expert_count, global_expert_count, 74 | stored_models, _1, _2) = ctx.saved_tensors 75 | (fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args 76 | 77 | def _expert_backward(grad_y, grad_x, idx): 78 | y = ctx.gobs[idx] 79 | x = ctx.gibs[idx] 80 | torch.autograd.backward([y], [grad_y]) 81 | grad_x.copy_(x.grad) 82 | 83 | experts = ctx.experts 84 | def stash_fn(idx): 85 | expert_utils.stash_expert_params(experts, ctx.shadows[idx]) 86 | pop_fn = lambda: expert_utils.pop_expert_params(experts) 87 | def collect_fn(idx, root): 88 | grad = ctx.shadows[idx] 89 | expert_utils.collect_expert_grads(experts, grad) 90 | fmoe_native.reduce_grad(grad, root, ctx.expert_size) 91 | set_grad_fn = lambda idx: expert_utils.set_grads(experts, ctx.shadows[idx]) 92 | 93 | grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g) 94 | grad_in_buf = fmoe_native.smart_sch_backward( 95 | grad_out_buf, 96 | local_expert_count, global_expert_count, 97 | stored_models, 98 | pos_s.shape[0], fwd_batch_size, 99 | world_size, 100 | _expert_backward, stash_fn, pop_fn, collect_fn, set_grad_fn) 101 | grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size) 102 | 103 | return (None, None, grad_in, None, None, None, None, None, None, None, None) 104 | 105 | policy_fn = None 106 | 107 | def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None, stored_models=None): 108 | # TODO: Using multiple tensors as input is to be supported. 109 | assert(isinstance(inp, torch.Tensor)) 110 | # TODO: Support many experts on each process 111 | assert(n_expert == 1) 112 | ( 113 | pos, 114 | local_expert_count, 115 | global_expert_count, 116 | fwd_expert_count, 117 | fwd_batch_size, 118 | ) = prepare_forward(gate, n_expert, world_size) 119 | 120 | global policy_fn 121 | if policy_fn is None: 122 | policy_fn = get_shadow_policy(d_model=inp.shape[-1]) 123 | 124 | if stored_models is None: 125 | stored_models = policy_fn(local_expert_count, global_expert_count, 126 | n_expert, world_size) 127 | 128 | topk = 1 129 | if len(gate.shape) == 2: 130 | topk = gate.shape[1] 131 | out_batch_size = inp.shape[0] * topk 132 | 133 | return MoEForward.apply(expert_fn, experts, inp, 134 | torch.div(pos, topk, rounding_mode='floor'), pos, 135 | local_expert_count, global_expert_count, stored_models, 136 | fwd_batch_size, out_batch_size, world_size) -------------------------------------------------------------------------------- /gates/noisy_gate.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from .base_gate import BaseGate 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.distributions.normal import Normal 8 | import math 9 | 10 | class NoisyGate(BaseGate): 11 | def __init__(self, d_model, num_expert, world_size, top_k=2): 12 | super().__init__(num_expert, world_size) 13 | self.w_gate = nn.Parameter( 14 | torch.zeros(d_model, self.tot_expert), requires_grad=True 15 | ) 16 | self.w_noise = nn.Parameter( 17 | torch.zeros(d_model, self.tot_expert), requires_grad=True 18 | ) 19 | self.top_k = top_k 20 | self.softplus = nn.Softplus() 21 | self.softmax = nn.Softmax(1) 22 | 23 | self.noise_epsilon = 1e-2 24 | 25 | self.reset_parameters() 26 | 27 | def reset_parameters(self): 28 | # Approach is the same as in torch.nn.Linear 29 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88 30 | 31 | torch.nn.init.kaiming_uniform_(self.w_gate, a=math.sqrt(5)) 32 | torch.nn.init.kaiming_uniform_(self.w_noise, a=math.sqrt(5)) 33 | 34 | def _gates_to_load(self, gates): 35 | """Compute the true load per expert, given the gates. 36 | The load is the number of examples for which the corresponding gate is >0. 37 | Args: 38 | gates: a `Tensor` of shape [batch_size, n] 39 | Returns: 40 | a float32 `Tensor` of shape [n] 41 | """ 42 | return (gates > 0).sum(0) 43 | 44 | def _prob_in_top_k( 45 | self, clean_values, noisy_values, noise_stddev, noisy_top_values 46 | ): 47 | """Helper function to NoisyTopKGating. 48 | Computes the probability that value is in top k, given different random noise. 49 | This gives us a way of backpropagating from a loss that balances the number 50 | of times each expert is in the top k experts per example. 51 | In the case of no noise, pass in None for noise_stddev, and the result will 52 | not be differentiable. 53 | Args: 54 | clean_values: a `Tensor` of shape [batch, n]. 55 | noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus 56 | normally distributed noise with standard deviation noise_stddev. 57 | noise_stddev: a `Tensor` of shape [batch, n], or None 58 | noisy_top_values: a `Tensor` of shape [batch, m]. 59 | "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 60 | Returns: 61 | a `Tensor` of shape [batch, n]. 62 | """ 63 | 64 | batch = clean_values.size(0) 65 | m = noisy_top_values.size(1) 66 | top_values_flat = noisy_top_values.flatten() 67 | threshold_positions_if_in = ( 68 | torch.arange(batch, device=clean_values.device) * m + self.top_k 69 | ) 70 | threshold_if_in = torch.unsqueeze( 71 | torch.gather(top_values_flat, 0, threshold_positions_if_in), 1 72 | ) 73 | is_in = torch.gt(noisy_values, threshold_if_in) 74 | threshold_positions_if_out = threshold_positions_if_in - 1 75 | threshold_if_out = torch.unsqueeze( 76 | torch.gather(top_values_flat, 0, threshold_positions_if_out), 1 77 | ) 78 | # is each value currently in the top k. 79 | normal = Normal( 80 | torch.tensor([0.0], device=clean_values.device), 81 | torch.tensor([1.0], device=clean_values.device), 82 | ) 83 | prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev) 84 | prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev) 85 | prob = torch.where(is_in, prob_if_in, prob_if_out) 86 | return prob 87 | 88 | def cv_squared(self, x): 89 | """The squared coefficient of variation of a sample. 90 | Useful as a loss to encourage a positive distribution to be more uniform. 91 | Epsilons added for numerical stability. 92 | Returns 0 for an empty Tensor. 93 | Args: 94 | x: a `Tensor`. 95 | Returns: 96 | a `Scalar`. 97 | """ 98 | eps = 1e-10 99 | # if only num_expert = 1 100 | if x.shape[0] == 1: 101 | return torch.Tensor([0]) 102 | return x.float().var() / (x.float().mean() ** 2 + eps) 103 | 104 | def forward(self, inp): 105 | clean_logits = inp @ self.w_gate 106 | raw_noise_stddev = inp @ self.w_noise 107 | noise_stddev = ( 108 | self.softplus(raw_noise_stddev) + self.noise_epsilon 109 | ) * self.training 110 | noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) 111 | logits = noisy_logits 112 | 113 | # calculate topk + 1 that will be needed for the noisy gates 114 | top_logits, top_indices = logits.topk( 115 | min(self.top_k + 1, self.tot_expert), dim=1 116 | ) 117 | top_k_logits = top_logits[:, : self.top_k] 118 | top_k_indices = top_indices[:, : self.top_k] 119 | top_k_gates = self.softmax(top_k_logits) 120 | 121 | zeros = torch.zeros_like(logits, requires_grad=True) 122 | gates = zeros.scatter(1, top_k_indices, top_k_gates) 123 | 124 | if self.top_k < self.tot_expert: 125 | load = ( 126 | self._prob_in_top_k( 127 | clean_logits, noisy_logits, noise_stddev, top_logits 128 | ) 129 | ).sum(0) 130 | else: 131 | load = self._gates_to_load(gates) 132 | 133 | importance = gates.sum(0) 134 | loss = self.cv_squared(importance) + self.cv_squared(load) 135 | self.set_loss(loss) 136 | 137 | return ( 138 | top_k_indices.contiguous().view(-1), 139 | top_k_gates.contiguous().unsqueeze(1), 140 | ) -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import math, random 4 | import torch 5 | import tqdm 6 | 7 | from custom_gates import * 8 | 9 | 10 | def _train_step(model, load_balance, X, Y, h_cache, eval_only, loss_div=1): 11 | """Single training step.""" 12 | 13 | out, h_cache = model(X, h_cache) 14 | out = out.view(-1, out.size(-1)) 15 | loss = torch.nn.functional.nll_loss(out, Y.view(-1)) 16 | loss_value = loss.item() / loss_div 17 | 18 | if not eval_only: 19 | # loss term from adaptive-span 20 | if model.module.layers[0].attn.attn.adapt_span_enabled: 21 | loss += sum( 22 | model.module.layers[layer_i].attn.attn.adaptive_span.get_loss() 23 | for layer_i in range(model.module.attn_layer_count) 24 | ) 25 | 26 | if load_balance > 0: 27 | balance_loss = 0 28 | for name, m in model.named_modules(): 29 | if isinstance(m, CustomNaiveGate_Balance_SMoE) or isinstance( 30 | m, CustomNaiveGate_Balance_XMoE 31 | ): 32 | if m.loss is not None: 33 | balance_loss += m.loss 34 | loss += load_balance * balance_loss 35 | (loss / loss_div).backward(retain_graph=True) 36 | return loss_value, h_cache 37 | 38 | 39 | def _train_batch( 40 | model, load_balance, optimizer, scheduler, X, Y, h_cache, eval_only, batch_split 41 | ): 42 | """Train on a batch.""" 43 | 44 | optimizer.zero_grad() 45 | 46 | if batch_split == 1: 47 | # process a batch in a single step (default behaviour) 48 | loss_value, h_cache = _train_step(model, load_balance, X, Y, h_cache, eval_only) 49 | else: 50 | # split a batch into multiple pieces that each can fit in memory 51 | assert X.size(0) % batch_split == 0 52 | split_size = X.size(0) // batch_split 53 | loss_value = 0 54 | h_cache_list = [] 55 | for split_ind in range(batch_split): 56 | split_slice = slice(split_ind * split_size, (split_ind + 1) * split_size) 57 | split_h_cache = [h[split_slice, :, :] for h in h_cache] 58 | split_loss_value, split_h_cache = _train_step( 59 | model, 60 | load_balance, 61 | X[split_slice, :], 62 | Y[split_slice], 63 | split_h_cache, 64 | eval_only, 65 | batch_split, 66 | ) 67 | loss_value += split_loss_value 68 | h_cache_list.append(split_h_cache) 69 | h_cache = [ 70 | torch.cat([h_cache_list[i][l] for i in range(batch_split)], dim=0) 71 | for l in range(len(h_cache)) 72 | ] 73 | if not eval_only: 74 | if scheduler is not None: 75 | scheduler.step() 76 | optimizer.step() 77 | 78 | # make sure span parameters are in a correct range 79 | if model.module.layers[0].attn.attn.adapt_span_enabled: 80 | for layer in model.module.layers: 81 | if layer.use_attn: 82 | layer.attn.attn.adaptive_span.clamp_param() 83 | return loss_value, h_cache 84 | 85 | 86 | def train_iteration( 87 | model, 88 | load_balance, 89 | optimizer, 90 | scheduler, 91 | data, 92 | nb_batches_per_iter, 93 | block_size, 94 | eval_only, 95 | train_pos, 96 | h_cache, 97 | batch_split, 98 | checkpoint_path, 99 | ): 100 | """Single training iteration.""" 101 | if eval_only: 102 | model.eval() 103 | else: 104 | model.train() 105 | 106 | nb_batches_per_iter_max = nb_batches_per_iter 107 | if eval_only: 108 | # eval on fewer batches during training for speed-up 109 | nb_batches_per_iter_max = max(1, nb_batches_per_iter // 10) 110 | nb_batches_per_iter_max = min( 111 | nb_batches_per_iter_max, math.ceil(data.size(1) / block_size) 112 | ) 113 | 114 | loss_all = 0 115 | actual_nb_batches_per_iter = 0 116 | for _ in tqdm.tqdm(range(nb_batches_per_iter_max)): 117 | actual_nb_batches_per_iter += 1 118 | X = data[:, train_pos : train_pos + block_size].contiguous() 119 | Y = data[:, train_pos + 1 : train_pos + block_size + 1].contiguous() 120 | 121 | loss, h_cache = _train_batch( 122 | model=model, 123 | load_balance=load_balance, 124 | optimizer=optimizer, 125 | scheduler=scheduler, 126 | X=X, 127 | Y=Y, 128 | h_cache=h_cache, 129 | eval_only=eval_only, 130 | batch_split=batch_split, 131 | ) 132 | loss_all += loss 133 | train_pos += block_size 134 | if train_pos >= data.size(1) - block_size: 135 | # reached the end. randomize the offset to reduce overfitting 136 | train_pos = random.randrange(block_size) 137 | # reset the cache 138 | for h in h_cache: 139 | h.fill_(0) 140 | 141 | loss_all = loss_all / actual_nb_batches_per_iter 142 | return loss_all, train_pos, h_cache 143 | 144 | 145 | # do full evaluation 146 | def full_eval(model, optimizer, scheduler, data, block_size, hidden_size): 147 | model.eval() 148 | train_pos = 0 149 | nb_batches_per_iter_max = math.ceil(data.size(1) / block_size) 150 | h_cache = [ 151 | torch.zeros( 152 | data.size(0), 153 | model.module.layers[layer_i].attn.attn.get_cache_size(), 154 | hidden_size, 155 | ).to(data.device) 156 | for layer_i in range(model.module.attn_layer_count) 157 | ] 158 | 159 | loss_all = 0 160 | actual_nb_batches_per_iter = 0 161 | for _ in tqdm.tqdm(range(nb_batches_per_iter_max)): 162 | actual_nb_batches_per_iter += 1 163 | X = data[:, train_pos : train_pos + block_size].contiguous() 164 | Y = data[:, train_pos + 1 : train_pos + block_size + 1].contiguous() 165 | 166 | loss, h_cache = _train_batch( 167 | model=model, 168 | load_balance=0, 169 | optimizer=optimizer, 170 | scheduler=scheduler, 171 | X=X, 172 | Y=Y, 173 | h_cache=h_cache, 174 | eval_only=True, 175 | batch_split=1, 176 | ) 177 | loss_all += loss 178 | train_pos += block_size 179 | if train_pos >= data.size(1) - block_size: 180 | # Skip the remaining tokens as it can't make a whole block. 181 | # An effect on performance should be negligable for a large data. 182 | break 183 | 184 | loss_all = loss_all / actual_nb_batches_per_iter 185 | return loss_all 186 | -------------------------------------------------------------------------------- /custom_functions.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import math, random 4 | import torch 5 | import fmoe_cuda 6 | from torch.autograd import Function 7 | from custom_utils import get_torch_default_comm 8 | 9 | _moe_group = None 10 | 11 | 12 | def ensure_comm(t, comm): 13 | if comm is None: 14 | comm = get_torch_default_comm() 15 | global _moe_group 16 | _moe_group = comm 17 | fmoe_cuda.ensure_nccl(comm, t) 18 | 19 | 20 | def get_moe_group(): 21 | return _moe_group 22 | 23 | 24 | def count_by_gate(gate, num_expert, world_size, require_pos=True): 25 | with torch.no_grad(): 26 | local_expert_count = torch.zeros( 27 | num_expert * world_size, device=gate.device, dtype=torch.int32 28 | ) 29 | fmoe_cuda.expert_count(gate, local_expert_count) 30 | local_expert_count = local_expert_count.long() 31 | 32 | if world_size > 1: 33 | global_expert_count = fmoe_cuda.expert_exchange( 34 | local_expert_count, num_expert, world_size 35 | ) 36 | else: 37 | global_expert_count = local_expert_count 38 | if not require_pos: 39 | pos = None 40 | else: 41 | lec_cum = torch.cumsum(local_expert_count, dim=0).int() 42 | pos_size = lec_cum[-1].item() 43 | pos = torch.empty((pos_size,), device=gate.device, dtype=torch.long) 44 | fmoe_cuda.assign_pos(lec_cum, gate, pos) 45 | return pos, local_expert_count, global_expert_count 46 | 47 | 48 | def prepare_forward(gate, num_expert, world_size): 49 | r""" 50 | Prepare necessary information from gate output for MoE computation. 51 | 52 | Args: 53 | gate: a 1-d Long Tensor representing the target expert of each input 54 | sample. 55 | num_expert: number of experts on each worker. 56 | world_size: number of workers that hold different experts. 57 | comm: the communicator of all workers in the expert-parallel group. 58 | """ 59 | pos, local_expert_count, global_expert_count = count_by_gate( 60 | gate, num_expert, world_size 61 | ) 62 | with torch.no_grad(): 63 | fwd_expert_count = global_expert_count.view(world_size, num_expert).sum(dim=0) 64 | fwd_batch_size = int(fwd_expert_count.sum().item()) 65 | return ( 66 | pos, 67 | local_expert_count.cpu(), 68 | global_expert_count.cpu(), 69 | fwd_expert_count.cpu(), 70 | fwd_batch_size, 71 | ) 72 | 73 | 74 | def _local_scatter(inp, pos): 75 | inp_buf = torch.index_select(inp, 0, pos) 76 | return inp_buf 77 | 78 | 79 | def _local_gather(inp, pos, out_batch_size, maybe_overlap=True): 80 | inp_buf = torch.zeros( 81 | out_batch_size, inp.shape[-1], dtype=inp.dtype, device=inp.device 82 | ) 83 | if maybe_overlap: 84 | inp_buf.index_add_(0, pos, inp) 85 | else: 86 | inp_buf.index_copy_(0, pos, inp) 87 | return inp_buf 88 | 89 | 90 | class MOEScatter(Function): 91 | r""" 92 | Scatter input samples from [batch x sequences] to contiguous alone experts. 93 | If `world_size` is greater than 1, the samples will first be locally 94 | scattered, and then exchanged across workers. 95 | """ 96 | 97 | @staticmethod 98 | def forward( 99 | ctx, 100 | inp, 101 | pos, 102 | local_expert_count, 103 | global_expert_count, 104 | fwd_batch_size, 105 | world_size, 106 | ): 107 | local_input_buf = _local_scatter(inp, pos) 108 | if world_size > 1: 109 | global_input_buf = fmoe_cuda.global_scatter( 110 | local_input_buf, 111 | local_expert_count, 112 | global_expert_count, 113 | fwd_batch_size, 114 | world_size, 115 | ) 116 | else: 117 | global_input_buf = local_input_buf 118 | ctx.moe_args = inp.shape[0], pos.shape[0], world_size 119 | variables = (pos, local_expert_count, global_expert_count) 120 | ctx.save_for_backward(*variables) 121 | return global_input_buf 122 | 123 | @staticmethod 124 | def backward(ctx, global_grad_in): 125 | (pos, local_expert_count, global_expert_count) = ctx.saved_tensors 126 | (inp_batch_size, buf_batch_size, world_size) = ctx.moe_args 127 | 128 | if world_size > 1: 129 | local_grad_in = fmoe_cuda.global_gather( 130 | global_grad_in, 131 | local_expert_count, 132 | global_expert_count, 133 | buf_batch_size, 134 | world_size, 135 | ) 136 | else: 137 | local_grad_in = global_grad_in 138 | grad_in = _local_gather(local_grad_in, pos, inp_batch_size) 139 | return grad_in, None, None, None, None, None 140 | 141 | 142 | class MOEGather(Function): 143 | r""" 144 | Gather output samples from contiguous alone experts back to [batch x 145 | sequences]. Works symmetrically with MOEScatter. 146 | """ 147 | 148 | @staticmethod 149 | def forward( 150 | ctx, 151 | global_output_buf, 152 | pos, 153 | local_expert_count, 154 | global_expert_count, 155 | local_batch_size, 156 | world_size, 157 | ): 158 | if world_size > 1: 159 | local_output_buf = fmoe_cuda.global_gather( 160 | global_output_buf, 161 | local_expert_count, 162 | global_expert_count, 163 | pos.shape[0], 164 | world_size, 165 | ) 166 | else: 167 | local_output_buf = global_output_buf 168 | output = _local_gather( 169 | local_output_buf, pos, local_batch_size, maybe_overlap=False 170 | ) 171 | 172 | ctx.moe_args = (global_output_buf.shape[0], world_size) 173 | variables = (pos, local_expert_count, global_expert_count) 174 | ctx.save_for_backward(*variables) 175 | return output 176 | 177 | @staticmethod 178 | def backward(ctx, grad_out): 179 | pos, local_expert_count, global_expert_count = ctx.saved_tensors 180 | fwd_batch_size, world_size = ctx.moe_args 181 | grad_out_buf = _local_scatter(grad_out.contiguous(), pos) 182 | if world_size > 1: 183 | global_grad_out_buf = fmoe_cuda.global_scatter( 184 | grad_out_buf, 185 | local_expert_count, 186 | global_expert_count, 187 | fwd_batch_size, 188 | world_size, 189 | ) 190 | else: 191 | global_grad_out_buf = grad_out_buf 192 | return global_grad_out_buf, None, None, None, None, None 193 | 194 | 195 | class AllGather(Function): 196 | r""" 197 | A wrapper for the All-Gather function to support auto-differentiation. 198 | """ 199 | 200 | @staticmethod 201 | def forward(ctx, inp, rank, world_size, group): 202 | tensor_list = [torch.empty_like(inp) for _ in range(world_size)] 203 | torch.distributed.all_gather(tensor_list, inp, group=group) 204 | torch.cuda.synchronize() 205 | output = torch.cat(tensor_list, dim=0) 206 | ctx.args = rank, inp.shape[0] 207 | return output 208 | 209 | @staticmethod 210 | def backward(ctx, grad_out): 211 | rank, dim0 = ctx.args 212 | return grad_out[rank * dim0 : (rank + 1) * dim0], None, None, None 213 | 214 | 215 | class Slice(Function): 216 | r""" 217 | A wrapper for the Slice function to support auto-differentiation. 218 | """ 219 | 220 | @staticmethod 221 | def forward(ctx, inp, rank, world_size, group): 222 | B: int = inp.shape[0] 223 | local_batch_size = B // world_size 224 | batch_start = local_batch_size * rank 225 | batch_end = min(batch_start + local_batch_size, B) 226 | inp = inp[batch_start:batch_end] 227 | ctx.args = world_size, group 228 | return inp 229 | 230 | @staticmethod 231 | def backward(ctx, grad_out): 232 | world_size, group = ctx.args 233 | tensor_list = [torch.empty_like(grad_out) for _ in range(world_size)] 234 | torch.distributed.all_gather(tensor_list, grad_out, group=group) 235 | torch.cuda.synchronize() 236 | grad_out = torch.cat(tensor_list, dim=0) 237 | return grad_out, None, None, None 238 | -------------------------------------------------------------------------------- /custom_gates.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import math, random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import pdb 9 | import numpy as np 10 | from fmoe.gates.base_gate import BaseGate 11 | 12 | __all__ = [ 13 | "CustomNaiveGate_Balance_SMoE", 14 | "CustomNaiveGate_Balance_XMoE", 15 | "CustomNaiveGate_Balance_StableMoE", 16 | ] 17 | 18 | 19 | class CustomNaiveGate_Balance_SMoE(BaseGate): 20 | def __init__(self, d_model, num_expert, world_size, top_k=2, g_blance=False): 21 | super().__init__(num_expert, world_size) 22 | self.gate = nn.Linear(d_model, self.tot_expert) 23 | self.top_k = top_k 24 | self.dense_moe_flag = False 25 | self.g_blance = g_blance 26 | self.loss = None 27 | 28 | def set_load_balance(self, gate, gate_top_k_idx): 29 | 30 | score = F.softmax(gate, dim=-1) 31 | valid_idx = gate_top_k_idx[gate_top_k_idx > -1] 32 | fraction_expert = ( 33 | torch.scatter_add( 34 | torch.zeros(self.tot_expert, device=valid_idx.device), 35 | 0, 36 | valid_idx, 37 | torch.ones_like(valid_idx, dtype=torch.float), 38 | ) 39 | / valid_idx.numel() 40 | ) 41 | prob_expert = score.sum(dim=0) / valid_idx.numel() 42 | 43 | loss = (fraction_expert * prob_expert).sum() * self.tot_expert 44 | self.loss = loss 45 | 46 | def forward(self, inp, return_all_scores=False): 47 | 48 | gate = self.gate(inp) 49 | 50 | if self.dense_moe_flag: 51 | gate = torch.ones_like(gate) # average the importance of all experts 52 | gate_top_k_val, gate_top_k_idx = torch.topk( 53 | gate, k=self.tot_expert, dim=-1, largest=True, sorted=False 54 | ) 55 | gate_top_k_val = gate_top_k_val.view(-1, self.tot_expert) 56 | else: 57 | gate_top_k_val, gate_top_k_idx = torch.topk( 58 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 59 | ) # [.. x top_k] 60 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) # (BxL) x 1 x top_k 61 | 62 | gate_score = F.softmax(gate_top_k_val, dim=-1) 63 | if self.g_blance: 64 | self.set_load_balance(gate, gate_top_k_idx) 65 | 66 | if return_all_scores: 67 | return gate_top_k_idx, gate_score, gate 68 | return gate_top_k_idx, gate_score 69 | 70 | 71 | class CustomNaiveGate_Balance_XMoE(BaseGate): 72 | def __init__(self, d_model, num_expert, world_size, top_k=2, g_balance=False): 73 | super().__init__(num_expert, world_size) 74 | self.gate = nn.Linear(d_model, self.tot_expert) 75 | self.top_k = top_k 76 | self.dense_moe_flag = False 77 | self.g_balance = g_balance 78 | self.loss = 0.0 79 | 80 | expert_embeddings = torch.empty(num_expert, 8) 81 | torch.nn.init.orthogonal_(expert_embeddings, gain=0.32) 82 | self.register_parameter( 83 | "expert_embeddings", torch.nn.Parameter(expert_embeddings) 84 | ) 85 | 86 | self.inp_reduction = torch.nn.Linear(d_model, 8, bias=False) 87 | 88 | def set_load_balance(self, gate, gate_top_k_idx): 89 | # gate_top_k_idx (tokens_number, top-k) 90 | # gate_top_k_val (tokens_number, top-k) 91 | 92 | score = F.softmax(gate / 0.3, dim=-1) 93 | valid_idx = gate_top_k_idx[gate_top_k_idx > -1] 94 | fraction_expert = ( 95 | torch.scatter_add( 96 | torch.zeros(self.tot_expert, device=valid_idx.device), 97 | 0, 98 | valid_idx, 99 | torch.ones_like(valid_idx, dtype=torch.float), 100 | ) 101 | / valid_idx.numel() 102 | ) 103 | prob_expert = score.sum(dim=0) / valid_idx.numel() 104 | 105 | loss = (fraction_expert * prob_expert).sum() * self.tot_expert 106 | self.loss = loss 107 | 108 | def forward(self, inp, return_all_scores=False): 109 | 110 | reduced_inp = self.inp_reduction(inp) 111 | with torch.no_grad(): 112 | expert_embeddings_norm = self.expert_embeddings.norm( 113 | p=2.0, dim=1, keepdim=True 114 | ) 115 | self.expert_embeddings.mul_(1.5 / expert_embeddings_norm) 116 | 117 | gate = self._cosine(reduced_inp, self.expert_embeddings) 118 | gate = self._make_finite(gate) 119 | 120 | if self.dense_moe_flag: 121 | gate = torch.ones_like(gate) # average the importance of all experts 122 | gate_top_k_val, gate_top_k_idx = torch.topk( 123 | gate, k=self.tot_expert, dim=-1, largest=True, sorted=False 124 | ) 125 | gate_top_k_val = gate_top_k_val.view(-1, self.tot_expert) 126 | else: 127 | gate_top_k_val, gate_top_k_idx = torch.topk( 128 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 129 | ) # [.. x top_k] 130 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) # (BxL) x 1 x top_k 131 | 132 | gate_score = F.softmax(gate_top_k_val, dim=-1) 133 | if self.g_balance: 134 | self.set_load_balance(gate, gate_top_k_idx) 135 | 136 | if return_all_scores: 137 | return gate_top_k_idx, gate_score, gate 138 | return gate_top_k_idx, gate_score 139 | 140 | def _cosine(self, mat1, mat2, eps=1e-4): 141 | assert mat1.dim() == 2 142 | assert mat2.dim() == 2 143 | # mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps) 144 | mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps) 145 | return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1) 146 | 147 | def _make_finite(self, scores): 148 | ok = scores.isfinite() 149 | if not ok.all(): 150 | # NaNs here can break the assignment algorithm 151 | scores[~ok] = scores[ok].min() 152 | return scores 153 | 154 | 155 | class CustomNaiveGate_Balance_StableMoE(BaseGate): 156 | r""" 157 | Naive Gate StableMoE 158 | """ 159 | 160 | def __init__(self, d_model, num_expert, world_size, top_k=2, g_balance=False): 161 | super().__init__(num_expert, world_size) 162 | self.top_k = top_k 163 | self.dense_moe_flag = False 164 | self.g_balance = g_balance 165 | self.loss = 0.0 166 | 167 | expert_embeddings = torch.empty(num_expert, d_model) 168 | torch.nn.init.orthogonal_(expert_embeddings, gain=0.32) 169 | self.register_parameter( 170 | "expert_embeddings", torch.nn.Parameter(expert_embeddings) 171 | ) 172 | 173 | def set_load_balance(self, gate, gate_top_k_idx): 174 | 175 | score = F.softmax(gate / 0.3, dim=-1) 176 | valid_idx = gate_top_k_idx[gate_top_k_idx > -1] 177 | fraction_expert = ( 178 | torch.scatter_add( 179 | torch.zeros(self.tot_expert, device=valid_idx.device), 180 | 0, 181 | valid_idx, 182 | torch.ones_like(valid_idx, dtype=torch.float), 183 | ) 184 | / valid_idx.numel() 185 | ) 186 | prob_expert = score.sum(dim=0) / valid_idx.numel() 187 | 188 | loss = (fraction_expert * prob_expert).sum() * self.tot_expert 189 | self.loss = loss 190 | 191 | def forward(self, inp, return_all_scores=False): 192 | 193 | gate = self._cosine(inp, self.expert_embeddings) 194 | gate = self._make_finite(gate) 195 | 196 | if self.dense_moe_flag: 197 | gate = torch.ones_like(gate) # average the importance of all experts 198 | gate_top_k_val, gate_top_k_idx = torch.topk( 199 | gate, k=self.tot_expert, dim=-1, largest=True, sorted=False 200 | ) 201 | gate_top_k_val = gate_top_k_val.view(-1, self.tot_expert) 202 | else: 203 | gate_top_k_val, gate_top_k_idx = torch.topk( 204 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 205 | ) # [.. x top_k] 206 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) 207 | # (BxL) x 1 x top_k 208 | 209 | gate_score = F.softmax(gate_top_k_val, dim=-1) 210 | if self.g_balance: 211 | self.set_load_balance(gate, gate_top_k_idx) 212 | 213 | if return_all_scores: 214 | return gate_top_k_idx, gate_score, gate 215 | return gate_top_k_idx, gate_score 216 | 217 | def _cosine(self, mat1, mat2, eps=1e-4): 218 | assert mat1.dim() == 2 219 | assert mat2.dim() == 2 220 | # mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps) 221 | mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps) 222 | return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1) 223 | 224 | def _make_finite(self, scores): 225 | ok = scores.isfinite() 226 | if not ok.all(): 227 | # NaNs here can break the assignment algorithm 228 | scores[~ok] = scores[ok].min() 229 | return scores 230 | -------------------------------------------------------------------------------- /finetune_trainer.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import math, random 4 | import torch 5 | import tqdm 6 | import pdb 7 | import torch.nn as nn 8 | from custom_gates import * 9 | 10 | 11 | def _train_step(model, load_balance, X, Y, h_cache, eval_only, loss_div=1): 12 | """Single training step.""" 13 | acc_num, acc_value = 0, 0.0 14 | out, h_cache = model(X, h_cache) 15 | # print(model.module.layers[0].attn.proj_key.weight) 16 | out = out.view(-1, out.size(-1)) 17 | # loss = torch.nn.functional.nll_loss(out, Y.view(-1)) 18 | criterion = nn.CrossEntropyLoss() 19 | # pdb.set_trace() 20 | # print(out, Y) 21 | loss = criterion(out, Y) 22 | loss_value = loss.item() / loss_div 23 | # loss = loss.float() 24 | # loss = loss.item() / loss_div 25 | acc_value += (out.argmax(-1) == Y).sum().item() 26 | acc_num += Y.shape[0] 27 | 28 | if not eval_only: 29 | # loss term from adaptive-span 30 | if model.module.layers[0].attn.attn.adapt_span_enabled: 31 | loss += sum( 32 | model.module.layers[layer_i].attn.attn.adaptive_span.get_loss() 33 | for layer_i in range(model.module.attn_layer_count) 34 | ) 35 | 36 | if load_balance > 0: 37 | balance_loss = 0 38 | for name, m in model.named_modules(): 39 | if isinstance(m, CustomNaiveGate_Balance_SMoE) or isinstance( 40 | m, CustomNaiveGate_Balance_XMoE 41 | ): 42 | balance_loss += m.loss 43 | loss += load_balance * balance_loss 44 | 45 | (loss / loss_div).backward() 46 | # print(torch.norm(model.module.layers[0].attn.proj_key.weight.grad)) 47 | return loss_value, acc_num, acc_value, h_cache 48 | 49 | 50 | def _train_batch( 51 | model, load_balance, optimizer, scheduler, X, Y, h_cache, eval_only, batch_split 52 | ): 53 | """Train on a batch.""" 54 | 55 | optimizer.zero_grad() 56 | total_len, total_acc = 0, 0.0 57 | if batch_split == 1: 58 | # process a batch in a single step (default behaviour) 59 | loss_value, total_len, total_acc, h_cache = _train_step( 60 | model, load_balance, X, Y, h_cache, eval_only 61 | ) 62 | else: 63 | # split a batch into multiple pieces that each can fit in memory 64 | assert X.size(0) % batch_split == 0 65 | split_size = X.size(0) // batch_split 66 | loss_value = 0 67 | h_cache_list = [] 68 | for split_ind in range(batch_split): 69 | split_slice = slice(split_ind * split_size, (split_ind + 1) * split_size) 70 | split_h_cache = [h[split_slice, :, :] for h in h_cache] 71 | tmp_len, tmp_acc = 0, 0.0 72 | # pdb.set_trace() 73 | split_loss_value, tmp_len, tmp_acc, split_h_cache = _train_step( 74 | model=model, 75 | load_balance=load_balance, 76 | X=X[split_slice, :], 77 | Y=Y[split_slice], 78 | h_cache=split_h_cache, 79 | eval_only=eval_only, 80 | loss_div=batch_split, 81 | ) 82 | loss_value += split_loss_value 83 | total_len += tmp_len 84 | total_acc += tmp_acc 85 | h_cache_list.append(split_h_cache) 86 | h_cache = [ 87 | torch.cat([h_cache_list[i][l] for i in range(batch_split)], dim=0) 88 | for l in range(len(h_cache)) 89 | ] 90 | if not eval_only: 91 | if scheduler is not None: 92 | scheduler.step() 93 | optimizer.step() 94 | 95 | # make sure span parameters are in a correct range 96 | if model.module.layers[0].attn.attn.adapt_span_enabled: 97 | for layer in model.module.layers: 98 | if layer.use_attn: 99 | layer.attn.attn.adaptive_span.clamp_param() 100 | return loss_value, total_len, total_acc, h_cache 101 | 102 | 103 | def train_iteration( 104 | model, 105 | load_balance, 106 | optimizer, 107 | scheduler, 108 | data, 109 | nb_batches_per_iter, 110 | block_size, 111 | eval_only, 112 | train_pos, 113 | h_cache, 114 | batch_split, 115 | checkpoint_path, 116 | ): 117 | """Single training iteration.""" 118 | if eval_only: 119 | model.eval() 120 | else: 121 | model.train() 122 | 123 | # nb_batches_per_iter_max = nb_batches_per_iter 124 | # if eval_only: 125 | # # eval on fewer batches during training for speed-up 126 | # nb_batches_per_iter_max = max(1, nb_batches_per_iter // 10) 127 | # for _temp, _, _ in data: 128 | # ch_block = _temp.size(1) # data.size(1) 129 | # break 130 | # nb_batches_per_iter_max = min( 131 | # nb_batches_per_iter_max, math.ceil(ch_block / block_size) 132 | # ) 133 | nb_batches_per_iter_max = nb_batches_per_iter 134 | if eval_only: 135 | # eval on fewer batches during training for speed-up 136 | nb_batches_per_iter_max = max(1, nb_batches_per_iter // 10) 137 | nb_batches_per_iter_max = min( 138 | nb_batches_per_iter_max, math.ceil(data.n_step / block_size) 139 | ) 140 | 141 | loss_all = 0 142 | # all accuracy and number 143 | total_len, total_acc = 0, 0.0 144 | actual_nb_batches_per_iter = 0 145 | for _data, _att_mask, _target in tqdm.tqdm(data): 146 | actual_nb_batches_per_iter += 1 147 | _data = _data.cuda() 148 | _target = _target.cuda() 149 | X = _data.permute( 150 | 1, 0 151 | ).contiguous() # data[:, train_pos: train_pos + block_size].contiguous() 152 | Y = _target # .permute(1,0) #.contiguous() #data[:, train_pos + 1: train_pos + block_size + 1].contiguous() 153 | # print(Y) 154 | 155 | loss, tmp_len, tmp_acc, h_cache = _train_batch( 156 | model=model, 157 | load_balance=load_balance, 158 | optimizer=optimizer, 159 | scheduler=scheduler, 160 | X=X, 161 | Y=Y, 162 | h_cache=h_cache, 163 | eval_only=eval_only, 164 | batch_split=batch_split, 165 | ) 166 | # print(tmp_acc) 167 | loss_all += loss 168 | total_len += tmp_len 169 | total_acc += tmp_acc 170 | train_pos += block_size 171 | if train_pos >= _data.size(1) - block_size: 172 | # reached the end. randomize the offset to reduce overfitting 173 | train_pos = random.randrange(block_size) 174 | # reset the cache 175 | for h in h_cache: 176 | h.fill_(0) 177 | 178 | loss_all = loss_all / actual_nb_batches_per_iter 179 | acc_all = 100 * total_acc / total_len 180 | return loss_all, acc_all, train_pos, h_cache 181 | 182 | 183 | # do full evaluation 184 | def full_eval(model, optimizer, scheduler, data, block_size, hidden_size): 185 | model.eval() 186 | train_pos = 0 187 | # nb_batches_per_iter_max = math.ceil(data.encoded.size(1) / block_size) 188 | h_cache = [ 189 | torch.zeros( 190 | data.bsz, 191 | model.module.layers[layer_i].attn.attn.get_cache_size(), 192 | hidden_size, 193 | ).cuda() 194 | for layer_i in range(model.module.attn_layer_count) 195 | ] 196 | 197 | loss_all = 0 198 | actual_nb_batches_per_iter = 0 199 | total_len, total_acc = 0, 0.0 200 | 201 | for _data, _att_mask, _target in tqdm.tqdm(data): 202 | actual_nb_batches_per_iter += 1 203 | _data = _data.cuda() 204 | _target = _target.cuda() 205 | X = _data.permute( 206 | 1, 0 207 | ).contiguous() # data[:, train_pos: train_pos + block_size].contiguous() 208 | Y = _target # .permute(1,0) #.contiguous() #data[:, train_pos + 1: train_pos + block_size + 1].contiguous() 209 | # pdb.set_trace() 210 | # for _ in tqdm.tqdm(range(nb_batches_per_iter_max)): 211 | # actual_nb_batches_per_iter += 1 212 | # X = data[:, train_pos: train_pos + block_size].contiguous() 213 | # Y = data[:, train_pos + 1: train_pos + block_size + 1].contiguous() 214 | 215 | loss, tmp_len, tmp_acc, h_cache = _train_batch( 216 | model=model, 217 | load_balance=0, 218 | optimizer=optimizer, 219 | scheduler=scheduler, 220 | X=X, 221 | Y=Y, 222 | h_cache=h_cache, 223 | eval_only=True, 224 | batch_split=1, 225 | ) 226 | loss_all += loss 227 | total_len += tmp_len 228 | total_acc += tmp_acc 229 | train_pos += block_size 230 | if train_pos >= _data.size(1) - block_size: 231 | # Skip the remaining tokens as it can't make a whole block. 232 | # An effect on performance should be negligable for a large data. 233 | break 234 | 235 | loss_all = loss_all / actual_nb_batches_per_iter 236 | acc_all = 100 * total_acc / total_len 237 | return loss_all, acc_all 238 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import math, random 4 | import functools 5 | import os, shutil 6 | import torch 7 | import tqdm 8 | from models import CustomizedMoEPositionwiseFFOpt 9 | 10 | 11 | def logging(s, log_path, print_=True, log_=True): 12 | if print_: 13 | print(s) 14 | if log_: 15 | with open(log_path, "a+") as f_log: 16 | f_log.write(s + "\n") 17 | 18 | 19 | def get_logger(log_path, **kwargs): 20 | return functools.partial(logging, log_path=log_path, **kwargs) 21 | 22 | 23 | def create_exp_dir(dir_path, scripts_to_save=None, debug=False): 24 | if debug: 25 | print("Debug Mode : no experiment dir created") 26 | return functools.partial(logging, log_path=None, log_=False) 27 | 28 | if not os.path.exists(dir_path): 29 | os.makedirs(dir_path) 30 | 31 | print("Experiment dir : {}".format(dir_path)) 32 | if scripts_to_save is not None: 33 | script_path = os.path.join(dir_path, "scripts") 34 | if not os.path.exists(script_path): 35 | os.makedirs(script_path) 36 | for script in scripts_to_save: 37 | dst_file = os.path.join(dir_path, "scripts", os.path.basename(script)) 38 | shutil.copyfile(script, dst_file) 39 | 40 | return get_logger(log_path=os.path.join(dir_path, "log.txt")) 41 | 42 | 43 | def freeze_gate_weight(model): 44 | print("* Freeze Router") 45 | for name, p in model.named_parameters(): 46 | if "gate.gate" in name: 47 | print("Freeze: ", name) 48 | p.requires_grad = False 49 | 50 | 51 | def set_freq_optimal_search(model, threshold): 52 | print(f"* Set Freq Optimal Search: ") 53 | for name, m in model.named_modules(): 54 | if isinstance(m, CustomizedMoEPositionwiseFFOpt): 55 | if random.random() > (1 - threshold): 56 | print(f"* Set Freq of {name} to 1.0") 57 | m.freq = 1.0 58 | else: 59 | print(f"* Set Freq of {name} to 0.0") 60 | m.freq = 0.0 61 | 62 | 63 | def _parse_args(params_config, args): 64 | parser = argparse.ArgumentParser() 65 | for params_category in params_config: # e.g., 'model_params' 66 | for param_flag, param_config in params_config[params_category].items(): 67 | # e.g., param_flag = '--block-sz' 68 | parser.add_argument(param_flag, **param_config) 69 | return parser.parse_args(args) 70 | 71 | 72 | def get_params(params_config, args=None): 73 | namespace = _parse_args(params_config, args) 74 | return { 75 | params_category: { 76 | param_config["dest"]: namespace.__getattribute__(param_config["dest"]) 77 | for param_config in params_config[params_category].values() 78 | } 79 | for params_category in params_config 80 | } 81 | 82 | 83 | ############################################################################## 84 | # ENVIRONMENT 85 | ############################################################################## 86 | 87 | 88 | def _torch_distributed_init_process_group(local_rank): 89 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 90 | rank = torch.distributed.get_rank() 91 | world_size = torch.distributed.get_world_size() 92 | print("my rank={} local_rank={}".format(rank, local_rank)) 93 | torch.cuda.set_device(local_rank) 94 | return { 95 | "rank": rank, 96 | "world_size": world_size, 97 | } 98 | 99 | 100 | def set_up_env(env_params): 101 | assert torch.cuda.is_available() 102 | if env_params["distributed"]: 103 | env_params.update( 104 | _torch_distributed_init_process_group(local_rank=env_params["local_rank"]) 105 | ) 106 | env_params["device"] = torch.device("cuda") 107 | 108 | 109 | ############################################################################## 110 | # OPTIMIZER AND SCHEDULER 111 | ############################################################################## 112 | 113 | 114 | def _get_grad_requiring_params(model): 115 | nb_parameters = 0 116 | grad_requiring_params = [] 117 | for param in model.parameters(): 118 | if param.requires_grad: 119 | nb_parameters += param.numel() 120 | grad_requiring_params.append(param) 121 | print("nb_parameters={:.2f}M".format(nb_parameters / 1e6)) 122 | return grad_requiring_params 123 | 124 | 125 | def _get_optimizer(model, optim, lr: float, momentum: float, grad_clip: float): 126 | if optim == "sgd": 127 | return torch.optim.SGD( 128 | _get_grad_requiring_params(model), lr=lr, momentum=momentum 129 | ) 130 | elif optim == "adam": 131 | return torch.optim.Adam( 132 | _get_grad_requiring_params(model), 133 | lr=lr, 134 | ) 135 | else: 136 | raise RuntimeError("wrong type of optimizer - must be 'sgd' or 'adam'") 137 | 138 | 139 | def _get_scheduler(optimizer, lr_warmup): 140 | if lr_warmup > 0: 141 | return torch.optim.lr_scheduler.LambdaLR( 142 | optimizer, lambda ep: min(1, ep / lr_warmup) 143 | ) 144 | return None 145 | 146 | 147 | def get_optimizer_and_scheduler(model, optim_params): 148 | optimizer = _get_optimizer( 149 | model=model, 150 | optim=optim_params["optim"], 151 | lr=optim_params["lr"], 152 | momentum=optim_params["momentum"], 153 | grad_clip=optim_params["grad_clip"], 154 | ) 155 | scheduler = _get_scheduler(optimizer=optimizer, lr_warmup=optim_params["lr_warmup"]) 156 | return optimizer, scheduler 157 | 158 | 159 | ############################################################################## 160 | # CHECKPOINT 161 | ############################################################################## 162 | 163 | 164 | def _load_checkpoint(checkpoint_path, model, optimizer, scheduler, logger, distributed): 165 | print("loading from a checkpoint at {}".format(checkpoint_path)) 166 | if distributed: 167 | # the model is saved from gpu0 so we need to map it to CPU first 168 | checkpoint_state = torch.load( 169 | checkpoint_path, map_location=lambda storage, loc: storage 170 | ) 171 | else: 172 | checkpoint_state = torch.load(checkpoint_path) 173 | iter_init = checkpoint_state["nb_batches_per_iter"] + 1 # next iteration 174 | # del checkpoint_state["model"]["module.in_emb.weight"] 175 | # del checkpoint_state["model"]["module.out_emb.weight"] 176 | # del checkpoint_state["model"]["module.out_emb.bias"] 177 | model.load_state_dict(checkpoint_state["model"]) 178 | optimizer.load_state_dict(checkpoint_state["optimizer"]) 179 | if "scheduler_iter" in checkpoint_state: 180 | # we only need the step count 181 | scheduler.step(checkpoint_state["scheduler_iter"]) 182 | return iter_init 183 | 184 | 185 | def load_checkpoint(checkpoint_path, model, optimizer, scheduler, logger, distributed, resume): 186 | print(checkpoint_path) 187 | if resume and os.path.exists(checkpoint_path): 188 | return _load_checkpoint( 189 | checkpoint_path=checkpoint_path, 190 | model=model, 191 | optimizer=optimizer, 192 | scheduler=scheduler, 193 | logger=logger, 194 | distributed=distributed, 195 | ) 196 | return 0 197 | 198 | 199 | def save_checkpoint( 200 | checkpoint_path, nb_batches_per_iter, model, optimizer, scheduler, logger 201 | ): 202 | if checkpoint_path: 203 | checkpoint_state = { 204 | "nb_batches_per_iter": nb_batches_per_iter, # last completed iteration 205 | "model": model.state_dict(), 206 | "optimizer": optimizer.state_dict(), 207 | } 208 | if scheduler is not None: 209 | checkpoint_state["scheduler_iter"] = scheduler.last_epoch 210 | torch.save(checkpoint_state, checkpoint_path) 211 | 212 | 213 | ############################################################################## 214 | # LOGGER 215 | ############################################################################## 216 | 217 | 218 | class Logger: 219 | def __init__(self): 220 | self._state_dict = dict() 221 | 222 | def load_state_dict(self, state_dict): 223 | self._state_dict = state_dict 224 | 225 | def state_dict(self): 226 | return self._state_dict 227 | 228 | def _log(self, title, value): 229 | if title not in self._state_dict: 230 | self._state_dict[title] = [] 231 | self._state_dict[title].append(value) 232 | 233 | def log_iter( 234 | self, iter_no, nb_batches_per_iter, loss_train, loss_val, elapsed, model 235 | ): 236 | step = (iter_no + 1) * nb_batches_per_iter 237 | train_bpc = float(loss_train / math.log(2)) 238 | val_bpc = float(loss_val / math.log(2)) 239 | msg = "steps: {}".format(step) 240 | msg += "\ttrain: {:.3f}bpc\tval: {:.3f}bpc".format(train_bpc, val_bpc) 241 | msg += "\tms/batch: {:.1f}".format(elapsed) 242 | self._log(title="step", value=step) 243 | self._log(title="train_bpc", value=train_bpc) 244 | self._log(title="val_bpc", value=val_bpc) 245 | 246 | if model.module.layers[0].attn.attn.adapt_span_enabled: 247 | avg_spans = [] 248 | max_spans = [] 249 | for layer in model.module.layers: 250 | if layer.use_attn: 251 | avg_spans.append( 252 | layer.attn.attn.adaptive_span.get_current_avg_span() 253 | ) 254 | max_spans.append( 255 | layer.attn.attn.adaptive_span.get_current_max_span() 256 | ) 257 | span_avg = float(sum(avg_spans)) / len(avg_spans) 258 | span_max = float(max(max_spans)) 259 | self._log("span_avg", span_avg) 260 | self._log("span_max", span_max) 261 | msg += "\tspan_avg: {:.0f}\tspan_max: {:.0f}".format(span_avg, span_max) 262 | 263 | print(msg) 264 | -------------------------------------------------------------------------------- /finetune_train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import warnings 3 | 4 | warnings.filterwarnings("ignore") 5 | 6 | import argparse 7 | import math, random 8 | import torch 9 | import time 10 | import pdb 11 | from config import PARAMS_CONFIG 12 | 13 | from finetune_data import get_lm_corpus 14 | from finetune_models import TransformerSeq 15 | from finetune_trainer import train_iteration, full_eval 16 | import datetime 17 | from utils import ( 18 | get_params, 19 | set_up_env, 20 | get_optimizer_and_scheduler, 21 | load_checkpoint, 22 | save_checkpoint, 23 | create_exp_dir, 24 | freeze_gate_weight, 25 | Logger, 26 | ) 27 | 28 | 29 | def launch( 30 | env_params, 31 | model_params, 32 | adapt_span_params, 33 | optim_params, 34 | data_params, 35 | trainer_params, 36 | ): 37 | # global val 38 | best_val_loss = None 39 | # ENVIRONMENT (device, distributed, etc.) 40 | set_up_env(env_params) 41 | device = env_params["device"] 42 | distributed = env_params["distributed"] 43 | 44 | if distributed == False or env_params["rank"] == 0: 45 | print("data_params:\t", data_params) 46 | print("model_params:\t", model_params) 47 | print("optim_params:\t", optim_params) 48 | print("trainer_params:\t", trainer_params) 49 | print("adapt_span_params:\t", adapt_span_params) 50 | 51 | # DATA 52 | corpus = get_lm_corpus(data_params["data_path"], data_params["data_name"]) 53 | ntokens = len(corpus.vocab) 54 | 55 | if data_params["data_name"] in ["sst2", "imdb"]: 56 | num_classes = 2 57 | elif data_params["data_name"] == "sst5": 58 | num_classes = 5 59 | elif data_params["data_name"] == "banking77": 60 | num_classes = 77 61 | 62 | eval_batch_size = 10 63 | train_data = corpus.get_iterator("train", trainer_params["batch_size"]) 64 | val_data = corpus.get_iterator("valid", eval_batch_size) 65 | test_data = val_data # corpus.get_iterator('test', eval_batch_size) 66 | 67 | # MODEL data_params['vocab_size'] 68 | model = TransformerSeq( 69 | vocab_size=ntokens, 70 | **model_params, 71 | num_classes=num_classes, 72 | adapt_span_params=adapt_span_params, 73 | ) 74 | print(model) 75 | if distributed: 76 | local_rank = env_params["local_rank"] 77 | model = model.to(device) 78 | model = torch.nn.parallel.DistributedDataParallel( 79 | model, 80 | device_ids=[local_rank], 81 | output_device=local_rank, 82 | find_unused_parameters=True, 83 | ) 84 | else: 85 | model = torch.nn.DataParallel(model) 86 | model = model.to(device) 87 | 88 | # OPTIMIZER AND SCHEDULER 89 | optimizer, scheduler = get_optimizer_and_scheduler( 90 | model=model, optim_params=optim_params 91 | ) 92 | 93 | # create logger 94 | logger = Logger() 95 | fold_name = trainer_params["checkpoint_path"].split("/")[-1].split(".")[0] 96 | folder_path = "/".join(trainer_params["checkpoint_path"].split("/")[:-1]) 97 | logging = create_exp_dir(f"{folder_path}/experiments/{fold_name}") 98 | # log paramters 99 | logging(f"Training Parameters:\n {trainer_params}") 100 | # logging time 101 | current_time = datetime.datetime.now() 102 | logging(str(current_time)) 103 | # log model 104 | logging(str(model)) 105 | logging(f"Total of Parameters: {sum(p.numel() for p in model.parameters())}") 106 | logging( 107 | f"Total of Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}" 108 | ) 109 | # load check points 110 | 111 | logging("=" * 100) 112 | logging( 113 | "==== loading pretrained model from {} ====".format( 114 | trainer_params["pretrained_weight"] 115 | ) 116 | ) 117 | logging("=" * 100) 118 | 119 | # Load the best saved model. 120 | if not trainer_params["full_eval_mode"]: 121 | with open(trainer_params["pretrained_weight"], "rb") as f: 122 | pretrained_model = torch.load(f) 123 | # pdb.set_trace() 124 | pretrained_model_checkpoint = pretrained_model["model"] # .state_dict() 125 | filtered_checkpoint = {} 126 | for key in pretrained_model_checkpoint.keys(): 127 | if not key in model.state_dict(): 128 | logging("Can not load {}".format(key)) 129 | elif ( 130 | not pretrained_model_checkpoint[key].shape 131 | == model.state_dict()[key].shape 132 | ): 133 | logging("Can not load {}, shape do not match".format(key)) 134 | else: 135 | filtered_checkpoint[key] = pretrained_model_checkpoint[key] 136 | 137 | model.load_state_dict(filtered_checkpoint, strict=False) 138 | iter_init = 0 139 | else: 140 | # resume training from last checkpoint if exists 141 | iter_init = load_checkpoint( 142 | trainer_params["checkpoint_path"], 143 | model, 144 | optimizer, 145 | scheduler, 146 | logger, 147 | distributed, 148 | ) 149 | 150 | # fix gate 151 | if model_params["smoe_dropout"]: 152 | freeze_gate_weight(model) 153 | # calculate time 154 | start_time = time.time() 155 | # eval model 156 | if trainer_params["full_eval_mode"]: 157 | # evaluate the model on test data 158 | with torch.no_grad(): 159 | loss_val, acc_val = full_eval( 160 | model, 161 | optimizer, 162 | scheduler, 163 | val_data, 164 | model_params["block_size"], 165 | model_params["hidden_size"], 166 | ) 167 | loss_test, acc_test = full_eval( 168 | model, 169 | optimizer, 170 | scheduler, 171 | test_data, 172 | model_params["block_size"], 173 | model_params["hidden_size"], 174 | ) 175 | if distributed: 176 | # collect results into rank0 177 | stats = torch.tensor([loss_val, loss_test]).to(device) 178 | torch.distributed.reduce(stats, 0) 179 | if env_params["rank"] == 0: 180 | loss_val = stats[0] / env_params["world_size"] 181 | loss_test = stats[1] / env_params["world_size"] 182 | else: 183 | return 184 | 185 | # log accuracy score 186 | logging("Val: {:.3f} Acc".format(acc_val)) 187 | logging("Test: {:.3f} Acc".format(acc_test)) 188 | 189 | return 190 | 191 | # position of current batch 192 | data_pos = [0] * 2 193 | # initialize caches for train and valid 194 | hid_cache = [ 195 | [ 196 | torch.zeros( 197 | train_data.bsz, 198 | model.module.layers[layer_i].attn.attn.get_cache_size(), 199 | model_params["hidden_size"], 200 | ).to(device) 201 | for layer_i in range(model.module.attn_layer_count) 202 | ] 203 | for _ in range(2) 204 | ] 205 | 206 | nb_batches_per_iter = trainer_params["nb_batches_per_iter"] 207 | for iter_no in range(iter_init, trainer_params["nb_iter"]): 208 | t_sta = time.time() 209 | loss_train, acc_train, data_pos[0], hid_cache[0] = train_iteration( 210 | model, 211 | model_params["load_balance"], 212 | optimizer, 213 | scheduler, 214 | train_data, 215 | nb_batches_per_iter, 216 | model_params["block_size"], 217 | False, 218 | data_pos[0], 219 | hid_cache[0], 220 | trainer_params["batch_split"], 221 | trainer_params["checkpoint_path"], 222 | ) 223 | elapsed = 1000 * (time.time() - t_sta) / nb_batches_per_iter 224 | with torch.no_grad(): 225 | loss_val, acc_val, data_pos[1], hid_cache[1] = train_iteration( 226 | model, 227 | model_params["load_balance"], 228 | optimizer, 229 | scheduler, 230 | val_data, 231 | nb_batches_per_iter, 232 | model_params["block_size"], 233 | True, 234 | data_pos[1], 235 | hid_cache[1], 236 | trainer_params["batch_split"], 237 | trainer_params["checkpoint_path"], 238 | ) 239 | 240 | if distributed: 241 | # collect results into rank0 242 | stats = torch.tensor([loss_train, loss_val]).to(device) 243 | torch.distributed.reduce(stats, 0) 244 | if env_params["rank"] == 0: 245 | loss_train = stats[0] / env_params["world_size"] 246 | loss_val = stats[1] / env_params["world_size"] 247 | else: 248 | continue 249 | logging(f"=================== EPOCHS {iter_no} ======================") 250 | # if ('enwik8' in data_params['data_path']) or ('text8' in data_params['data_path']): 251 | msg_result = "Epochs: {} | loss_train: {:.3f} ~ {:.3f} Acc | loss_val: {:.3f} ~ {:.3f} Acc | elapsed: {:.1f}".format( 252 | iter_no, loss_train, acc_train, loss_val, acc_val, elapsed 253 | ) 254 | logging(msg_result) 255 | # Save the model if the validation loss is the best we've seen so far. 256 | if (best_val_loss is None) or loss_val < best_val_loss: 257 | best_val_loss = loss_val 258 | save_checkpoint( 259 | trainer_params["checkpoint_path"], 260 | iter_no, 261 | model, 262 | optimizer, 263 | scheduler, 264 | logger, 265 | ) 266 | # save_checkpoint(trainer_params['checkpoint_path'], nb_batches_per_iter, model, optimizer, scheduler, logger) 267 | end_time = time.time() 268 | logging(f"Training time total: {(end_time - start_time)/3600} h") 269 | 270 | 271 | if __name__ == "__main__": 272 | launch(**get_params(params_config=PARAMS_CONFIG)) 273 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import warnings 3 | 4 | warnings.filterwarnings("ignore") 5 | 6 | import argparse 7 | import math, random 8 | import torch 9 | import time 10 | 11 | from config import PARAMS_CONFIG 12 | from data import get_train_val_test_data 13 | from models import TransformerSeq 14 | from trainer import train_iteration, full_eval 15 | import datetime 16 | import wandb 17 | import os 18 | from utils import ( 19 | get_params, 20 | set_up_env, 21 | get_optimizer_and_scheduler, 22 | load_checkpoint, 23 | save_checkpoint, 24 | create_exp_dir, 25 | freeze_gate_weight, 26 | Logger, 27 | set_freq_optimal_search, 28 | ) 29 | 30 | 31 | def launch( 32 | env_params, 33 | model_params, 34 | adapt_span_params, 35 | optim_params, 36 | data_params, 37 | trainer_params, 38 | wandb_params, 39 | ): 40 | wandb_flag = wandb_params["wandb_flag"] 41 | if wandb_flag: 42 | wandb.init(project=wandb_params["project_name"]) 43 | wandb.run.name = wandb_params["job_name"] 44 | wandb.config.update(model_params) 45 | # global val 46 | best_val_loss = None 47 | # ENVIRONMENT (device, distributed, etc.) 48 | set_up_env(env_params) 49 | device = env_params["device"] 50 | distributed = env_params["distributed"] 51 | resume = trainer_params["resume"] 52 | 53 | if distributed == False or env_params["rank"] == 0: 54 | print("data_params:\t", data_params) 55 | print("model_params:\t", model_params) 56 | print("optim_params:\t", optim_params) 57 | print("trainer_params:\t", trainer_params) 58 | print("adapt_span_params:\t", adapt_span_params) 59 | 60 | # DATA 61 | train_data, val_data, test_data = get_train_val_test_data( 62 | data_params=data_params, 63 | env_params=env_params, 64 | batch_size=trainer_params["batch_size"], 65 | device=device, 66 | ) 67 | 68 | # MODEL 69 | model = TransformerSeq( 70 | vocab_size=data_params["vocab_size"], 71 | **model_params, 72 | adapt_span_params=adapt_span_params, 73 | ) 74 | print(model) 75 | if distributed: 76 | local_rank = env_params["local_rank"] 77 | model = model.to(device) 78 | model = torch.nn.parallel.DistributedDataParallel( 79 | model, 80 | device_ids=[local_rank], 81 | output_device=local_rank, 82 | find_unused_parameters=True, 83 | ) 84 | else: 85 | model = torch.nn.DataParallel(model) 86 | model = model.to(device) 87 | 88 | # OPTIMIZER AND SCHEDULER 89 | optimizer, scheduler = get_optimizer_and_scheduler( 90 | model=model, optim_params=optim_params 91 | ) 92 | 93 | # create logger 94 | logger = Logger() 95 | fold_name = trainer_params["checkpoint_path"].split("/")[-1].split(".")[0] 96 | folder_path = "/".join(trainer_params["checkpoint_path"].split("/")[:-1]) 97 | logging = create_exp_dir(f"{folder_path}/experiments/{fold_name}") 98 | # log paramters 99 | logging(f"Training Parameters:\n {trainer_params}") 100 | logging(f"Models Parameters:\n {model_params}") 101 | # logging time 102 | current_time = datetime.datetime.now() 103 | logging(str(current_time)) 104 | # log model 105 | logging(str(model)) 106 | logging(f"Total of Parameters: {sum(p.numel() for p in model.parameters())}") 107 | logging( 108 | f"Total of Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}" 109 | ) 110 | # resume training from last checkpoint if exists 111 | iter_init = load_checkpoint( 112 | trainer_params["checkpoint_path"], 113 | model, 114 | optimizer, 115 | scheduler, 116 | logger, 117 | distributed, 118 | resume, 119 | ) 120 | # fix gate 121 | if model_params["smoe_dropout"]: 122 | freeze_gate_weight(model) 123 | # calculate time 124 | start_time = time.time() 125 | # eval model 126 | if trainer_params["full_eval_mode"]: 127 | # evaluate the model on test data 128 | with torch.no_grad(): 129 | loss_val = full_eval( 130 | model, 131 | optimizer, 132 | scheduler, 133 | val_data, 134 | model_params["block_size"], 135 | model_params["hidden_size"], 136 | ) 137 | loss_test = full_eval( 138 | model, 139 | optimizer, 140 | scheduler, 141 | test_data, 142 | model_params["block_size"], 143 | model_params["hidden_size"], 144 | ) 145 | if distributed: 146 | # collect results into rank0 147 | stats = torch.tensor([loss_val, loss_test]).to(device) 148 | torch.distributed.reduce(stats, 0) 149 | if env_params["rank"] == 0: 150 | loss_val = stats[0] / env_params["world_size"] 151 | loss_test = stats[1] / env_params["world_size"] 152 | else: 153 | return 154 | 155 | # print('Test BPC: {:.4f}'.format(loss_test / math.log(2))) 156 | if ("enwik8" in data_params["data_path"]) or ( 157 | "text8" in data_params["data_path"] 158 | ): 159 | logging("Val: {:.3f} BPC".format(loss_val / math.log(2))) 160 | logging("Test: {:.3f} BPC".format(loss_test / math.log(2))) 161 | else: 162 | logging("Val: {:.3f} PPL".format(math.exp(loss_val))) 163 | logging("Test: {:.3f} PPL".format(math.exp(loss_test))) 164 | return 165 | 166 | # position of current batch 167 | data_pos = [0] * 2 168 | # initialize caches for train and valid 169 | hid_cache = [ 170 | [ 171 | torch.zeros( 172 | train_data.size(0), 173 | model.module.layers[layer_i].attn.attn.get_cache_size(), 174 | model_params["hidden_size"], 175 | ).to(device) 176 | for layer_i in range(model.module.attn_layer_count) 177 | ] 178 | for _ in range(2) 179 | ] 180 | 181 | nb_batches_per_iter = trainer_params["nb_batches_per_iter"] 182 | for iter_no in range(iter_init, trainer_params["nb_iter"]): 183 | # freq type 184 | if model_params["freq_type"] == "function": 185 | _threshold = 2.0 / (2.0 + math.sqrt((iter_no + 1))) 186 | set_freq_optimal_search(model, _threshold) 187 | 188 | # time storing 189 | t_sta = time.time() 190 | loss_train, data_pos[0], hid_cache[0] = train_iteration( 191 | model, 192 | model_params["load_balance"], 193 | optimizer, 194 | scheduler, 195 | train_data, 196 | nb_batches_per_iter, 197 | model_params["block_size"], 198 | False, 199 | data_pos[0], 200 | hid_cache[0], 201 | trainer_params["batch_split"], 202 | trainer_params["checkpoint_path"], 203 | ) 204 | elapsed = 1000 * (time.time() - t_sta) / nb_batches_per_iter 205 | with torch.no_grad(): 206 | loss_val, data_pos[1], hid_cache[1] = train_iteration( 207 | model, 208 | model_params["load_balance"], 209 | optimizer, 210 | scheduler, 211 | val_data, 212 | nb_batches_per_iter, 213 | model_params["block_size"], 214 | True, 215 | data_pos[1], 216 | hid_cache[1], 217 | trainer_params["batch_split"], 218 | trainer_params["checkpoint_path"], 219 | ) 220 | 221 | if distributed: 222 | # collect results into rank0 223 | stats = torch.tensor([loss_train, loss_val]).to(device) 224 | torch.distributed.reduce(stats, 0) 225 | if env_params["rank"] == 0: 226 | loss_train = stats[0] / env_params["world_size"] 227 | loss_val = stats[1] / env_params["world_size"] 228 | else: 229 | continue 230 | logging(f"=================== EPOCHS {iter_no} ======================") 231 | if ("enwik8" in data_params["data_path"]) or ( 232 | "text8" in data_params["data_path"] 233 | ): 234 | msg_result = "Epochs: {} | loss_train: {:.3f} ~ {:.3f} BPC | loss_val: {:.3f} ~ {:.3f} BPC | elapsed: {:.1f}".format( 235 | iter_no, 236 | loss_train, 237 | float(loss_train / math.log(2)), 238 | loss_val, 239 | float(loss_val / math.log(2)), 240 | elapsed, 241 | ) 242 | else: 243 | msg_result = "Epochs: {} | loss_train: {:.3f} ~ {:.3f} PPL | loss_val: {:.3f} ~ {:.3f} PPL | elapsed: {:.1f}".format( 244 | iter_no, 245 | loss_train, 246 | float(math.exp(loss_train)), 247 | loss_val, 248 | float(math.exp(loss_val)), 249 | elapsed, 250 | ) 251 | logging(msg_result) 252 | if wandb_flag: 253 | wandb.log({'train_ppl':float(math.exp(loss_train)),'Epoch':iter_no,'valid_ppl':float(math.exp(loss_val))}) 254 | logger.log_iter(iter_no, nb_batches_per_iter, loss_train, loss_val, elapsed, model) 255 | # Save the model if the validation loss is the best we've seen so far. 256 | if (best_val_loss is None) or loss_val < best_val_loss: 257 | best_val_loss = loss_val 258 | save_checkpoint( 259 | trainer_params["checkpoint_path"], 260 | iter_no, 261 | model, 262 | optimizer, 263 | scheduler, 264 | logger, 265 | ) 266 | # save_checkpoint(trainer_params['checkpoint_path'], nb_batches_per_iter, model, optimizer, scheduler, logger) 267 | end_time = time.time() 268 | logging(f"Training time total: {(end_time - start_time)/3600} h") 269 | 270 | 271 | if __name__ == "__main__": 272 | launch(**get_params(params_config=PARAMS_CONFIG)) 273 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import math, random 4 | import torch 5 | import tqdm 6 | 7 | PARAMS_CONFIG = { 8 | # env-specific 9 | "env_params": { 10 | "--distributed": { 11 | "action": "store_true", 12 | "default": False, 13 | "help": "enable distributed training." 14 | "(otherwise will use all available GPUs with dataparallel)", 15 | "dest": "distributed", 16 | }, 17 | "--local_rank": { 18 | "type": int, 19 | "default": int(os.environ['LOCAL_RANK']), 20 | # "default": 0, 21 | "help": "used in distributed training", 22 | "dest": "local_rank", 23 | }, 24 | }, 25 | # data-specific 26 | "data_params": { 27 | "--data": { 28 | "type": str, 29 | "default": "data/text8", 30 | "help": "data location " "(must contain train.txt, valid.txt and test.txt)", 31 | "dest": "data_path", 32 | }, 33 | "--data_name": { 34 | "type": str, 35 | "default": "text8", 36 | "help": "The name of dataset", 37 | "dest": "data_name", 38 | }, 39 | }, 40 | # model-specific 41 | "model_params": { 42 | "--hid-sz": { 43 | "type": int, 44 | "default": 256, 45 | "help": "hidden size (i.e. model size)", 46 | "dest": "hidden_size", 47 | }, 48 | "--inner-hid-sz": { 49 | "type": int, 50 | "default": 1024, 51 | "help": "inner hidden size of FF layer", 52 | "dest": "inner_hidden_size", 53 | }, 54 | "--nlayers": { 55 | "type": int, 56 | "default": 8, 57 | "help": "number of layers", 58 | "dest": "nb_layers", 59 | }, 60 | "--block-sz": { 61 | "type": int, 62 | "default": 64, 63 | "help": "block size " "(the length of sequence to process in parallel)", 64 | "dest": "block_size", 65 | }, 66 | "--nheads": { 67 | "type": int, 68 | "default": 2, 69 | "help": "number of self-attention heads", 70 | "dest": "nb_heads", 71 | }, 72 | "--attn-span": { 73 | "type": int, 74 | "default": 32, 75 | "help": "length of the attention span", 76 | "dest": "attn_span", 77 | }, 78 | "--dropout": { 79 | "type": float, 80 | "default": 0.2, 81 | "help": "dropout rate of ReLU and attention", 82 | "dest": "dropout", 83 | }, 84 | "--architecture": { 85 | "type": str, 86 | "default": None, 87 | "help": "arch", 88 | "dest": "architecture", 89 | }, 90 | "--base_arch": { 91 | "type": str, 92 | "default": None, 93 | "help": "arch", 94 | "dest": "base_arch", 95 | }, 96 | "--smoe_dropout": { 97 | "action": "store_true", 98 | "default": False, 99 | "help": "enable SMoE-drop - Freeze gate", 100 | "dest": "smoe_dropout", 101 | }, 102 | "--optimal_policy": { 103 | "action": "store_true", 104 | "default": False, 105 | "help": "Searching the best routing policy", 106 | "dest": "optimal_policy", 107 | }, 108 | "--load_balance": { 109 | "type": float, 110 | "default": 1.0, 111 | "help": "Ratio of blance loss", 112 | "dest": "load_balance", 113 | }, 114 | "--moe_top_k": { 115 | "type": int, 116 | "default": 2, 117 | "help": "Number of activate experts", 118 | "dest": "moe_top_k", 119 | }, 120 | "--freq": { 121 | "type": float, 122 | "default": 0.03, 123 | "help": "Frequent for searching optimal policy", 124 | "dest": "freq", 125 | }, 126 | "--freq_type": { 127 | "type": str, 128 | "default": "fix", 129 | "help": "Type of frequent for searching optimal policy. Choice: fix or function", 130 | "dest": "freq_type", 131 | }, 132 | "--alpha": { 133 | "type": float, 134 | "default": 1.0, 135 | "help": "Impact of optimal loss", 136 | "dest": "alpha", 137 | }, 138 | "--gate_name": { 139 | "type": str, 140 | "default": "smoe", 141 | "help": "Names of gates: smoe, smoe-dropout, xmoe, stablemoe", 142 | "dest": "gate_name", 143 | }, 144 | "--act_experts": { 145 | "type": str, 146 | "default": "shuffle", 147 | "help": "Type to activate all experts: shuffle OR linear", 148 | "dest": "act_experts", 149 | }, 150 | "--g_blance": { 151 | "action": "store_true", 152 | "default": False, 153 | "help": "Activate balance loss for router", 154 | "dest": "g_blance", 155 | }, 156 | "--opt_blance": { 157 | "action": "store_true", 158 | "default": False, 159 | "help": "Activate blancing for optimal router", 160 | "dest": "opt_blance", 161 | }, 162 | "--combine_gate": { 163 | "action": "store_true", 164 | "default": False, 165 | "help": "Utilize previous information for better consistancy", 166 | "dest": "combine_gate", 167 | }, 168 | "--opt_loss": { 169 | "type": str, 170 | "default": "mse", 171 | "help": "Type of loss for optimal policy searching", 172 | "dest": "opt_loss", 173 | }, 174 | "--gamma1": { 175 | "type": float, 176 | "default": 1.0, 177 | "help": "Adam decay parameter", 178 | "dest": "gamma1", 179 | }, 180 | "--gamma2": { 181 | "type": float, 182 | "default": 1.0, 183 | "help": "Momentum learning rate", 184 | "dest": "gamma2", 185 | }, 186 | "--mu": { 187 | "type": float, 188 | "default": 0.9, 189 | "help": "Momentum parameter", 190 | "dest": "mu", 191 | }, 192 | "--beta1": { 193 | "type": float, 194 | "default": 0.9, 195 | "help": "ADAM parameter", 196 | "dest": "beta1", 197 | }, 198 | "--beta2": { 199 | "type": float, 200 | "default": 0.999, 201 | "help": "ADAM parameter", 202 | "dest": "beta2", 203 | }, 204 | }, 205 | # optimization-specific 206 | "optim_params": { 207 | "--lr": {"type": float, "default": 0.03, "help": "learning rate", "dest": "lr"}, 208 | "--momentum": { 209 | "type": float, 210 | "default": 0.9, 211 | "help": "SGD momentum", 212 | "dest": "momentum", 213 | }, 214 | "--optim": { 215 | "type": str, 216 | "default": "sgd", 217 | "help": "optimization method: sgd | adagrad", 218 | "dest": "optim", 219 | }, 220 | "--lr-warmup": { 221 | "type": int, 222 | "default": 0, 223 | "help": "linearly increase LR from 0 " "during first lr_warmup updates", 224 | "dest": "lr_warmup", 225 | }, 226 | "--grad-clip": { 227 | "type": float, 228 | "default": 0, 229 | "help": "[only works with adagrad!] " 230 | "clip gradient of each module parameters by a given " 231 | "value", 232 | "dest": "grad_clip", 233 | }, 234 | }, 235 | # trainer-specific 236 | "trainer_params": { 237 | "--batch-sz": { 238 | "type": int, 239 | "default": 64, 240 | "help": "batch size", 241 | "dest": "batch_size", 242 | }, 243 | "--batch-split": { 244 | "type": int, 245 | "default": 1, 246 | "help": "split a batch into smaller parts to fit in GPU memory", 247 | "dest": "batch_split", 248 | }, 249 | "--nbatches": { 250 | "type": int, 251 | "default": 1000, 252 | "help": "number of batches in each iteration", 253 | "dest": "nb_batches_per_iter", 254 | }, 255 | "--niter": { 256 | "type": int, 257 | "default": 1000, 258 | "help": "number of iterations to train", 259 | "dest": "nb_iter", 260 | }, 261 | "--checkpoint": { 262 | "type": str, 263 | "default": "", 264 | "help": "path to save/load model", 265 | "dest": "checkpoint_path", 266 | }, 267 | "--resume": { 268 | "action": "store_true", 269 | "default": False, 270 | "help": "resume training", 271 | "dest": "resume", 272 | }, 273 | "--pretrained_weight": { 274 | "type": str, 275 | "default": "", 276 | "help": "path to save/load model", 277 | "dest": "pretrained_weight", 278 | }, 279 | "--full-eval-mode": { 280 | "action": "store_true", 281 | "default": False, 282 | "help": "do evaluation on the whole validation and the test data", 283 | "dest": "full_eval_mode", 284 | }, 285 | }, 286 | # adaptive attention span specific params 287 | "adapt_span_params": { 288 | "--adapt-span": { 289 | "action": "store_true", 290 | "default": False, 291 | "help": "enable adaptive attention span", 292 | "dest": "adapt_span_enabled", 293 | }, 294 | "--adapt-span-loss": { 295 | "type": float, 296 | "default": 0, 297 | "help": "the loss coefficient for span lengths", 298 | "dest": "adapt_span_loss", 299 | }, 300 | "--adapt-span-ramp": { 301 | "type": int, 302 | "default": 32, 303 | "help": "ramp length of the soft masking function", 304 | "dest": "adapt_span_ramp", 305 | }, 306 | "--adapt-span-init": { 307 | "type": float, 308 | "default": 0, 309 | "help": "initial attention span ratio", 310 | "dest": "adapt_span_init", 311 | }, 312 | "--adapt-span-cache": { 313 | "action": "store_true", 314 | "default": False, 315 | "help": "adapt cache size as well to reduce memory usage", 316 | "dest": "adapt_span_cache", 317 | }, 318 | }, 319 | "wandb_params": { 320 | "--project-name": { 321 | "type": str, 322 | "default": "project_name", 323 | "help": "wandb project name", 324 | "dest": "project_name", 325 | }, 326 | "--job-name": { 327 | "type": str, 328 | "default": "job_name", 329 | "help": "wandb job name", 330 | "dest": "job_name", 331 | }, 332 | "--wandb-flag": { 333 | "action": "store_true", 334 | "default": False, 335 | "help": "use wandb", 336 | "dest": "wandb_flag", 337 | }, 338 | }, 339 | } 340 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /custom_layers.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import math, random 4 | import torch 5 | import torch.nn as nn 6 | import tree 7 | from custom_functions import prepare_forward, ensure_comm 8 | from custom_functions import MOEScatter, MOEGather 9 | from custom_functions import AllGather, Slice 10 | from gates import NaiveGate 11 | 12 | from fastermoe.config import switch_from_env 13 | 14 | 15 | def mark_module_parallel_comm(module, comm): 16 | r""" 17 | Mark all parameters in `module` as doing data parallel in `comm`, where 18 | `comm` may be one of `'world', 'dp', 'none'`. 19 | """ 20 | for p in module.parameters(): 21 | setattr(p, "dp_comm", comm) 22 | 23 | 24 | def _fmoe_general_global_forward( 25 | inp, gate, expert_fn, num_expert, world_size, **kwargs 26 | ): 27 | r""" 28 | A private function that performs the following steps to complete the MoE 29 | computation. 30 | * Count the number of tokens from each worker to each expert. 31 | * Send the features to their target position so that input features to each 32 | expert are contiguous in memory. 33 | * Perform the forward computation of the experts using `expert_fn` 34 | * Gather the output features of experts back, and reorder them as sentences. 35 | Intermediate results like expert counts are hidden from users by this 36 | function. 37 | """ 38 | ( 39 | pos, 40 | local_expert_count, 41 | global_expert_count, 42 | fwd_expert_count, 43 | fwd_batch_size, 44 | ) = prepare_forward(gate, num_expert, world_size) 45 | topk = 1 46 | if len(gate.shape) == 2: 47 | topk = gate.shape[1] 48 | 49 | def scatter_func(tensor): 50 | return MOEScatter.apply( 51 | tensor, 52 | torch.div(pos, topk, rounding_mode="floor"), 53 | local_expert_count, 54 | global_expert_count, 55 | fwd_batch_size, 56 | world_size, 57 | ) 58 | 59 | x = tree.map_structure(scatter_func, inp) 60 | 61 | x = expert_fn(x, fwd_expert_count) 62 | 63 | out_batch_size = tree.flatten(inp)[0].shape[0] 64 | if len(gate.shape) == 2: 65 | out_batch_size *= gate.shape[1] 66 | 67 | def gather_func(tensor): 68 | return MOEGather.apply( 69 | tensor, 70 | pos, 71 | local_expert_count, 72 | global_expert_count, 73 | out_batch_size, 74 | world_size, 75 | ) 76 | 77 | outp = tree.map_structure(gather_func, x) 78 | return outp 79 | 80 | 81 | fmoe_faster_schedule = False 82 | if switch_from_env("FMOE_FASTER_SCHEDULE_ENABLE", False): 83 | fmoe_faster_schedule = True 84 | from .fastermoe.schedule import _fmoe_general_global_forward 85 | 86 | 87 | class FMoE(nn.Module): 88 | r""" 89 | A general moe implementation that supports an arbitrary module as the 90 | expert. 91 | * `num_expert` stands for the number of experts on **each** worker. 92 | * `world_size` stands for the total number of workers that contains 93 | different experts. 94 | * `slice_group` can be a torch's communication group, indicating that 95 | specific model parallel is applied across the group, and workers in the 96 | group hold the same copy of input feature, and requires the same copy of 97 | the output. For each worker, FMoE only computes the output of a certain 98 | slice of the input batch, and will all-gather the outputs after 99 | computation. 100 | * `top_k` stands for the number of experts each token is going to. 101 | * `gate` is a gate class which can found in `fmoe.gates`. 102 | * `expert` can be specified as a module class, it is used to generate 103 | `num_expert` expert modules. 104 | """ 105 | 106 | def __init__( 107 | self, 108 | num_expert=32, 109 | d_model=1024, 110 | world_size=1, 111 | mp_group=None, # being deprecated 112 | slice_group=None, 113 | moe_group=None, 114 | moe_top_k=2, 115 | gate=NaiveGate, 116 | expert=None, 117 | gate_hook=None, 118 | mask=None, 119 | mask_dict=None, 120 | ): 121 | super().__init__() 122 | self.num_expert = num_expert 123 | self.d_model = d_model 124 | self.world_size = world_size 125 | 126 | self.slice_group = slice_group 127 | if mp_group is not None: 128 | print("[Warning] mp_group is being deprecated") 129 | self.slice_group = mp_group 130 | if self.slice_group is None: 131 | self.slice_size = 1 132 | self.slice_rank = 0 133 | else: 134 | self.slice_size = self.slice_group.size() 135 | self.slice_rank = self.slice_group.rank() 136 | 137 | self.top_k = moe_top_k 138 | if type(expert) is list: 139 | self.experts = nn.ModuleList([e(d_model) for e in expert]) 140 | self.experts_fused = False 141 | self.num_expert = num_expert = len(expert) 142 | elif expert is not None: 143 | self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)]) 144 | self.experts_fused = False 145 | else: 146 | self.experts_fused = True 147 | 148 | self.gate = gate(d_model, num_expert, world_size, moe_top_k) 149 | self.gate_hook = gate_hook 150 | self.mask = mask 151 | self.mask_dict = mask_dict 152 | self.moe_group = moe_group 153 | 154 | def expert_fn(self, inp, fwd_expert_count): 155 | r""" 156 | The default expert function which either calls the experts as a whole 157 | or as separate experts. 158 | """ 159 | if self.experts_fused: 160 | return self.experts(inp, fwd_expert_count) 161 | if isinstance(fwd_expert_count, torch.Tensor): 162 | fwd_expert_count = fwd_expert_count.cpu().numpy() 163 | outputs = [] 164 | base_idx = 0 165 | for i in range(self.num_expert): 166 | batch_size = fwd_expert_count[i] 167 | inp_slice = inp[base_idx : base_idx + batch_size] 168 | outputs.append(self.experts[i](inp_slice)) 169 | base_idx += batch_size 170 | return torch.cat(outputs, dim=0) 171 | 172 | def mark_parallel_comm(self, expert_dp_comm="none"): 173 | r""" 174 | Automatically mark the data parallel comms of the parameters within the 175 | module. This can be typically called at the end of the __init__ function 176 | in child classes. 177 | """ 178 | if self.experts is not None: 179 | comm = expert_dp_comm 180 | if isinstance(self.experts, list): 181 | for e in self.experts: 182 | mark_module_parallel_comm(e, comm) 183 | else: 184 | mark_module_parallel_comm(self.experts, comm) 185 | mark_module_parallel_comm(self.gate, "gate") 186 | 187 | def forward(self, moe_inp): 188 | r""" 189 | The FMoE module first computes gate output, and then conduct MoE forward 190 | according to the gate. The score of the selected gate given by the 191 | expert is multiplied to the experts' output tensors as a weight. 192 | """ 193 | 194 | moe_inp_batch_size = tree.flatten( 195 | tree.map_structure(lambda tensor: tensor.shape[0], moe_inp) 196 | ) 197 | assert all( 198 | [batch_size == moe_inp_batch_size[0] for batch_size in moe_inp_batch_size] 199 | ), "MoE inputs must have the same batch size" 200 | 201 | if self.world_size > 1: 202 | 203 | def ensure_comm_func(tensor): 204 | ensure_comm(tensor, self.moe_group) 205 | 206 | tree.map_structure(ensure_comm_func, moe_inp) 207 | if self.slice_size > 1: 208 | 209 | def slice_func(tensor): 210 | return Slice.apply( 211 | tensor, self.slice_rank, self.slice_size, self.slice_group 212 | ) 213 | 214 | moe_inp = tree.map_structure(slice_func, moe_inp) 215 | 216 | gate_top_k_idx, gate_score = self.gate(moe_inp) 217 | 218 | if hasattr(self.gate, "dynamic_top_k"): 219 | self.top_k = self.gate.dynamic_top_k 220 | 221 | if self.gate_hook is not None: 222 | self.gate_hook(gate_top_k_idx, gate_score, None) 223 | 224 | # delete masked tensors 225 | if self.mask is not None and self.mask_dict is not None: 226 | # TODO: to fix 227 | def delete_mask_func(tensor): 228 | # to: (BxL') x d_model 229 | tensor = tensor[mask == 0, :] 230 | return tensor 231 | 232 | mask = self.mask.view(-1) 233 | moe_inp = tree.map_structure(delete_mask_func, moe_inp) 234 | gate_top_k_idx = gate_top_k_idx[mask == 0, :] 235 | 236 | fwd = _fmoe_general_global_forward( 237 | moe_inp, 238 | gate_top_k_idx, 239 | self.expert_fn, 240 | self.num_expert, 241 | self.world_size, 242 | experts=self.experts, 243 | ) 244 | 245 | # recover deleted tensors 246 | if self.mask is not None and self.mask_dict is not None: 247 | 248 | def recover_func(tensor): 249 | # to: (BxL') x top_k x dim 250 | dim = tensor.shape[-1] 251 | tensor = tensor.view(-1, self.top_k, dim) 252 | # to: (BxL) x top_k x d_model 253 | x = torch.zeros( 254 | mask.shape[0], 255 | self.top_k, 256 | dim, 257 | device=tensor.device, 258 | dtype=tensor.dtype, 259 | ) 260 | # recover 261 | x[mask == 0] = tensor 262 | for k, v in self.mask_dict.items(): 263 | x[mask == k] = v 264 | return x 265 | 266 | moe_outp = tree.map_structure(recover_func, fwd) 267 | else: 268 | 269 | def view_func(tensor): 270 | dim = tensor.shape[-1] 271 | tensor = tensor.view(-1, self.top_k, dim) 272 | return tensor 273 | 274 | moe_outp = tree.map_structure(view_func, fwd) 275 | 276 | gate_score = gate_score.view(-1, 1, self.top_k) 277 | 278 | def bmm_func(tensor): 279 | dim = tensor.shape[-1] 280 | tensor = torch.bmm(gate_score, tensor).reshape(-1, dim) 281 | return tensor 282 | 283 | moe_outp = tree.map_structure(bmm_func, moe_outp) 284 | 285 | if self.slice_size > 1: 286 | 287 | def all_gather_func(tensor): 288 | return AllGather.apply( 289 | tensor, self.slice_rank, self.slice_size, self.slice_group 290 | ) 291 | 292 | moe_outp = tree.map_structure(all_gather_func, moe_outp) 293 | 294 | moe_outp_batch_size = tree.flatten( 295 | tree.map_structure(lambda tensor: tensor.shape[0], moe_outp) 296 | ) 297 | assert all( 298 | [batch_size == moe_outp_batch_size[0] for batch_size in moe_outp_batch_size] 299 | ), "MoE outputs must have the same batch size" 300 | return moe_outp 301 | 302 | 303 | ############################################################################## 304 | 305 | import torch 306 | import torch.nn as nn 307 | import math 308 | import fmoe_cuda 309 | from torch.autograd import Function 310 | 311 | 312 | class MOELinear(Function): 313 | r""" 314 | Computes linear operators within one GPU on different experts simutaneously. 315 | """ 316 | 317 | @staticmethod 318 | def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None): 319 | global_output_buf = fmoe_cuda.linear_forward( 320 | global_input_buf, fwd_expert_count, weight, bias 321 | ) 322 | variables = (global_input_buf, fwd_expert_count, weight, bias) 323 | ctx.save_for_backward(*variables) 324 | return global_output_buf 325 | 326 | @staticmethod 327 | def backward(ctx, grad_out): 328 | (input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors 329 | grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward( 330 | grad_out, input_buf, fwd_expert_count, weight, bias 331 | ) 332 | 333 | if not torch.is_tensor(bias): 334 | grad_bias = None 335 | 336 | return grad_inp_buf, None, grad_weight, grad_bias 337 | 338 | 339 | class FMoELinear(nn.Module): 340 | r""" 341 | A linear layer that contains multiple experts. 342 | As multiple experts can be placed on the same worker, the computation can be 343 | performed in parallel to increase the performance. 344 | The FMoELinear module provides such function. 345 | """ 346 | 347 | def __init__( 348 | self, 349 | num_expert: int, 350 | in_feat: int, 351 | out_feat: int, 352 | bias: bool = True, 353 | rank: int = 0, 354 | ): 355 | super().__init__() 356 | self.num_expert = num_expert 357 | self.in_feat = in_feat 358 | self.out_feat = out_feat 359 | self.rank = rank 360 | self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat)) 361 | if bias: 362 | self.bias = nn.Parameter(torch.zeros(num_expert, out_feat)) 363 | else: 364 | self.register_parameter("bias", None) 365 | 366 | self.reset_parameters() 367 | 368 | def forward(self, inp, fwd_expert_count): 369 | r""" 370 | Call MOE function 371 | """ 372 | x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias) 373 | return x 374 | 375 | def extra_repr(self) -> str: 376 | return "num_expert={}, in_features={}, \ 377 | out_features={}, bias={}, rank={}".format( 378 | self.num_expert, 379 | self.in_feat, 380 | self.out_feat, 381 | self.bias is not None, 382 | self.rank, 383 | ) 384 | 385 | def reset_parameters(self): 386 | # Approach is the same as in torch.nn.Linear 387 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88 388 | # bias is left to zero, similar as megatron 389 | 390 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 391 | -------------------------------------------------------------------------------- /vocabulary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import csv 4 | import json 5 | import torch 6 | from collections import Counter, OrderedDict 7 | import torch.nn.functional as F 8 | from torch.nn.utils.rnn import pad_sequence 9 | 10 | 11 | class Vocab(object): 12 | def __init__( 13 | self, 14 | special=[], 15 | min_freq=0, 16 | max_size=None, 17 | lower_case=True, 18 | delimiter=None, 19 | vocab_file=None, 20 | ): 21 | self.counter = Counter() 22 | self.special = special 23 | self.min_freq = min_freq 24 | self.max_size = max_size 25 | self.lower_case = lower_case 26 | self.delimiter = delimiter 27 | self.vocab_file = vocab_file 28 | 29 | def tokenize( 30 | self, 31 | line, 32 | add_eos=False, 33 | add_double_eos=False, 34 | add_cls_token=False, 35 | add_s=False, 36 | add_cls_token_last=False, 37 | ): 38 | line = line.strip() 39 | # convert to lower case 40 | if self.lower_case: 41 | line = line.lower() 42 | 43 | # empty delimiter '' will evaluate False 44 | if self.delimiter == "": 45 | symbols = line 46 | else: 47 | symbols = line.split(self.delimiter) 48 | 49 | if add_cls_token: 50 | return [""] + symbols + [""] 51 | elif add_cls_token_last: 52 | return [""] + symbols + [""] 53 | elif add_double_eos: # lm1b 54 | return [""] + symbols + [""] 55 | elif add_eos: 56 | return symbols + [""] 57 | elif add_s: 58 | return symbols + [""] 59 | else: 60 | return symbols 61 | 62 | def count_file(self, path, verbose=False, add_eos=False): 63 | if verbose: 64 | print("counting file {} ...".format(path)) 65 | assert os.path.exists(path) 66 | 67 | sents = [] 68 | with open(path, "r", encoding="utf-8") as f: 69 | for idx, line in enumerate(f): 70 | if verbose and idx > 0 and idx % 500000 == 0: 71 | print(" line {}".format(idx)) 72 | symbols = self.tokenize(line, add_eos=add_eos) 73 | self.counter.update(symbols) 74 | sents.append(symbols) 75 | 76 | return sents 77 | 78 | def count_csqa( 79 | self, 80 | path, 81 | num_classes=5, 82 | verbose=False, 83 | add_eos=False, 84 | add_double_eos=False, 85 | add_cls_token=False, 86 | ): 87 | if verbose: 88 | print("counting file {} ...".format(path)) 89 | assert os.path.exists(path) 90 | 91 | sents = [] 92 | with open(path, "r", encoding="utf-8") as f: 93 | for idx, line in enumerate(f): 94 | if verbose and idx > 0 and idx % 500000 == 0: 95 | print(" line {}".format(idx)) 96 | example = json.loads(line.strip()) 97 | question = example["question"]["stem"] 98 | assert len(example["question"]["choices"]) == num_classes 99 | # format: ` Q: Where would I not want a fox? A: hen house ` 100 | question = "Q: " + question 101 | question_toks = self.tokenize( 102 | question, 103 | add_eos=add_eos, 104 | add_double_eos=add_double_eos, 105 | add_cls_token=add_cls_token, 106 | ) 107 | for i, choice in enumerate(example["question"]["choices"]): 108 | src = "A: " + choice["text"] 109 | assert (ord(choice["label"]) - ord("A")) == i 110 | src_bin = self.tokenize(src, add_eos=add_eos) 111 | question_toks.extend(src_bin) 112 | self.counter.update(question_toks) 113 | sents.append(question_toks) 114 | return sents 115 | 116 | def count_sst2( 117 | self, 118 | path, 119 | verbose=False, 120 | add_eos=False, 121 | add_double_eos=False, 122 | add_cls_token=False, 123 | ): 124 | if verbose: 125 | print("counting file {} ...".format(path)) 126 | assert os.path.exists(path) 127 | sents = [] 128 | with open(path, "r", encoding="utf-8") as f: 129 | tsv_file = csv.reader(f, delimiter="\t") 130 | for line in tsv_file: 131 | if not line[1] in ["0", "1"]: 132 | print('* Ignore ', line) 133 | continue 134 | sentence, label = line[0], int(line[1]) 135 | assert label in [0, 1] 136 | sentence_toks = self.tokenize( 137 | sentence, 138 | add_eos=add_eos, 139 | add_double_eos=add_double_eos, 140 | add_cls_token=add_cls_token, 141 | ) 142 | self.counter.update(sentence_toks) 143 | sents.append(sentence_toks) 144 | return sents 145 | 146 | def count_sst5( 147 | self, 148 | dataset, 149 | verbose=False, 150 | add_eos=False, 151 | add_double_eos=False, 152 | add_cls_token=False, 153 | ): 154 | # if verbose: 155 | # print("counting file {} ...".format(path)) 156 | # assert os.path.exists(path) 157 | sents = [] 158 | for sample in dataset: 159 | # print("here:", sample) 160 | sample = sample.to_labeled_lines()[0] 161 | if not sample[0] in [0, 1, 2, 3, 4]: 162 | print("* Ignore ", sample) 163 | continue 164 | sentence, label = sample[1], sample[0] 165 | # assert label in [0, 1, 2, 3, 4] 166 | sentence_toks = self.tokenize( 167 | sentence, 168 | add_eos=add_eos, 169 | add_double_eos=add_double_eos, 170 | add_cls_token=add_cls_token, 171 | ) 172 | self.counter.update(sentence_toks) 173 | sents.append(sentence_toks) 174 | return sents 175 | 176 | def count_banking77( 177 | self, 178 | path, 179 | verbose=False, 180 | add_eos=False, 181 | add_double_eos=False, 182 | add_cls_token=False, 183 | ): 184 | if verbose: 185 | print("counting file {} ...".format(path)) 186 | assert os.path.exists(path) 187 | sents = [] 188 | with open(path, "r", encoding="utf-8") as f: 189 | tsv_file = csv.reader(f, delimiter="\t") 190 | for line in tsv_file: 191 | if not line[1] in [str(x) for x in list(range(77))]: 192 | # print('* Ignore ', line) 193 | continue 194 | sentence, label = line[0], int(line[1]) 195 | assert label in list(range(77)) 196 | sentence_toks = self.tokenize( 197 | sentence, 198 | add_eos=add_eos, 199 | add_double_eos=add_double_eos, 200 | add_cls_token=add_cls_token, 201 | ) 202 | self.counter.update(sentence_toks) 203 | sents.append(sentence_toks) 204 | return sents 205 | 206 | def count_sents(self, sents, verbose=False): 207 | """ 208 | sents : a list of sentences, each a list of tokenized symbols 209 | """ 210 | if verbose: 211 | print("counting {} sents ...".format(len(sents))) 212 | for idx, symbols in enumerate(sents): 213 | if verbose and idx > 0 and idx % 500000 == 0: 214 | print(" line {}".format(idx)) 215 | self.counter.update(symbols) 216 | 217 | def _build_from_file(self, vocab_file): 218 | self.idx2sym = [] 219 | self.sym2idx = OrderedDict() 220 | 221 | with open(vocab_file, "r", encoding="utf-8") as f: 222 | for line in f: 223 | symb = line.strip().split()[0] 224 | self.add_symbol(symb) 225 | self.unk_idx = self.sym2idx[""] 226 | 227 | def build_vocab(self): 228 | if self.vocab_file: 229 | print("building vocab from {}".format(self.vocab_file)) 230 | self._build_from_file(self.vocab_file) 231 | print("final vocab size {}".format(len(self))) 232 | else: 233 | print( 234 | "building vocab with min_freq={}, max_size={}".format( 235 | self.min_freq, self.max_size 236 | ) 237 | ) 238 | self.idx2sym = [] 239 | self.sym2idx = OrderedDict() 240 | 241 | for sym in self.special: 242 | self.add_special(sym) 243 | 244 | for sym, cnt in self.counter.most_common(self.max_size): 245 | if cnt < self.min_freq: 246 | break 247 | self.add_symbol(sym) 248 | 249 | print( 250 | "final vocab size {} from {} unique tokens".format( 251 | len(self), len(self.counter) 252 | ) 253 | ) 254 | 255 | def encode_file( 256 | self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False 257 | ): 258 | if verbose: 259 | print("encoding file {} ...".format(path)) 260 | assert os.path.exists(path) 261 | encoded = [] 262 | with open(path, "r", encoding="utf-8") as f: 263 | for idx, line in enumerate(f): 264 | if verbose and idx > 0 and idx % 500000 == 0: 265 | print(" line {}".format(idx)) 266 | symbols = self.tokenize( 267 | line, add_eos=add_eos, add_double_eos=add_double_eos 268 | ) 269 | encoded.append(self.convert_to_tensor(symbols)) 270 | 271 | if ordered: 272 | encoded = torch.cat(encoded) 273 | 274 | return encoded 275 | 276 | def encode_csqa_file( 277 | self, 278 | path, 279 | ordered=False, 280 | num_classes=5, 281 | verbose=False, 282 | add_eos=False, 283 | add_double_eos=False, 284 | add_cls_token=False, 285 | ): 286 | if verbose: 287 | print("encoding file {} ...".format(path)) 288 | assert os.path.exists(path) 289 | encoded = [[] for i in range(num_classes)] 290 | labels = [] 291 | 292 | with open(path, "r", encoding="utf-8") as f: 293 | for idx, line in enumerate(f): 294 | if verbose and idx > 0 and idx % 500000 == 0: 295 | print(" line {}".format(idx)) 296 | example = json.loads(line.strip()) 297 | if "answerKey" in example: 298 | label = ord(example["answerKey"]) - ord("A") 299 | labels.append(label) 300 | question = example["question"]["stem"] 301 | assert len(example["question"]["choices"]) == num_classes 302 | # format: ` Q: Where would I not want a fox? A: hen house ` 303 | question = "Q: " + question 304 | question_bin = self.tokenize( 305 | question, 306 | add_eos=add_eos, 307 | add_double_eos=add_double_eos, 308 | add_cls_token=add_cls_token, 309 | ) 310 | for i, choice in enumerate(example["question"]["choices"]): 311 | src = " A: " + choice["text"] 312 | assert (ord(choice["label"]) - ord("A")) == i 313 | src_bin = question_bin + self.tokenize(src, add_s=True) 314 | encoded[i].append(self.convert_to_tensor(src_bin)) 315 | 316 | labels = torch.LongTensor(labels) 317 | 318 | # pdb.set_trace() 319 | 320 | # if ordered: 321 | # for idx in range(num_classes): 322 | # encoded[idx] = pad_sequence(encoded[idx]) 323 | 324 | # encoded = pad_sequence(encoded) 325 | # print(encoded.shape) 326 | 327 | return [encoded, labels] 328 | 329 | def encode_sst2_file( 330 | self, 331 | path, 332 | verbose=False, 333 | add_eos=False, 334 | add_double_eos=False, 335 | add_cls_token=False, 336 | ): 337 | if verbose: 338 | print("encoding file {} ...".format(path)) 339 | assert os.path.exists(path) 340 | encoded = [] 341 | labels = [] 342 | with open(path, "r", encoding="utf-8") as f: 343 | tsv_file = csv.reader(f, delimiter="\t") 344 | for line in tsv_file: 345 | if not line[1] in ["0", "1"]: 346 | print(line[0]) 347 | print("* Ignore ", line) 348 | continue 349 | sentence, label = line[0], int(line[1]) 350 | assert label in [0, 1] 351 | sentence_toks = self.tokenize( 352 | sentence, 353 | add_eos=add_eos, 354 | add_double_eos=add_double_eos, 355 | add_cls_token=add_cls_token, 356 | ) 357 | encoded.append(self.convert_to_tensor(sentence_toks)) 358 | labels.append(label) 359 | 360 | labels = torch.LongTensor(labels) 361 | return [encoded, labels] 362 | 363 | def encode_sst5_file( 364 | self, 365 | dataset, 366 | verbose=False, 367 | add_eos=False, 368 | add_double_eos=False, 369 | add_cls_token=False, 370 | ): 371 | # if verbose: 372 | # print("encoding file {} ...".format(path)) 373 | # assert os.path.exists(path) 374 | encoded = [] 375 | labels = [] 376 | for sample in dataset: 377 | sample = sample.to_labeled_lines()[0] 378 | if not sample[0] in [0, 1, 2, 3, 4]: 379 | print("* Ignore ", sample) 380 | continue 381 | sentence, label = sample[1], sample[0] 382 | # assert label in [0, 1, 2, 3, 4] 383 | sentence_toks = self.tokenize( 384 | sentence, 385 | add_eos=add_eos, 386 | add_double_eos=add_double_eos, 387 | add_cls_token=add_cls_token, 388 | ) 389 | encoded.append(self.convert_to_tensor(sentence_toks)) 390 | labels.append(label) 391 | 392 | labels = torch.LongTensor(labels) 393 | return [encoded, labels] 394 | 395 | def encode_banking77_file( 396 | self, 397 | path, 398 | verbose=False, 399 | add_eos=False, 400 | add_double_eos=False, 401 | add_cls_token=False, 402 | ): 403 | if verbose: 404 | print("encoding file {} ...".format(path)) 405 | assert os.path.exists(path) 406 | encoded = [] 407 | labels = [] 408 | with open(path, "r", encoding="utf-8") as f: 409 | tsv_file = csv.reader(f, delimiter="\t") 410 | for line in tsv_file: 411 | if not line[1] in [str(x) for x in list(range(77))]: 412 | print("* Ignore ", line) 413 | continue 414 | sentence, label = line[0], int(line[1]) 415 | assert label in list(range(77)) 416 | sentence_toks = self.tokenize( 417 | sentence, 418 | add_eos=add_eos, 419 | add_double_eos=add_double_eos, 420 | add_cls_token=add_cls_token, 421 | ) 422 | encoded.append(self.convert_to_tensor(sentence_toks)) 423 | labels.append(label) 424 | 425 | labels = torch.LongTensor(labels) 426 | return [encoded, labels] 427 | 428 | def encode_sents(self, sents, ordered=False, verbose=False): 429 | if verbose: 430 | print("encoding {} sents ...".format(len(sents))) 431 | encoded = [] 432 | for idx, symbols in enumerate(sents): 433 | if verbose and idx > 0 and idx % 500000 == 0: 434 | print(" line {}".format(idx)) 435 | encoded.append(self.convert_to_tensor(symbols)) 436 | 437 | if ordered: 438 | encoded = torch.cat(encoded) 439 | 440 | return encoded 441 | 442 | def add_special(self, sym): 443 | if sym not in self.sym2idx: 444 | self.idx2sym.append(sym) 445 | self.sym2idx[sym] = len(self.idx2sym) - 1 446 | setattr(self, "{}_idx".format(sym.strip("<>")), self.sym2idx[sym]) 447 | 448 | def add_symbol(self, sym): 449 | if sym not in self.sym2idx: 450 | self.idx2sym.append(sym) 451 | self.sym2idx[sym] = len(self.idx2sym) - 1 452 | 453 | def get_sym(self, idx): 454 | assert 0 <= idx < len(self), "Index {} out of range".format(idx) 455 | return self.idx2sym[idx] 456 | 457 | def get_idx(self, sym): 458 | if sym in self.sym2idx: 459 | return self.sym2idx[sym] 460 | else: 461 | # print('encounter unk {}'.format(sym)) 462 | print(sym) 463 | assert "" not in sym 464 | assert hasattr(self, "unk_idx") 465 | return self.sym2idx.get(sym, self.unk_idx) 466 | 467 | def get_symbols(self, indices): 468 | return [self.get_sym(idx) for idx in indices] 469 | 470 | def get_indices(self, symbols): 471 | return [self.get_idx(sym) for sym in symbols] 472 | 473 | def convert_to_tensor(self, symbols): 474 | return torch.LongTensor(self.get_indices(symbols)) 475 | 476 | def convert_to_sent(self, indices, exclude=None): 477 | if exclude is None: 478 | return " ".join([self.get_sym(idx) for idx in indices]) 479 | else: 480 | return " ".join( 481 | [self.get_sym(idx) for idx in indices if idx not in exclude] 482 | ) 483 | 484 | def __len__(self): 485 | return len(self.idx2sym) 486 | -------------------------------------------------------------------------------- /finetune_data.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import glob 3 | import pdb 4 | from collections import Counter, OrderedDict 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from vocabulary import Vocab 9 | import torch.nn.functional as F 10 | from torch.nn.utils.rnn import pad_sequence 11 | import pytreebank 12 | 13 | 14 | def pad_sequence_reverse(data): 15 | # data should be a list of 1D tensors 16 | 17 | assert data[0].dim() == 1 18 | device = data[0].device 19 | length_list = [] 20 | for item in data: 21 | length_list.append(item.shape[0]) 22 | max_length = max(length_list) 23 | 24 | # padding 25 | padded_data_list = [] 26 | for item in data: 27 | padded_item = torch.cat( 28 | [torch.zeros(max_length - item.shape[0], dtype=item.dtype).to(device), item] 29 | ).reshape(-1, 1) 30 | padded_data_list.append(padded_item) 31 | padded_data_list = torch.cat(padded_data_list, dim=1) 32 | return padded_data_list 33 | 34 | 35 | class LMOrderedIterator(object): 36 | def __init__(self, data, bsz, bptt, device="cpu", ext_len=None): 37 | """ 38 | data -- LongTensor -- the LongTensor is strictly ordered 39 | """ 40 | self.bsz = bsz 41 | self.bptt = bptt 42 | self.ext_len = ext_len if ext_len is not None else 0 43 | 44 | self.device = device 45 | 46 | # Work out how cleanly we can divide the dataset into bsz parts. 47 | self.n_step = data.size(0) // bsz 48 | 49 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 50 | data = data.narrow(0, 0, self.n_step * bsz) 51 | 52 | # Evenly divide the data across the bsz batches. 53 | self.data = data.view(bsz, -1).t().contiguous().to(device) 54 | 55 | # Number of mini-batches 56 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt 57 | 58 | def get_batch(self, i, bptt=None): 59 | if bptt is None: 60 | bptt = self.bptt 61 | seq_len = min(bptt, self.data.size(0) - 1 - i) 62 | 63 | end_idx = i + seq_len 64 | beg_idx = max(0, i - self.ext_len) 65 | 66 | data = self.data[beg_idx:end_idx] 67 | target = self.data[i + 1 : i + 1 + seq_len] 68 | 69 | return data, target, seq_len 70 | 71 | def get_fixlen_iter(self, start=0): 72 | for i in range(start, self.data.size(0) - 1, self.bptt): 73 | yield self.get_batch(i) 74 | 75 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 76 | max_len = self.bptt + max_deviation * std 77 | i = start 78 | while True: 79 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.0 80 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) 81 | data, target, seq_len = self.get_batch(i, bptt) 82 | i += seq_len 83 | yield data, target, seq_len 84 | if i >= self.data.size(0) - 2: 85 | break 86 | 87 | def __iter__(self): 88 | return self.get_fixlen_iter() 89 | 90 | 91 | class SST2Iterator(object): 92 | def __init__(self, data, bsz): 93 | """ 94 | data: [encoded, labels] 95 | """ 96 | 97 | self.bsz = bsz 98 | 99 | self.encoded = data[0] 100 | self.labels = data[1] # Tensor 101 | 102 | self.n_step = self.labels.size(0) // bsz 103 | self.cur_step = 0 104 | self.n_samples = self.labels.size(0) 105 | self.sequence_array = np.arange(self.n_samples) 106 | 107 | def get_batch(self, index_list): 108 | 109 | subencoded = [] 110 | mask_idx_pre = [] 111 | sublabels = [] 112 | 113 | for idx in index_list: 114 | subencoded.append(self.encoded[idx]) 115 | sublabels.append(self.labels[idx]) 116 | mask_idx_pre.append(torch.ones(self.encoded[idx].shape[0])) 117 | 118 | subencoded = pad_sequence_reverse(subencoded) 119 | mask_idx = 1 - pad_sequence_reverse(mask_idx_pre) 120 | length = mask_idx.shape[0] 121 | 122 | expand_mask_idx = mask_idx.unsqueeze(1).repeat( 123 | 1, length, 1 124 | ) # length, length, batch-size 125 | expand_mask_idx = ((expand_mask_idx + mask_idx) > 0).byte() 126 | 127 | # mask_idx = pad_sequence(mask_idx) 128 | sublabels = torch.LongTensor(sublabels) 129 | 130 | return subencoded, expand_mask_idx, sublabels 131 | 132 | def get_varlen_iter(self, start=0): 133 | sample_array = np.random.permutation(self.n_samples) 134 | for i in range(self.n_step): 135 | sub_index = sample_array[i * self.bsz : i * self.bsz + self.bsz] 136 | yield self.get_batch(sub_index) 137 | 138 | def get_fixlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 139 | # print(self.n_step) 140 | for i in range(self.cur_step, self.cur_step+(self.n_step//5)): 141 | # for i in range(self.n_step): 142 | sub_index = self.sequence_array[i * self.bsz : i * self.bsz + self.bsz] 143 | yield self.get_batch(sub_index) 144 | self.cur_step = i + 1 145 | 146 | def __iter__(self): 147 | return self.get_fixlen_iter() 148 | 149 | 150 | class LMShuffledIterator(object): 151 | def __init__(self, data, bsz, bptt, device="cpu", ext_len=None, shuffle=False): 152 | """ 153 | data -- list[LongTensor] -- there is no order among the LongTensors 154 | """ 155 | self.data = data 156 | 157 | self.bsz = bsz 158 | self.bptt = bptt 159 | self.ext_len = ext_len if ext_len is not None else 0 160 | 161 | self.device = device 162 | self.shuffle = shuffle 163 | 164 | def get_sent_stream(self): 165 | # index iterator 166 | epoch_indices = ( 167 | np.random.permutation(len(self.data)) 168 | if self.shuffle 169 | else np.array(range(len(self.data))) 170 | ) 171 | 172 | # sentence iterator 173 | for idx in epoch_indices: 174 | yield self.data[idx] 175 | 176 | def stream_iterator(self, sent_stream): 177 | # streams for each data in the batch 178 | streams = [None] * self.bsz 179 | 180 | data = torch.LongTensor(self.bptt, self.bsz) 181 | target = torch.LongTensor(self.bptt, self.bsz) 182 | 183 | n_retain = 0 184 | 185 | while True: 186 | # data : [n_retain+bptt x bsz] 187 | # target : [bptt x bsz] 188 | data[n_retain:].fill_(-1) 189 | target.fill_(-1) 190 | 191 | valid_batch = True 192 | 193 | for i in range(self.bsz): 194 | n_filled = 0 195 | try: 196 | while n_filled < self.bptt: 197 | if streams[i] is None or len(streams[i]) <= 1: 198 | streams[i] = next(sent_stream) 199 | # number of new tokens to fill in 200 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled) 201 | # first n_retain tokens are retained from last batch 202 | data[n_retain + n_filled : n_retain + n_filled + n_new, i] = ( 203 | streams[i][:n_new] 204 | ) 205 | target[n_filled : n_filled + n_new, i] = streams[i][ 206 | 1 : n_new + 1 207 | ] 208 | streams[i] = streams[i][n_new:] 209 | n_filled += n_new 210 | except StopIteration: 211 | valid_batch = False 212 | break 213 | 214 | if not valid_batch: 215 | return 216 | 217 | data = data.to(self.device) 218 | target = target.to(self.device) 219 | 220 | yield data, target, self.bptt 221 | 222 | n_retain = min(data.size(0), self.ext_len) 223 | if n_retain > 0: 224 | data[:n_retain] = data[-n_retain:] 225 | data.resize_(n_retain + self.bptt, data.size(1)) 226 | 227 | def __iter__(self): 228 | # sent_stream is an iterator 229 | sent_stream = self.get_sent_stream() 230 | 231 | for batch in self.stream_iterator(sent_stream): 232 | yield batch 233 | 234 | 235 | class LMMultiFileIterator(LMShuffledIterator): 236 | def __init__( 237 | self, paths, vocab, bsz, bptt, device="cpu", ext_len=None, shuffle=False 238 | ): 239 | 240 | self.paths = paths 241 | self.vocab = vocab 242 | 243 | self.bsz = bsz 244 | self.bptt = bptt 245 | self.ext_len = ext_len if ext_len is not None else 0 246 | 247 | self.device = device 248 | self.shuffle = shuffle 249 | 250 | def get_sent_stream(self, path): 251 | sents = self.vocab.encode_file(path, add_double_eos=True) 252 | if self.shuffle: 253 | np.random.shuffle(sents) 254 | sent_stream = iter(sents) 255 | 256 | return sent_stream 257 | 258 | def __iter__(self): 259 | if self.shuffle: 260 | np.random.shuffle(self.paths) 261 | 262 | for path in self.paths: 263 | # sent_stream is an iterator 264 | sent_stream = self.get_sent_stream(path) 265 | for batch in self.stream_iterator(sent_stream): 266 | yield batch 267 | 268 | 269 | class Corpus(object): 270 | def __init__(self, path, dataset, *args, **kwargs): 271 | self.dataset = dataset 272 | self.vocab = Vocab(*args, **kwargs) 273 | 274 | if self.dataset in ["ptb", "wt2", "enwik8", "text8"]: 275 | self.vocab.count_file(os.path.join(path, "train.txt")) 276 | self.vocab.count_file(os.path.join(path, "valid.txt")) 277 | self.vocab.count_file(os.path.join(path, "test.txt")) 278 | elif self.dataset == "wt103": 279 | self.vocab.count_file(os.path.join(path, "train.txt")) 280 | elif self.dataset == "lm1b": 281 | train_path_pattern = os.path.join( 282 | path, 283 | "1-billion-word-language-modeling-benchmark-r13output", 284 | "training-monolingual.tokenized.shuffled", 285 | "news.en-*", 286 | ) 287 | train_paths = glob.glob(train_path_pattern) 288 | # the vocab will load from file when build_vocab() is called 289 | 290 | elif self.dataset == "csqa": 291 | self.vocab.count_csqa( 292 | os.path.join(path, "train_rand_split.jsonl"), add_cls_token=True 293 | ) 294 | self.vocab.count_csqa( 295 | os.path.join(path, "dev_rand_split.jsonl"), add_cls_token=True 296 | ) 297 | self.vocab.count_csqa( 298 | os.path.join(path, "test_rand_split_no_answers.jsonl"), 299 | add_cls_token=True, 300 | ) 301 | 302 | elif self.dataset in ["sst2", "imdb"]: 303 | self.vocab.count_sst2(os.path.join(path, "train.tsv"), add_cls_token=True) 304 | self.vocab.count_sst2(os.path.join(path, "dev.tsv"), add_cls_token=True) 305 | 306 | elif self.dataset == "sst5": 307 | dataset = pytreebank.load_sst(path) 308 | train = dataset['train'] 309 | val = dataset['dev'] 310 | test = dataset['test'] 311 | # self.vocab.count_sst5(os.path.join(path, "train.tsv"), add_cls_token=True) 312 | # self.vocab.count_sst5(os.path.join(path, "dev.tsv"), add_cls_token=True) 313 | self.vocab.count_sst5(train, add_cls_token=True) 314 | self.vocab.count_sst5(val, add_cls_token=True) 315 | 316 | elif self.dataset == "banking77": 317 | self.vocab.count_banking77( 318 | os.path.join(path, "train.tsv"), add_cls_token=True 319 | ) 320 | self.vocab.count_banking77( 321 | os.path.join(path, "dev.tsv"), add_cls_token=True 322 | ) 323 | 324 | self.vocab.build_vocab() 325 | 326 | if self.dataset in ["ptb", "wt2", "wt103"]: 327 | self.train = self.vocab.encode_file( 328 | os.path.join(path, "train.txt"), ordered=True 329 | ) 330 | self.valid = self.vocab.encode_file( 331 | os.path.join(path, "valid.txt"), ordered=True 332 | ) 333 | self.test = self.vocab.encode_file( 334 | os.path.join(path, "test.txt"), ordered=True 335 | ) 336 | elif self.dataset in ["enwik8", "text8"]: 337 | self.train = self.vocab.encode_file( 338 | os.path.join(path, "train.txt"), ordered=True, add_eos=False 339 | ) 340 | self.valid = self.vocab.encode_file( 341 | os.path.join(path, "valid.txt"), ordered=True, add_eos=False 342 | ) 343 | self.test = self.vocab.encode_file( 344 | os.path.join(path, "test.txt"), ordered=True, add_eos=False 345 | ) 346 | elif self.dataset == "lm1b": 347 | self.train = train_paths 348 | self.valid = self.vocab.encode_file( 349 | os.path.join(path, "valid.txt"), ordered=False, add_double_eos=True 350 | ) 351 | self.test = self.vocab.encode_file( 352 | os.path.join(path, "test.txt"), ordered=False, add_double_eos=True 353 | ) 354 | elif self.dataset == "csqa": 355 | self.train = self.vocab.encode_csqa_file( 356 | os.path.join(path, "train_rand_split.jsonl"), 357 | ordered=True, 358 | add_cls_token=True, 359 | ) 360 | self.valid = self.vocab.encode_csqa_file( 361 | os.path.join(path, "dev_rand_split.jsonl"), 362 | ordered=True, 363 | add_cls_token=True, 364 | ) 365 | elif self.dataset in ["sst2", "imdb"]: 366 | self.train = self.vocab.encode_sst2_file( 367 | os.path.join(path, "train.tsv"), add_cls_token=True 368 | ) 369 | self.valid = self.vocab.encode_sst2_file( 370 | os.path.join(path, "dev.tsv"), add_cls_token=True 371 | ) 372 | 373 | elif self.dataset == "sst5": 374 | self.train = self.vocab.encode_sst5_file( 375 | train, add_cls_token=True 376 | ) 377 | self.valid = self.vocab.encode_sst5_file( 378 | val, add_cls_token=True 379 | ) 380 | # self.test = self.vocab.encode_sst5_file( 381 | # test, add_cls_token=True 382 | # ) 383 | elif self.dataset == "banking77": 384 | self.train = self.vocab.encode_banking77_file( 385 | os.path.join(path, "train.tsv"), add_cls_token=True 386 | ) 387 | self.valid = self.vocab.encode_banking77_file( 388 | os.path.join(path, "dev.tsv"), add_cls_token=True 389 | ) 390 | 391 | def get_iterator(self, split, *args, **kwargs): 392 | 393 | if split == "train": 394 | if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]: 395 | data_iter = LMOrderedIterator(self.train, *args, **kwargs) 396 | elif self.dataset == "lm1b": 397 | kwargs["shuffle"] = True 398 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) 399 | elif self.dataset == "csqa": 400 | data_iter = CSQAIterator(self.train, *args, **kwargs) 401 | elif self.dataset in ["sst2", "imdb", "sst5", "banking77"]: 402 | data_iter = SST2Iterator(self.train, *args, **kwargs) 403 | # dataset = CSQADataset(self.train) 404 | # data_iter = DataLoader(dataset, *args, shuffle=True, 405 | # num_workers=4, drop_last=False, pin_memory=True) 406 | 407 | elif split in ["valid", "test"]: 408 | data = self.valid if split == "valid" else self.test 409 | if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]: 410 | data_iter = LMOrderedIterator(data, *args, **kwargs) 411 | elif self.dataset == "lm1b": 412 | data_iter = LMShuffledIterator(data, *args, **kwargs) 413 | elif self.dataset == "csqa": 414 | data_iter = CSQAIterator(self.valid, *args, **kwargs) 415 | elif self.dataset in ["sst2", "imdb", "sst5", "banking77"]: 416 | data_iter = SST2Iterator(self.valid, *args, **kwargs) 417 | 418 | # dataset = CSQADataset(self.valid) 419 | # data_iter = DataLoader(dataset, *args, shuffle=False, 420 | # num_workers=4, drop_last=False, pin_memory=True) 421 | return data_iter 422 | 423 | 424 | def get_lm_corpus(datadir, dataset): 425 | fn = os.path.join(datadir, "cache.pt") 426 | # print(fn) 427 | if os.path.exists(fn): 428 | print("Loading cached dataset...") 429 | corpus = torch.load(fn) 430 | else: 431 | print("Producing dataset {}...".format(dataset)) 432 | kwargs = {} 433 | if dataset in ["wt103", "wt2"]: 434 | kwargs["special"] = [""] 435 | kwargs["lower_case"] = False 436 | elif dataset == "ptb": 437 | kwargs["special"] = [""] 438 | kwargs["lower_case"] = True 439 | elif dataset == "lm1b": 440 | kwargs["special"] = [] 441 | kwargs["lower_case"] = False 442 | kwargs["vocab_file"] = os.path.join(datadir, "1b_word_vocab.txt") 443 | elif dataset in ["csqa", "sst2", "imdb", "sst5", "banking77"]: 444 | kwargs["special"] = [""] 445 | elif dataset in ["enwik8", "text8"]: 446 | pass 447 | 448 | corpus = Corpus(datadir, dataset, **kwargs) 449 | torch.save(corpus, fn) 450 | 451 | return corpus 452 | 453 | 454 | if __name__ == "__main__": 455 | import argparse 456 | 457 | parser = argparse.ArgumentParser(description="unit test") 458 | parser.add_argument( 459 | "--datadir", type=str, default="./enwik8", help="location of the data corpus" 460 | ) 461 | parser.add_argument( 462 | "--dataset", 463 | type=str, 464 | default="enwik8", 465 | choices=["ptb", "wt2", "wt103", "lm1b", "enwik8", "text8"], 466 | help="dataset name", 467 | ) 468 | args = parser.parse_args() 469 | 470 | corpus = get_lm_corpus(args.datadir, args.dataset) 471 | print("Vocab size : {}".format(len(corpus.vocab.idx2sym))) 472 | -------------------------------------------------------------------------------- /custom_layers_opt.py: -------------------------------------------------------------------------------- 1 | r""" 2 | FMoE core layer 3 | """ 4 | 5 | import tree 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | 10 | from custom_functions import prepare_forward, ensure_comm 11 | from custom_functions import MOEScatter, MOEGather 12 | from custom_functions import AllGather, Slice 13 | from gates import NaiveGate 14 | 15 | from fastermoe.config import switch_from_env 16 | import random 17 | import torch.nn.functional as F 18 | from torchmetrics.regression import KLDivergence 19 | 20 | 21 | def kl_divergence(softmax_1, softmax_2): 22 | kl_divergence = KLDivergence(log_prob=False).cuda() 23 | return kl_divergence(softmax_1, softmax_2) 24 | 25 | 26 | def cal_mse_loss(input, target): 27 | mse_loss = nn.MSELoss() 28 | _loss = mse_loss(input, target) 29 | return _loss 30 | 31 | 32 | def mark_module_parallel_comm(module, comm): 33 | r""" 34 | Mark all parameters in `module` as doing data parallel in `comm`, where 35 | `comm` may be one of `'world', 'dp', 'none'`. 36 | """ 37 | for p in module.parameters(): 38 | setattr(p, "dp_comm", comm) 39 | 40 | 41 | def _fmoe_general_global_forward( 42 | inp, gate, expert_fn, num_expert, world_size, **kwargs 43 | ): 44 | r""" 45 | A private function that performs the following steps to complete the MoE 46 | computation. 47 | * Count the number of tokens from each worker to each expert. 48 | * Send the features to their target position so that input features to each 49 | expert are contiguous in memory. 50 | * Perform the forward computation of the experts using `expert_fn` 51 | * Gather the output features of experts back, and reorder them as sentences. 52 | Intermediate results like expert counts are hidden from users by this 53 | function. 54 | """ 55 | ( 56 | pos, 57 | local_expert_count, 58 | global_expert_count, 59 | fwd_expert_count, 60 | fwd_batch_size, 61 | ) = prepare_forward(gate, num_expert, world_size) 62 | topk = 1 63 | if len(gate.shape) == 2: 64 | topk = gate.shape[1] 65 | 66 | def scatter_func(tensor): 67 | return MOEScatter.apply( 68 | tensor, 69 | torch.div(pos, topk, rounding_mode="floor"), 70 | local_expert_count, 71 | global_expert_count, 72 | fwd_batch_size, 73 | world_size, 74 | ) 75 | 76 | x = tree.map_structure(scatter_func, inp) 77 | 78 | x = expert_fn(x, fwd_expert_count) 79 | 80 | out_batch_size = tree.flatten(inp)[0].shape[0] 81 | if len(gate.shape) == 2: 82 | out_batch_size *= gate.shape[1] 83 | 84 | def gather_func(tensor): 85 | return MOEGather.apply( 86 | tensor, 87 | pos, 88 | local_expert_count, 89 | global_expert_count, 90 | out_batch_size, 91 | world_size, 92 | ) 93 | 94 | outp = tree.map_structure(gather_func, x) 95 | return outp 96 | 97 | 98 | fmoe_faster_schedule = False 99 | if switch_from_env("FMOE_FASTER_SCHEDULE_ENABLE", False): 100 | fmoe_faster_schedule = True 101 | from .fastermoe.schedule import _fmoe_general_global_forward 102 | 103 | 104 | class FMoEOpt(nn.Module): 105 | r""" 106 | A general moe implementation that supports an arbitrary module as the 107 | expert. 108 | * `num_expert` stands for the number of experts on **each** worker. 109 | * `world_size` stands for the total number of workers that contains 110 | different experts. 111 | * `slice_group` can be a torch's communication group, indicating that 112 | specific model parallel is applied across the group, and workers in the 113 | group hold the same copy of input feature, and requires the same copy of 114 | the output. For each worker, FMoE only computes the output of a certain 115 | slice of the input batch, and will all-gather the outputs after 116 | computation. 117 | * `top_k` stands for the number of experts each token is going to. 118 | * `gate` is a gate class which can found in `fmoe.gates`. 119 | * `expert` can be specified as a module class, it is used to generate 120 | `num_expert` expert modules. 121 | """ 122 | 123 | def __init__( 124 | self, 125 | num_expert=32, 126 | d_model=1024, 127 | world_size=1, 128 | mp_group=None, # being deprecated 129 | slice_group=None, 130 | moe_group=None, 131 | moe_top_k=2, 132 | gate=NaiveGate, 133 | expert=None, 134 | gate_hook=None, 135 | mask=None, 136 | mask_dict=None, 137 | freq=0.0, 138 | alpha=0.0, 139 | act_experts="shuffle", 140 | g_blance=False, 141 | opt_blance=False, 142 | combine_gate=False, 143 | opt_loss="mse", 144 | ): 145 | super().__init__() 146 | self.num_expert = num_expert 147 | self.d_model = d_model 148 | self.world_size = world_size 149 | self.freq = freq 150 | self.alpha = alpha 151 | self.act_experts = act_experts 152 | self.opt_blance = opt_blance 153 | self.combine_gate = combine_gate 154 | self.opt_loss = opt_loss 155 | self.slice_group = slice_group 156 | if mp_group is not None: 157 | print("[Warning] mp_group is being deprecated") 158 | self.slice_group = mp_group 159 | if self.slice_group is None: 160 | self.slice_size = 1 161 | self.slice_rank = 0 162 | else: 163 | self.slice_size = self.slice_group.size() 164 | self.slice_rank = self.slice_group.rank() 165 | 166 | self.top_k = moe_top_k 167 | if type(expert) is list: 168 | self.experts = nn.ModuleList([e(d_model) for e in expert]) 169 | self.experts_fused = False 170 | self.num_expert = num_expert = len(expert) 171 | elif expert is not None: 172 | self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)]) 173 | self.experts_fused = False 174 | else: 175 | self.experts_fused = True 176 | 177 | self.gate = gate(d_model, num_expert, world_size, moe_top_k, g_blance) 178 | self.gate_hook = gate_hook 179 | self.mask = mask 180 | self.mask_dict = mask_dict 181 | self.moe_group = moe_group 182 | 183 | def expert_fn(self, inp, fwd_expert_count): 184 | r""" 185 | The default expert function which either calls the experts as a whole 186 | or as separate experts. 187 | """ 188 | if self.experts_fused: 189 | return self.experts(inp, fwd_expert_count) 190 | if isinstance(fwd_expert_count, torch.Tensor): 191 | fwd_expert_count = fwd_expert_count.cpu().numpy() 192 | outputs = [] 193 | base_idx = 0 194 | for i in range(self.num_expert): 195 | batch_size = fwd_expert_count[i] 196 | inp_slice = inp[base_idx : base_idx + batch_size] 197 | outputs.append(self.experts[i](inp_slice)) 198 | base_idx += batch_size 199 | return torch.cat(outputs, dim=0) 200 | 201 | def mark_parallel_comm(self, expert_dp_comm="none"): 202 | r""" 203 | Automatically mark the data parallel comms of the parameters within the 204 | module. This can be typically called at the end of the __init__ function 205 | in child classes. 206 | """ 207 | if self.experts is not None: 208 | comm = expert_dp_comm 209 | if isinstance(self.experts, list): 210 | for e in self.experts: 211 | mark_module_parallel_comm(e, comm) 212 | else: 213 | mark_module_parallel_comm(self.experts, comm) 214 | mark_module_parallel_comm(self.gate, "gate") 215 | 216 | def cal_load_balance(self, gate, gate_top_k_idx): 217 | 218 | score = F.softmax(gate, dim=-1) 219 | valid_idx = gate_top_k_idx[gate_top_k_idx > -1] 220 | fraction_expert = ( 221 | torch.scatter_add( 222 | torch.zeros(self.num_expert, device=valid_idx.device), 223 | 0, 224 | valid_idx, 225 | torch.ones_like(valid_idx, dtype=torch.float), 226 | ) 227 | / valid_idx.numel() 228 | ) 229 | prob_expert = score.sum(dim=0) / valid_idx.numel() 230 | 231 | loss = (fraction_expert * prob_expert).sum() * self.num_expert 232 | return loss 233 | 234 | def forward(self, moe_inp): 235 | r""" 236 | The FMoE module first computes gate output, and then conduct MoE forward 237 | according to the gate. The score of the selected gate given by the 238 | expert is multiplied to the experts' output tensors as a weight. 239 | """ 240 | 241 | moe_inp_batch_size = tree.flatten( 242 | tree.map_structure(lambda tensor: tensor.shape[0], moe_inp) 243 | ) 244 | assert all( 245 | [batch_size == moe_inp_batch_size[0] for batch_size in moe_inp_batch_size] 246 | ), "MoE inputs must have the same batch size" 247 | 248 | if self.world_size > 1: 249 | 250 | def ensure_comm_func(tensor): 251 | ensure_comm(tensor, self.moe_group) 252 | 253 | tree.map_structure(ensure_comm_func, moe_inp) 254 | if self.slice_size > 1: 255 | 256 | def slice_func(tensor): 257 | return Slice.apply( 258 | tensor, self.slice_rank, self.slice_size, self.slice_group 259 | ) 260 | 261 | moe_inp = tree.map_structure(slice_func, moe_inp) 262 | flip_ = random.random() 263 | gate_top_k_idx, gate_score, gate_ = self.gate(moe_inp, return_all_scores=True) 264 | 265 | if self.training: 266 | if flip_ > (1 - self.freq): 267 | # all experts score 268 | gate_top_k_val_org, _ = torch.topk( 269 | gate_, k=self.num_expert, dim=-1, largest=True, sorted=False 270 | ) 271 | gate_top_k_val_org = gate_top_k_val_org.view( 272 | -1, self.num_expert 273 | ) # (BxL) x 1 x top_k 274 | gate_score_org = F.softmax(gate_top_k_val_org, dim=-1) 275 | # activate all experts with shuffle index 276 | if self.act_experts == "shuffle": 277 | # searching best routing optimal 278 | gate_dense = torch.ones_like( 279 | gate_ 280 | ) # average the importance of all experts 281 | gate_top_k_val_opt, gate_top_k_idx_opt = torch.topk( 282 | gate_dense, 283 | k=self.num_expert, 284 | dim=-1, 285 | largest=True, 286 | sorted=False, 287 | ) # [.. x top_k] 288 | gate_top_k_val_opt = gate_top_k_val_opt.view( 289 | -1, self.num_expert 290 | ) # (BxL) x 1 x num_experts 291 | gate_score_opt = F.softmax(gate_top_k_val_opt, dim=-1) 292 | 293 | if hasattr(self.gate, "dynamic_top_k"): 294 | self.top_k = self.gate.dynamic_top_k 295 | 296 | if self.gate_hook is not None: 297 | self.gate_hook(gate_top_k_idx_opt, gate_score_opt, None) 298 | 299 | # delete masked tensors 300 | if self.mask is not None and self.mask_dict is not None: 301 | # TODO: to fix 302 | def delete_mask_func(tensor): 303 | # to: (BxL') x d_model 304 | tensor = tensor[mask == 0, :] 305 | return tensor 306 | 307 | mask = self.mask.view(-1) 308 | moe_inp = tree.map_structure(delete_mask_func, moe_inp) 309 | gate_top_k_idx_opt = gate_top_k_idx_opt[mask == 0, :] 310 | bs = moe_inp.shape[0] 311 | fwd_tmp = _fmoe_general_global_forward( 312 | moe_inp, 313 | gate_top_k_idx_opt, 314 | self.expert_fn, 315 | self.num_expert, 316 | self.world_size, 317 | experts=self.experts, 318 | ).reshape(bs, self.num_expert, -1) 319 | # cal norm of output experts 320 | fwd_norm = torch.norm(fwd_tmp, dim=2) 321 | # activate all experts without shuffle index 322 | else: 323 | # activate with grad 324 | if self.opt_blance: 325 | fwd_tmp = None 326 | for i in range(self.num_expert): 327 | 328 | temp_ = ( 329 | moe_inp @ self.experts.htoh4.weight[i].T 330 | + self.experts.htoh4.bias[i] 331 | ) 332 | temp_ = F.relu(temp_) 333 | temp_ = ( 334 | temp_ @ self.experts.h4toh.weight[i].T 335 | + self.experts.h4toh.bias[i] 336 | ) 337 | temp_ = torch.unsqueeze(temp_, -1) 338 | if fwd_tmp is None: 339 | fwd_tmp = temp_.clone() 340 | else: 341 | fwd_tmp = torch.concat([fwd_tmp, temp_], dim=-1) 342 | # activate without grad 343 | else: 344 | fwd_tmp = None 345 | for i in range(self.num_expert): 346 | 347 | temp_ = ( 348 | moe_inp @ self.experts.htoh4.weight[i].T 349 | + self.experts.htoh4.bias[i] 350 | ) 351 | temp_ = F.relu(temp_) 352 | temp_ = ( 353 | temp_ @ self.experts.h4toh.weight[i].T 354 | + self.experts.h4toh.bias[i] 355 | ) 356 | temp_ = torch.unsqueeze(temp_, -1) 357 | if fwd_tmp is None: 358 | fwd_tmp = temp_.clone() 359 | else: 360 | fwd_tmp = torch.concat([fwd_tmp, temp_], dim=-1) 361 | # cal norm of output experts 362 | fwd_norm = torch.norm(fwd_tmp, dim=1) 363 | # ensemble with gate information 364 | if self.combine_gate: 365 | fwd_norm = fwd_norm * 0.5 + gate_ * 0.5 366 | gate_top_k_val_optim, gate_top_k_idx_optim = torch.topk( 367 | fwd_norm, k=self.top_k, dim=-1, largest=True, sorted=False 368 | ) 369 | # if balance loss 370 | if self.opt_blance: 371 | opt_bl_loss = self.cal_load_balance(fwd_norm, gate_top_k_idx_optim) 372 | # get output 373 | gate_top_k_val_optim = gate_top_k_val_optim.view(-1, self.top_k) 374 | # push low score to zeros 375 | gate_score2 = torch.zeros( 376 | (gate_top_k_val_optim.shape[0], self.num_expert) 377 | ).cuda() 378 | gate_score2.fill_(-10e9) 379 | # fill with topk score value 380 | gate_score2 = gate_score2.scatter( 381 | 1, gate_top_k_idx_optim, gate_top_k_val_optim 382 | ) 383 | gate_score_optimal = F.softmax(gate_score2, dim=1) 384 | # # calculate loss 385 | if self.opt_loss == "mse": 386 | add_loss = cal_mse_loss( 387 | gate_score_org, 388 | gate_score_optimal, 389 | ) 390 | else: 391 | add_loss = kl_divergence( 392 | gate_score_org, 393 | gate_score_optimal, 394 | ) 395 | # if balance loss 396 | if self.opt_blance: 397 | add_loss += opt_bl_loss 398 | # add to balance loss 399 | self.gate.loss = add_loss * self.alpha 400 | else: 401 | self.gate.loss = add_loss * self.alpha 402 | #update gate policy 403 | gate_top_k_idx = gate_top_k_idx_optim 404 | gate_score = F.softmax(gate_top_k_val_optim, dim=1) 405 | 406 | 407 | if hasattr(self.gate, "dynamic_top_k"): 408 | self.top_k = self.gate.dynamic_top_k 409 | 410 | if self.gate_hook is not None: 411 | self.gate_hook(gate_top_k_idx, gate_score, None) 412 | 413 | # delete masked tensors 414 | if self.mask is not None and self.mask_dict is not None: 415 | # TODO: to fix 416 | def delete_mask_func(tensor): 417 | # to: (BxL') x d_model 418 | tensor = tensor[mask == 0, :] 419 | return tensor 420 | 421 | mask = self.mask.view(-1) 422 | moe_inp = tree.map_structure(delete_mask_func, moe_inp) 423 | gate_top_k_idx = gate_top_k_idx[mask == 0, :] 424 | 425 | fwd = _fmoe_general_global_forward( 426 | moe_inp, 427 | gate_top_k_idx, 428 | self.expert_fn, 429 | self.num_expert, 430 | self.world_size, 431 | experts=self.experts, 432 | ) 433 | 434 | # recover deleted tensors 435 | if self.mask is not None and self.mask_dict is not None: 436 | 437 | def recover_func(tensor): 438 | # to: (BxL') x top_k x dim 439 | dim = tensor.shape[-1] 440 | tensor = tensor.view(-1, self.top_k, dim) 441 | # to: (BxL) x top_k x d_model 442 | x = torch.zeros( 443 | mask.shape[0], 444 | self.top_k, 445 | dim, 446 | device=tensor.device, 447 | dtype=tensor.dtype, 448 | ) 449 | # recover 450 | x[mask == 0] = tensor 451 | for k, v in self.mask_dict.items(): 452 | x[mask == k] = v 453 | return x 454 | 455 | moe_outp = tree.map_structure(recover_func, fwd) 456 | else: 457 | 458 | def view_func(tensor): 459 | dim = tensor.shape[-1] 460 | tensor = tensor.view(-1, self.top_k, dim) 461 | return tensor 462 | 463 | moe_outp = tree.map_structure(view_func, fwd) 464 | 465 | gate_score = gate_score.view(-1, 1, self.top_k) 466 | 467 | def bmm_func(tensor): 468 | dim = tensor.shape[-1] 469 | tensor = torch.bmm(gate_score, tensor).reshape(-1, dim) 470 | return tensor 471 | 472 | moe_outp = tree.map_structure(bmm_func, moe_outp) 473 | 474 | if self.slice_size > 1: 475 | 476 | def all_gather_func(tensor): 477 | return AllGather.apply( 478 | tensor, self.slice_rank, self.slice_size, self.slice_group 479 | ) 480 | 481 | moe_outp = tree.map_structure(all_gather_func, moe_outp) 482 | 483 | moe_outp_batch_size = tree.flatten( 484 | tree.map_structure(lambda tensor: tensor.shape[0], moe_outp) 485 | ) 486 | assert all( 487 | [batch_size == moe_outp_batch_size[0] for batch_size in moe_outp_batch_size] 488 | ), "MoE outputs must have the same batch size" 489 | return moe_outp 490 | --------------------------------------------------------------------------------