├── fastermoe
├── __init__.py
├── config.py
├── expert_utils.py
├── shadow_policy.py
└── schedule.py
├── fastmoe.egg-info
├── dependency_links.txt
├── top_level.txt
├── SOURCES.txt
└── PKG-INFO
├── requirements.txt
├── .gitignore
├── gates
├── __init__.py
├── base_gate.py
├── zero_gate.py
├── utils.py
├── naive_gate.py
├── swipe_gate.py
├── gshard_gate.py
├── switch_gate.py
├── faster_gate.py
└── noisy_gate.py
├── scripts
├── smoe-s.sh
├── glam-m.sh
├── smoe-m.sh
├── smoe-l.sh
├── smoe-mom-s.sh
├── glam-adam-m.sh
├── glam-mom-m.sh
├── smoe-adam-m.sh
├── smoe-mom-m.sh
└── smoe-mom-l.sh
├── custom_utils.py
├── .github
└── workflows
│ └── typecheck.yaml
├── README.md
├── setup.py
├── custom_transformer.py
├── data.py
├── trainer.py
├── custom_functions.py
├── custom_gates.py
├── finetune_trainer.py
├── utils.py
├── finetune_train.py
├── train.py
├── config.py
├── LICENSE.txt
├── custom_layers.py
├── vocabulary.py
├── finetune_data.py
└── custom_layers_opt.py
/fastermoe/__init__.py:
--------------------------------------------------------------------------------
1 | import os, sys
--------------------------------------------------------------------------------
/fastmoe.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/fastmoe.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | fmoe_cuda
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # torch==2.2.0
2 | # numpy
3 | ninja
4 | dm-tree
5 | tqdm
6 | torchmetrics==1.3.1
7 |
--------------------------------------------------------------------------------
/fastmoe.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | README.md
2 | setup.py
3 | fastmoe.egg-info/PKG-INFO
4 | fastmoe.egg-info/SOURCES.txt
5 | fastmoe.egg-info/dependency_links.txt
6 | fastmoe.egg-info/top_level.txt
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | wandb/
2 |
3 | # Compiled and optimized Python code and libraries
4 | __pycache__/
5 | *.py[cod]
6 |
7 | # Temporary editor files
8 | *~
9 | [._]*.sw[nop]
10 |
11 | # Temporary macOS files
12 | .DS_Store
13 |
--------------------------------------------------------------------------------
/gates/__init__.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | from .zero_gate import ZeroGate
3 | from .naive_gate import NaiveGate
4 | from .noisy_gate import NoisyGate
5 |
6 | from .gshard_gate import GShardGate
7 | from .switch_gate import SwitchGate
8 |
9 | from .swipe_gate import SwipeGate
--------------------------------------------------------------------------------
/fastmoe.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: fastmoe
3 | Version: 1.1.0
4 | Summary: An efficient Mixture-of-Experts system for PyTorch
5 | Author: Jiaao He, Jiezhong Qiu, Aohan Zeng, Tiago Antunes, Jinjun Peng, Qin Li, Mingshu Zhai
6 | Author-email: hja20@mails.tsinghua.edu.cn
7 | License: Apache-2
8 |
--------------------------------------------------------------------------------
/fastermoe/config.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | def float_from_env(key, default=-1):
4 | if key in os.environ:
5 | return float(os.environ[key])
6 | return default
7 |
8 | def switch_from_env(key, default=False):
9 | if key in os.environ:
10 | return os.environ[key] in ['1', 'ON']
11 | return default
--------------------------------------------------------------------------------
/gates/base_gate.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch.nn as nn
3 |
4 | class BaseGate(nn.Module):
5 | def __init__(self, num_expert, world_size):
6 | super().__init__()
7 | self.world_size = world_size
8 | self.num_expert = num_expert
9 | self.tot_expert = world_size * num_expert
10 | self.loss = None
11 |
12 | def forward(self, x):
13 | raise NotImplementedError('Base gate cannot be directly used for fwd')
14 |
15 | def set_loss(self, loss):
16 | self.loss = loss
17 |
18 | def get_loss(self, clear=True):
19 | loss = self.loss
20 | if clear:
21 | self.loss = None
22 | return loss
23 |
24 | @property
25 | def has_loss(self):
26 | return self.loss is not None
--------------------------------------------------------------------------------
/gates/zero_gate.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | from .base_gate import BaseGate
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | class ZeroGate(BaseGate):
9 | r"""
10 | Guide all input samples to gate 0.
11 | """
12 |
13 | def __init__(self, _1, num_expert, world_size, top_k=2):
14 | super().__init__(num_expert, world_size)
15 | self.top_k = top_k
16 |
17 | def forward(self, inp):
18 | r"""
19 | All output to expert 1
20 | """
21 | idx = torch.zeros(
22 | inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device
23 | )
24 | gate_score = (
25 | torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k
26 | )
27 | return idx, gate_score.reshape(-1, 1, self.top_k)
--------------------------------------------------------------------------------
/scripts/smoe-s.sh:
--------------------------------------------------------------------------------
1 | mkdir -p /path/to/checkpoint/directory/
2 |
3 | args="
4 | --data /path/to/data/directory/wikitext-103/ \
5 | --base_arch transformer \
6 | --architecture sgsgsg \
7 | --gate_name smoe \
8 | --nlayers 3 \
9 | --hid-sz 128 \
10 | --inner-hid-sz 128 \
11 | --nheads 8 \
12 | --block-sz 256 \
13 | --attn-span 256 \
14 | --dropout 0.7 \
15 | --load_balance 0.01 \
16 | --optim adam \
17 | --lr 0.0007 \
18 | --lr-warmup 3000 \
19 | --niter 60 \
20 | --batch-sz 96 \
21 | --batch-split 2 \
22 | --nbatches 1000 \
23 | --distributed \
24 | --checkpoint /path/to/checkpoint/directory/smoe.pt \
25 | "
26 |
27 | echo "Training ..."
28 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args
29 |
30 | echo "Evaluation ..."
31 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode
--------------------------------------------------------------------------------
/gates/utils.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | from fmoe.functions import count_by_gate
4 | import fmoe_cuda as fmoe_native
5 |
6 | def limit_by_capacity(topk_idx, num_expert, world_size, capacity):
7 | with torch.no_grad():
8 | capacity = torch.ones(num_expert, dtype=torch.int32,
9 | device=topk_idx.device) * capacity
10 |
11 | pos, lec, gec = count_by_gate(topk_idx, num_expert, world_size,
12 | require_pos=False)
13 | new_gec = fmoe_native.limit_by_capacity(gec, capacity,
14 | num_expert, world_size)
15 | if world_size > 1:
16 | new_lec = fmoe_native.expert_exchange(new_gec, num_expert,
17 | world_size)
18 | else:
19 | new_lec = new_gec
20 |
21 | topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx,
22 | new_lec.to(torch.int32), num_expert, world_size)
23 | return new_lec, new_gec, topk_idx
--------------------------------------------------------------------------------
/scripts/glam-m.sh:
--------------------------------------------------------------------------------
1 | mkdir -p /path/to/checkpoint/directory/
2 |
3 | args="
4 | --data /path/to/data/directory/wikitext-103/ \
5 | --base_arch glam \
6 | --architecture sgsfsgsfsgsfsgsfsgsfsgsf \
7 | --gate_name smoe \
8 | --nlayers 6 \
9 | --hid-sz 352 \
10 | --inner-hid-sz 352 \
11 | --nheads 8 \
12 | --block-sz 512 \
13 | --attn-span 2048 \
14 | --dropout 0.1 \
15 | --load_balance 0.01 \
16 | --optim adam \
17 | --lr 0.00007 \
18 | --lr-warmup 4000 \
19 | --niter 120 \
20 | --batch-sz 48 \
21 | --batch-split 2 \
22 | --nbatches 1000 \
23 | --distributed \
24 | --checkpoint /path/to/checkpoint/directory/smoe.pt \
25 | "
26 |
27 | echo "Training ..."
28 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args
29 |
30 | echo "Evaluation ..."
31 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode
--------------------------------------------------------------------------------
/scripts/smoe-m.sh:
--------------------------------------------------------------------------------
1 | mkdir -p /path/to/checkpoint/directory/
2 |
3 | args="
4 | --data /path/to/data/directory/wikitext-103/ \
5 | --base_arch transformer \
6 | --architecture sgsgsgsgsgsg \
7 | --gate_name smoe \
8 | --nlayers 6 \
9 | --hid-sz 352 \
10 | --inner-hid-sz 352 \
11 | --nheads 8 \
12 | --block-sz 512 \
13 | --attn-span 1024 \
14 | --dropout 0.1 \
15 | --load_balance 0.01 \
16 | --optim adam \
17 | --lr 0.0007 \
18 | --lr-warmup 4000 \
19 | --niter 80 \
20 | --batch-sz 48 \
21 | --batch-split 2 \
22 | --nbatches 1000 \
23 | --distributed \
24 | --checkpoint /path/to/checkpoint/directory/smoe.pt \
25 | "
26 |
27 | echo "Training ..."
28 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args
29 |
30 | echo "Evaluation ..."
31 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode
32 |
--------------------------------------------------------------------------------
/scripts/smoe-l.sh:
--------------------------------------------------------------------------------
1 | mkdir -p /path/to/checkpoint/directory/
2 |
3 | args="
4 | --data /path/to/data/directory/wikitext-103/ \
5 | --base_arch transformer \
6 | --architecture sgsgsgsgsgsgsgsgsgsgsgsg \
7 | --gate_name smoe \
8 | --nlayers 12 \
9 | --hid-sz 512 \
10 | --inner-hid-sz 512 \
11 | --nheads 8 \
12 | --block-sz 1024 \
13 | --attn-span 2048 \
14 | --dropout 0.1 \
15 | --load_balance 0.01 \
16 | --optim adam \
17 | --lr 0.0007 \
18 | --lr-warmup 5000 \
19 | --niter 80 \
20 | --batch-sz 24 \
21 | --batch-split 2 \
22 | --nbatches 1000 \
23 | --distributed \
24 | --checkpoint /path/to/checkpoint/directory/smoe.pt \
25 | "
26 |
27 | echo "Training ..."
28 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args
29 |
30 | echo "Evaluation ..."
31 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode
32 |
--------------------------------------------------------------------------------
/scripts/smoe-mom-s.sh:
--------------------------------------------------------------------------------
1 | mkdir -p /path/to/checkpoint/directory/
2 |
3 | args="
4 | --data /path/to/data/directory/wikitext-103/ \
5 | --base_arch transformer \
6 | --architecture smsmsm \
7 | --gate_name smoe \
8 | --nlayers 3 \
9 | --hid-sz 128 \
10 | --inner-hid-sz 128 \
11 | --nheads 8 \
12 | --block-sz 256 \
13 | --attn-span 256 \
14 | --dropout 0.7 \
15 | --load_balance 0.01 \
16 | --optim adam \
17 | --lr 0.0007 \
18 | --lr-warmup 3000 \
19 | --niter 60 \
20 | --batch-sz 96 \
21 | --batch-split 2 \
22 | --nbatches 1000 \
23 | --distributed \
24 | --gamma1 0.8 \
25 | --gamma2 1.0 \
26 | --mu 0.7 \
27 | --beta1 0.9 \
28 | --beta2 0.999 \
29 | --checkpoint /path/to/checkpoint/directory/smoe.pt \
30 | "
31 |
32 | echo "Training ..."
33 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args
34 |
35 | echo "Evaluation ..."
36 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode
--------------------------------------------------------------------------------
/scripts/glam-adam-m.sh:
--------------------------------------------------------------------------------
1 | mkdir -p /path/to/checkpoint/directory/
2 |
3 | args="
4 | --data /path/to/data/directory/wikitext-103/ \
5 | --base_arch glam \
6 | --architecture sasfsasfsasfsasfsasfsasf \
7 | --gate_name smoe \
8 | --nlayers 6 \
9 | --hid-sz 352 \
10 | --inner-hid-sz 352 \
11 | --nheads 8 \
12 | --block-sz 512 \
13 | --attn-span 2048 \
14 | --dropout 0.1 \
15 | --load_balance 0.01 \
16 | --optim adam \
17 | --lr 0.00007 \
18 | --lr-warmup 4000 \
19 | --niter 120 \
20 | --batch-sz 48 \
21 | --batch-split 2 \
22 | --nbatches 1000 \
23 | --distributed \
24 | --gamma1 0.8 \
25 | --gamma2 1.0 \
26 | --mu 0.7 \
27 | --beta1 0.9 \
28 | --beta2 0.999 \
29 | --checkpoint /path/to/checkpoint/directory/smoe.pt \
30 | "
31 |
32 | echo "Training ..."
33 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args
34 |
35 | echo "Evaluation ..."
36 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode
--------------------------------------------------------------------------------
/scripts/glam-mom-m.sh:
--------------------------------------------------------------------------------
1 | mkdir -p /path/to/checkpoint/directory/
2 |
3 | args="
4 | --data /path/to/data/directory/wikitext-103/ \
5 | --base_arch glam \
6 | --architecture smsfsmsfsmsfsmsfsmsfsmsf \
7 | --gate_name smoe \
8 | --nlayers 6 \
9 | --hid-sz 352 \
10 | --inner-hid-sz 352 \
11 | --nheads 8 \
12 | --block-sz 512 \
13 | --attn-span 2048 \
14 | --dropout 0.1 \
15 | --load_balance 0.01 \
16 | --optim adam \
17 | --lr 0.00007 \
18 | --lr-warmup 4000 \
19 | --niter 120 \
20 | --batch-sz 48 \
21 | --batch-split 2 \
22 | --nbatches 1000 \
23 | --distributed \
24 | --gamma1 0.8 \
25 | --gamma2 1.0 \
26 | --mu 0.7 \
27 | --beta1 0.9 \
28 | --beta2 0.999 \
29 | --checkpoint /path/to/checkpoint/directory/smoe.pt \
30 | "
31 |
32 | echo "Training ..."
33 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args
34 |
35 | echo "Evaluation ..."
36 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode
--------------------------------------------------------------------------------
/scripts/smoe-adam-m.sh:
--------------------------------------------------------------------------------
1 | mkdir -p /path/to/checkpoint/directory/
2 |
3 | args="
4 | --data /path/to/data/directory/wikitext-103/ \
5 | --base_arch transformer \
6 | --architecture sasasasasasa \
7 | --gate_name smoe \
8 | --nlayers 6 \
9 | --hid-sz 352 \
10 | --inner-hid-sz 352 \
11 | --nheads 8 \
12 | --block-sz 512 \
13 | --attn-span 1024 \
14 | --dropout 0.1 \
15 | --load_balance 0.01 \
16 | --optim adam \
17 | --lr 0.0007 \
18 | --lr-warmup 4000 \
19 | --niter 80 \
20 | --batch-sz 48 \
21 | --batch-split 2 \
22 | --nbatches 1000 \
23 | --distributed \
24 | --gamma1 1.0 \
25 | --gamma2 1.0 \
26 | --mu 0.7 \
27 | --beta1 0.9 \
28 | --beta2 0.999 \
29 | --checkpoint /path/to/checkpoint/directory/smoe.pt \
30 | "
31 |
32 | echo "Training ..."
33 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args
34 |
35 | echo "Evaluation ..."
36 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode
37 |
--------------------------------------------------------------------------------
/scripts/smoe-mom-m.sh:
--------------------------------------------------------------------------------
1 | mkdir -p /path/to/checkpoint/directory/
2 |
3 | args="
4 | --data /path/to/data/directory/wikitext-103/ \
5 | --base_arch transformer \
6 | --architecture smsmsmsmsmsm \
7 | --gate_name smoe \
8 | --nlayers 6 \
9 | --hid-sz 352 \
10 | --inner-hid-sz 352 \
11 | --nheads 8 \
12 | --block-sz 512 \
13 | --attn-span 1024 \
14 | --dropout 0.1 \
15 | --load_balance 0.01 \
16 | --optim adam \
17 | --lr 0.0007 \
18 | --lr-warmup 4000 \
19 | --niter 80 \
20 | --batch-sz 48 \
21 | --batch-split 2 \
22 | --nbatches 1000 \
23 | --distributed \
24 | --gamma1 1.0 \
25 | --gamma2 1.0 \
26 | --mu 0.7 \
27 | --beta1 0.9 \
28 | --beta2 0.999 \
29 | --checkpoint /path/to/checkpoint/directory/smoe.pt \
30 | "
31 |
32 | echo "Training ..."
33 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args
34 |
35 | echo "Evaluation ..."
36 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode
37 |
--------------------------------------------------------------------------------
/scripts/smoe-mom-l.sh:
--------------------------------------------------------------------------------
1 | mkdir -p /path/to/checkpoint/directory/
2 |
3 | args="
4 | --data /path/to/data/directory/wikitext-103/ \
5 | --base_arch transformer \
6 | --architecture smsmsmsmsmsmsmsmsmsmsmsm \
7 | --gate_name smoe \
8 | --nlayers 12 \
9 | --hid-sz 512 \
10 | --inner-hid-sz 512 \
11 | --nheads 8 \
12 | --block-sz 1024 \
13 | --attn-span 2048 \
14 | --dropout 0.1 \
15 | --load_balance 0.01 \
16 | --optim adam \
17 | --lr 0.0007 \
18 | --lr-warmup 5000 \
19 | --niter 80 \
20 | --batch-sz 24 \
21 | --batch-split 2 \
22 | --nbatches 1000 \
23 | --distributed \
24 | --gamma1 0.8 \
25 | --gamma2 1.0 \
26 | --mu 0.7 \
27 | --beta1 0.9 \
28 | --beta2 0.999 \
29 | --checkpoint /path/to/checkpoint/directory/smoe.pt \
30 | "
31 |
32 | echo "Training ..."
33 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args
34 |
35 | echo "Evaluation ..."
36 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 10013 --nproc_per_node=4 --use_env train.py $args --resume --full-eval-mode
37 |
--------------------------------------------------------------------------------
/custom_utils.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import argparse
3 | import math, random
4 | import torch
5 | import torch.distributed as dist
6 |
7 |
8 | # pylint: disable=broad-except
9 | # pylint: disable=protected-access
10 | def get_torch_default_comm():
11 | r"""
12 | The NCCL communicator is needed so that Fast MoE can perform customized
13 | communication operators in the C code. However, it is not a publicly
14 | available variable. Therefore, a hacking class of the `ProcessGroupNCCL`
15 | in Fast MoE's C code takes the `_default_pg` and tries to dig the
16 | communicator out from the object. As PyTorch's private interface varies from
17 | time to time, different hacking techniques are tried one-by-one to be
18 | compatible with various versions of PyTorch.
19 | """
20 | try:
21 | comm = dist.distributed_c10d._get_default_group()
22 | return comm
23 | except Exception as _:
24 | pass
25 | try:
26 | comm = dist.distributed_c10d._default_pg
27 | if comm is not None:
28 | return comm
29 | except Exception as _:
30 | pass
31 | raise RuntimeError("Unsupported PyTorch version")
32 |
--------------------------------------------------------------------------------
/gates/naive_gate.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | from .base_gate import BaseGate
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | class NaiveGate(BaseGate):
9 | r"""
10 | A naive gate implementation that defines the standard behavior of the gate
11 | which determines which experts the tokens are going to.
12 | Both the indicies and the score, or confidence, are output to the parent
13 | module.
14 | The load-balance strategies are also designed to be implemented within the
15 | `Gate` module.
16 | """
17 |
18 | def __init__(self, d_model, num_expert, world_size, top_k=2):
19 | super().__init__(num_expert, world_size)
20 | self.gate = nn.Linear(d_model, self.tot_expert)
21 | self.top_k = top_k
22 |
23 | def forward(self, inp, return_all_scores=False):
24 | r"""
25 | The naive implementation simply calculates the top-k of a linear layer's
26 | output.
27 | """
28 | gate = self.gate(inp)
29 | gate_top_k_val, gate_top_k_idx = torch.topk(
30 | gate, k=self.top_k, dim=-1, largest=True, sorted=False
31 | ) # [.. x top_k]
32 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
33 |
34 | # (BxL) x 1 x top_k
35 | gate_score = F.softmax(gate_top_k_val, dim=-1)
36 |
37 | if return_all_scores:
38 | return gate_top_k_idx, gate_score, gate
39 | return gate_top_k_idx, gate_score
--------------------------------------------------------------------------------
/gates/swipe_gate.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import math
3 | import torch
4 | import torch.distributed as dist
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from .naive_gate import NaiveGate
8 |
9 | from fmoe.functions import count_by_gate
10 | import fmoe_cuda as fmoe_native
11 |
12 | class SwipeGate(NaiveGate):
13 | def __init__(self, d_model, num_expert, world_size, top_k=2):
14 | super().__init__(d_model, num_expert, world_size, top_k)
15 |
16 | def swipe_once(self, idx, capacity, bias):
17 | with torch.no_grad():
18 | idx_new, capacity = fmoe_native.swipe_once(idx, capacity,
19 | self.num_expert, self.world_size, bias)
20 | idx_new = idx_new.to(idx.device)
21 | return idx_new, capacity
22 |
23 | def forward(self, inp):
24 | score = self.gate(inp)
25 | orig_score, orig_idx = torch.topk(score, k=self.top_k, dim=-1)
26 |
27 | if not self.training:
28 | topk_val = F.softmax(orig_score, dim=-1)
29 | return orig_idx, topk_val
30 |
31 | capacity = torch.scalar_tensor(inp.shape[0] * self.top_k,
32 | dtype=torch.long)
33 |
34 | topk_idxs = []
35 | topk_vals = []
36 | idx_x = torch.arange(inp.shape[0], device=inp.device)
37 | for k in range(self.top_k):
38 | idx, capacity = self.swipe_once(orig_idx[:, k], capacity,
39 | k % self.num_expert)
40 | topk_vals.append(score[idx_x, idx])
41 | topk_idxs.append(idx)
42 | topk_idx = torch.stack(topk_idxs).transpose(0, 1)
43 | topk_val = torch.stack(topk_vals).transpose(0, 1)
44 | topk_val = F.softmax(topk_val, dim=-1)
45 | return topk_idx, topk_val
--------------------------------------------------------------------------------
/fastermoe/expert_utils.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 |
4 | def get_expert_param_size(e):
5 | return sum(map(lambda x: x.numel(), e.parameters()))
6 |
7 |
8 | def get_expert_params(e, out):
9 | offset = 0
10 | for n, p in e.named_parameters():
11 | seg = out[offset:offset + p.numel()]
12 | offset += p.numel()
13 | seg.copy_(p.data.flatten())
14 |
15 | def stash_expert_params(e, params):
16 | if not hasattr(e, 'expert_param_stash'):
17 | setattr(e, 'expert_param_stash', dict())
18 | offset = 0
19 | for n, p in e.named_parameters():
20 | if n not in e.expert_param_stash:
21 | e.expert_param_stash[n] = p.data.clone()
22 | with torch.no_grad():
23 | seg = params[offset:offset + p.numel()]
24 | offset += p.numel()
25 | p.copy_(seg.reshape(p.shape))
26 |
27 | def pop_expert_params(e):
28 | if not hasattr(e, 'expert_param_stash'):
29 | return
30 | for n, p in e.named_parameters():
31 | with torch.no_grad():
32 | p.copy_(e.expert_param_stash[n])
33 | e.expert_param_stash.clear()
34 |
35 | def collect_expert_grads(e, grads):
36 | offset = 0
37 | for _, p in e.named_parameters():
38 | seg = grads[offset:offset + p.numel()]
39 | offset += p.numel()
40 | if p.grad is not None:
41 | seg.copy_(p.grad.flatten())
42 | p.grad = None
43 | else:
44 | seg.zero_()
45 |
46 | def set_grads(e, grads):
47 | offset = 0
48 | for n, p in e.named_parameters():
49 | seg = grads[offset:offset + p.numel()]
50 | offset += p.numel()
51 | if p.grad is None:
52 | p.grad = seg.clone()
53 | else:
54 | p.grad += seg.reshape(p.shape)
--------------------------------------------------------------------------------
/gates/gshard_gate.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from .naive_gate import NaiveGate
6 | from .utils import limit_by_capacity
7 |
8 | class GShardGate(NaiveGate):
9 | def __init__(self, d_model, num_expert, world_size,
10 | topk=2, capacity=(1.2, 2.4), random_routing=True):
11 | assert topk == 2, 'topk should be 2 in gshard'
12 | super().__init__(d_model, num_expert, world_size, top_k=2)
13 | self.capacity = capacity
14 | self.random_routing = random_routing
15 |
16 | def forward(self, x):
17 | naive_outs = super().forward(x, return_all_scores=True)
18 | topk_idx, topk_val, gate_score = naive_outs
19 |
20 | S = gate_score.shape[0]
21 | top_k = topk_idx.shape[0] // gate_score.shape[0]
22 | top1_idx = topk_idx.view((-1, top_k))[:, 0]
23 | c_e = torch.scatter_add(
24 | torch.zeros(self.tot_expert, device=top1_idx.device),
25 | 0,
26 | top1_idx,
27 | torch.ones_like(top1_idx, dtype=torch.float),
28 | ) / S
29 | m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
30 | loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
31 | self.set_loss(loss)
32 |
33 | cap_rate = self.capacity[0 if self.training else 1]
34 | capacity = math.ceil(cap_rate * x.shape[0])
35 | _new_lec, _new_gec, topk_idx = limit_by_capacity(
36 | topk_idx, self.num_expert, self.world_size, capacity)
37 |
38 | if self.random_routing:
39 | rand_routing_prob = torch.rand(gate_score.size(0), device=x.device)
40 | mask = (2 * topk_val[:, 1] < rand_routing_prob)
41 | topk_idx[:, 1].masked_fill_(mask, -1)
42 |
43 | return topk_idx, topk_val
--------------------------------------------------------------------------------
/.github/workflows/typecheck.yaml:
--------------------------------------------------------------------------------
1 | name: Typecheck
2 |
3 | # These checks will run if at least one file is outside of the `paths-ignore`
4 | # list, but will be skipped if *all* files are in the `paths-ignore` list.
5 | #
6 | # Fore more info, see:
7 | # https://docs.github.com/en/actions/writing-workflows/workflow-syntax-for-github-actions#example-excluding-paths
8 |
9 | on:
10 | push:
11 | branches:
12 | - 'main'
13 | paths-ignore:
14 | - '**.md'
15 |
16 | pull_request:
17 | branches:
18 | - 'main'
19 | paths-ignore:
20 | - '**.md'
21 |
22 | jobs:
23 | test:
24 | strategy:
25 | fail-fast: false
26 | matrix:
27 | os: [ 'ubuntu-24.04' ]
28 | python: [ '3.10' ]
29 |
30 | runs-on: ${{ matrix.os }}
31 | name: Python ${{ matrix.python }} on ${{ matrix.os }}
32 |
33 | steps:
34 | - name: Checkout the repo
35 | uses: actions/checkout@v4
36 |
37 | - name: Setup Python
38 | uses: actions/setup-python@v5
39 | with:
40 | python-version: ${{ matrix.python }}
41 | cache: 'pip'
42 |
43 | - name: Update pip
44 | run: python -m pip install --upgrade pip
45 |
46 | - name: Install Python deps
47 | run: python -m pip install -r requirements.txt
48 |
49 | - name: Install Mypy
50 | run: python -m pip install --upgrade mypy
51 |
52 | - name: Check types with Mypy
53 | run: python -m mypy --python-version=${{ matrix.python }} .
54 | # TODO: fix the type checking errors and remove this line to make errors
55 | # obvious by failing the test.
56 | continue-on-error: true
57 |
58 | - name: Install PyType
59 | run: python -m pip install --upgrade pytype
60 |
61 | - name: Check types with PyType
62 | run: python -m pytype --python-version=${{ matrix.python }} -k .
63 | # TODO: fix the type checking errors and remove this line to make errors
64 | # obvious by failing the test.
65 | continue-on-error: true
66 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## MomentumSMoE: Integrating Momentum into Sparse Mixture of Experts
2 | MomentumSMoE: Integrating Momentum into Sparse Mixture of Experts
3 |
4 | https://arxiv.org/abs/2410.14574
5 |
6 | ### Prerequisites
7 |
8 | - pytorch
9 | - fastmoe: https://github.com/laekov/fastmoe
10 | - The toolkit supports [Weights & Biases](https://docs.wandb.ai/) for monitoring jobs. If you use it, also install `wandb`.
11 |
12 | ### Usage
13 |
14 |
15 | #### Prepare WikiText-103 Datasets:
16 |
17 | - Download the WikiText-103 dataset from [here](https://github.com/laekov/fastmoe/blob/master/examples/transformer-xl/scripts/getdata.sh), then change bash scripts based on your local data paths.
18 | ```bash
19 | data_directory/
20 | └── wikitext-103
21 | ├── test.txt
22 | ├── train.txt
23 | └── valid.txt
24 | ```
25 |
26 | #### Pretraining SMoE (SwitchTransformers) on WikiText-103:
27 |
28 | ``` # WikiText-103 dataset:
29 | bash scripts/smoe-s.sh
30 | bash scripts/smoe-m.sh
31 | bash scripts/smoe-l.sh
32 | ```
33 |
34 | #### Pretraining *Momentum*SMoE on WikiText-103:
35 |
36 | ``` # WikiText-103 dataset:
37 | bash scripts/smoe-mom-s.sh
38 | bash scripts/smoe-mom-m.sh
39 | bash scripts/smoe-mom-l.sh
40 | ```
41 |
42 | #### Pretraining *Adam*SMoE on WikiText-103:
43 |
44 | ``` # WikiText-103 dataset:
45 | bash scripts/smoe-adam-m.sh
46 | ```
47 |
48 | #### Pretraining GLaM on WikiText-103:
49 |
50 | ``` # WikiText-103 dataset:
51 | bash scripts/glam-m.sh
52 | ```
53 |
54 | #### Pretraining *Momentum*GLaM on WikiText-103:
55 |
56 | ``` # WikiText-103 dataset:
57 | bash scripts/glam-mom-m.sh
58 | ```
59 |
60 | #### Pretraining *Adam*GLaM on WikiText-103:
61 |
62 | ``` # WikiText-103 dataset:
63 | bash scripts/glam-adam-m.sh
64 | ```
65 |
66 | #### Wandb support:
67 | - Add these flags to bash script with your project and job name
68 | ``` # Wandb:
69 | --wandb
70 | --project-name test
71 | --job-name test
72 | ```
73 |
74 | ### Checkpoints
75 | Please find our SMoE checkpoints at .
76 |
77 |
78 |
--------------------------------------------------------------------------------
/gates/switch_gate.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from .naive_gate import NaiveGate
7 | from .utils import limit_by_capacity
8 |
9 | class SwitchGate(NaiveGate):
10 | r"""
11 | A switch gate implementation
12 | """
13 |
14 | def __init__(self, d_model, num_expert, world_size, topk=1,
15 | switch_eps=.1, capacity=(1.2, 2.4)):
16 | assert topk == 1, 'topk should be 1 in switch'
17 | super().__init__(d_model, num_expert, world_size, top_k=1)
18 | self.switch_eps = switch_eps
19 | self.capacity = capacity
20 |
21 | def forward(self, inp):
22 | r"""
23 | The switch firstly conduct softmax and then calculates the top-1
24 | """
25 | score = self.gate(inp)
26 |
27 | if self.training:
28 | # random uniform number from [1-eps, 1+eps]
29 | noise = torch.rand_like(score)
30 | noise = noise * 2 * self.switch_eps + 1.0 - self.switch_eps
31 | score += noise
32 |
33 | # fp32 softmax for numerical stability
34 | score = F.softmax(score.float(), dim=-1)
35 |
36 | top1_score, top1_idx = torch.topk(
37 | score, k=1, dim=-1, largest=True
38 | ) # [.. x top_k]
39 | top1_score = top1_score.to(dtype=inp.dtype)
40 |
41 | cap_rate = self.capacity[0 if self.training else 1]
42 | capacity = math.ceil(cap_rate * inp.shape[0])
43 | _new_lec, _new_gec, top1_idx = limit_by_capacity(
44 | top1_idx, self.num_expert, self.world_size, capacity)
45 |
46 | valid_idx = top1_idx[top1_idx > -1]
47 | fraction_expert = torch.scatter_add(
48 | torch.zeros(self.tot_expert, device=valid_idx.device),
49 | 0,
50 | valid_idx,
51 | torch.ones_like(valid_idx, dtype=torch.float),
52 | ) / valid_idx.numel()
53 | prob_expert = score.sum(dim=0) / valid_idx.numel()
54 | loss = (fraction_expert * prob_expert).sum() * self.tot_expert
55 | self.set_loss(loss)
56 | return top1_idx, top1_score
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 |
2 | import setuptools
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 | import os
5 | import torch
6 |
7 | cxx_flags = []
8 | ext_libs = []
9 |
10 | authors = [
11 | 'Jiaao He',
12 | 'Jiezhong Qiu',
13 | 'Aohan Zeng',
14 | 'Tiago Antunes',
15 | 'Jinjun Peng',
16 | 'Qin Li',
17 | 'Mingshu Zhai'
18 | ]
19 |
20 | is_rocm_pytorch = False
21 | if torch.__version__ >= '1.5':
22 | from torch.utils.cpp_extension import ROCM_HOME
23 | is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
24 |
25 | if os.environ.get('USE_NCCL', '1') == '1':
26 | cxx_flags.append('-DFMOE_USE_NCCL')
27 | cxx_flags.append('-DUSE_C10D_NCCL')
28 | if is_rocm_pytorch:
29 | ext_libs.append('rccl')
30 | else:
31 | ext_libs.append('nccl')
32 |
33 | if os.environ.get('MOE_DEBUG', '0') == '1':
34 | cxx_flags.append('-DMOE_DEBUG')
35 |
36 | if is_rocm_pytorch:
37 | define_macros=[('FMOE_USE_HIP', None)]
38 | else:
39 | define_macros=[]
40 |
41 |
42 | if __name__ == '__main__':
43 | setuptools.setup(
44 | name='fastmoe',
45 | version='1.1.0',
46 | description='An efficient Mixture-of-Experts system for PyTorch',
47 | author=', '.join(authors),
48 | author_email='hja20@mails.tsinghua.edu.cn',
49 | license='Apache-2',
50 | # url='https://github.com/laekov/fastmoe',
51 | # packages=['fmoe', 'fmoe.megatron', 'fmoe.gates', 'fmoe.fastermoe'],
52 | ext_modules=[
53 | CUDAExtension(
54 | name='fmoe_cuda',
55 | sources=[
56 | # 'cuda/stream_manager.cpp',
57 | # 'cuda/local_exchange.cu',
58 | # 'cuda/balancing.cu',
59 | # 'cuda/global_exchange.cpp',
60 | # 'cuda/parallel_linear.cu',
61 | 'cuda/fmoe_cuda.cpp',
62 | # 'cuda/fastermoe/smart_schedule.cpp',
63 | ],
64 | define_macros=define_macros,
65 | extra_compile_args={
66 | 'cxx': cxx_flags,
67 | 'nvcc': cxx_flags
68 | },
69 | libraries=ext_libs
70 | )
71 | ],
72 | cmdclass={
73 | 'build_ext': BuildExtension
74 | })
75 |
--------------------------------------------------------------------------------
/fastermoe/shadow_policy.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torch.distributed as dist
4 |
5 | from .config import float_from_env, switch_from_env
6 | from fmoe.functions import get_moe_group
7 |
8 | def global_policy(local_expert_count, _gec, num_expert, world_size):
9 | r"""
10 | This is the policy for two-layer MLPs, using the formula in the PPoPP paper.
11 | A few parameters are used in this policy.
12 | * `d_model`: feature length of the MLP input and output.
13 | * `alpha`: the ratio of the MLP's hidden size to `d_model`.
14 | * `bw_net`: bandwidth of the network (GBps)
15 | * `bw_mm`: computation throughput of performing GeMM (FLOPs)
16 | """
17 | bw_net = float_from_env('FMOE_FASTER_GLBPLC_NETBW', 50 * 1e9 / 8)
18 | bw_mm = float_from_env('FMOE_FASTER_GLBPLC_GPUTP', 11.5e12)
19 | alpha = float_from_env('FMOE_FASTER_GLBPLC_ALPHA', 2)
20 | d_model = float_from_env('FMOE_FASTER_GLBPLC_DMODEL', 2048)
21 |
22 | moe_group = get_moe_group()
23 | local_expert_count = local_expert_count.cuda()
24 | agecs = [torch.empty_like(local_expert_count) for _ in range(world_size)]
25 | dist.all_gather(agecs, local_expert_count, group=moe_group)
26 | all_global_expert_count = torch.stack(agecs)
27 |
28 | # TODO: data type other than float
29 | data_size = 4
30 |
31 | fwd_expert_counts = all_global_expert_count.sum(1).cpu()
32 | B_ws, indices = fwd_expert_counts.flatten().sort(0, descending=True)
33 |
34 | alphaH2 = alpha * (d_model ** 2)
35 | B_w = B_ws[0]
36 |
37 | comm = float('+inf')
38 | send_feature_time = d_model * data_size / bw_net
39 | send_model_time = 2 * alphaH2 * data_size / bw_net
40 | comp_time = 4 * alphaH2 / bw_mm
41 | lat_base = 3 * comp_time * B_w + 4 * send_feature_time * B_w
42 |
43 | res = torch.zeros(world_size * num_expert, dtype=torch.bool)
44 | shadow_time = 0
45 |
46 | for i, index in enumerate(indices):
47 | if i + 1 == indices.numel():
48 | break
49 | B_k = B_ws[i + 1]
50 | shadow_time += send_model_time
51 | lat_new = 3 * comp_time * B_k + 4 * send_feature_time * B_k + shadow_time
52 |
53 | if lat_new < lat_base:
54 | lat_base = lat_new
55 | res[index] = True
56 | else:
57 | break
58 | return res
59 |
60 | def no_shadow_policy(_lec, _gec, num_expert, world_size):
61 | res = torch.zeros(world_size * num_expert, dtype=bool)
62 | return res
63 |
64 | def get_shadow_policy(d_model=None):
65 | if d_model is not None and 'FMOE_FASTER_GLBPLC_DMODEL' not in os.environ:
66 | os.environ['FMOE_FASTER_GLBPLC_DMODEL'] = str(d_model)
67 | if not switch_from_env('FMOE_FASTER_SHADOW_ENABLE'):
68 | return no_policy
69 | return global_policy
--------------------------------------------------------------------------------
/gates/faster_gate.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | from .naive_gate import NaiveGate
3 |
4 | import os
5 | import sys
6 | import torch
7 | import torch.nn.functional as F
8 | from .utils import limit_by_capacity
9 | import fmoe_cuda
10 | from fmoe.functions import count_by_gate
11 |
12 | nw_per_node = 8
13 | try:
14 | nw_per_node = int(os.environ['FMOE_TOPO_GPUS_PER_NODE'])
15 | except Exception:
16 | pass
17 |
18 | class FasterGate(NaiveGate):
19 | def __init__(self, d_model, n_expert, world_size, node_rank):
20 | super().__init__(d_model, n_expert, world_size, top_k=2)
21 | self.ne_per_node = nw_per_node * n_expert
22 | self.ogn_ratio = .14
23 | try:
24 | self.ogn_ratio = float(os.environ['FMOE_TOPO_OUTGOING_FRACTION'])
25 | except Exception:
26 | pass
27 | self.node_rank = node_rank
28 |
29 | mask = [1] * world_size * n_expert
30 | for i in range(n_expert * world_size):
31 | if i // self.ne_per_node == self.node_rank:
32 | mask[i] = 0
33 | self.mask = torch.Tensor(mask).bool()
34 | self.policy_fn = None
35 | print('node rank {} mask {}'.format(node_rank, mask))
36 |
37 | def forward(self, inp):
38 | if self.mask.device != inp.device:
39 | self.mask = self.mask.to(inp.device)
40 |
41 | gate_score = self.gate(inp)
42 | lim_mask = self.mask
43 |
44 | top2_val, top2_idx = torch.topk(gate_score, k=2, dim=-1)
45 | S = gate_score.shape[0]
46 | top_k = 2
47 |
48 | with torch.no_grad():
49 | top1_idx = top2_idx.view((-1, top_k))[:, 0]
50 | top1_val = top2_val.view((-1, top_k))[:, 0]
51 | c_e = torch.scatter_add(
52 | torch.zeros(self.tot_expert, device=top1_idx.device),
53 | 0,
54 | top1_idx,
55 | torch.ones_like(top1_idx, dtype=torch.float),
56 | ) / S
57 | m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
58 | loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
59 | self.set_loss(loss)
60 |
61 | with torch.no_grad():
62 | if self.policy_fn is None:
63 | stored_models = torch.zeros(self.num_expert * self.world_size,
64 | dtype=torch.bool)
65 | else:
66 | # TODO: Fix this after expert shadowing is ported
67 | _, lec, aec, gec, agec = count_by_gate(top2_idx,
68 | self.num_expert, self.world_size, require_pos=False)
69 | stored_models = self.policy_fn(aec, agec,
70 | self.num_expert, self.world_size, inp.shape[-1], True)
71 | lim_mask = lim_mask & ~stored_models.view(-1).to(lim_mask.device)
72 |
73 | ogn_mask = lim_mask[top1_idx]
74 | ogn_thres = int(inp.shape[0] * self.ogn_ratio)
75 |
76 | if ogn_mask.sum().item() < ogn_thres:
77 | topk_val, topk_idx = torch.topk(gate_score, k=self.top_k)
78 | topk_val = F.softmax(topk_val, dim=-1)
79 | return topk_idx, topk_val
80 |
81 | with torch.no_grad():
82 | top1_val[~ogn_mask] = float('-inf')
83 | _, top_ogn = torch.topk(top1_val.view(-1), k=ogn_thres)
84 | cand = gate_score.clone()
85 | cand[:, lim_mask] = float('-inf')
86 | _, topk_idx = torch.topk(cand, k=self.top_k)
87 | topk_idx[top_ogn, 1] = top1_idx.view(-1)[top_ogn]
88 |
89 | idx_x = torch.arange(inp.shape[0], device=inp.device).repeat_interleave(2)
90 | topk_val = gate_score[idx_x, topk_idx.view(-1)].view(-1, self.top_k)
91 |
92 | topk_val = F.softmax(topk_val, dim=-1)
93 |
94 | return topk_idx, topk_val
95 |
96 | def gen_faster_gate(rank):
97 | def _gen(d_model, n_expert, world_size, top_k=2):
98 | assert top_k == 2
99 | return FasterGate(d_model, n_expert, world_size, rank // nw_per_node)
100 | return _gen
--------------------------------------------------------------------------------
/custom_transformer.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import argparse
3 | import math, random
4 | import torch
5 | import torch.nn as nn
6 | from custom_layers import FMoE
7 | from custom_layers import FMoELinear
8 | from custom_layers_opt import FMoEOpt
9 |
10 |
11 | class _Expert(nn.Module):
12 | r"""
13 | An expert using 2 FMoELinear modules to speed up the computation of experts
14 | within one worker.
15 | """
16 |
17 | def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
18 | super().__init__()
19 | self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
20 | self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
21 | self.activation = activation
22 |
23 | def forward(self, inp, fwd_expert_count):
24 | r"""
25 | First expand input to 4h (the hidden size is variable, but is called h4
26 | for convenience). Then perform activation. Finally shirink back to h.
27 | """
28 | x = self.htoh4(inp, fwd_expert_count)
29 | x = self.activation(x)
30 | x = self.h4toh(x, fwd_expert_count)
31 | return x
32 |
33 |
34 | class FMoETransformerMLP(FMoE):
35 | r"""
36 | A complete MoE MLP module in a Transformer block.
37 | * `activation` is the activation function to be used in MLP in each expert.
38 | * `d_hidden` is the dimension of the MLP layer.
39 | """
40 |
41 | def __init__(
42 | self,
43 | num_expert=32,
44 | d_model=1024,
45 | d_hidden=4096,
46 | activation=torch.nn.GELU(),
47 | expert_dp_comm="none",
48 | expert_rank=0,
49 | moe_top_k=2,
50 | **kwargs
51 | ):
52 | super().__init__(
53 | num_expert=num_expert, d_model=d_model, moe_top_k=moe_top_k, **kwargs
54 | )
55 | self.experts = _Expert(
56 | num_expert, d_model, d_hidden, activation, rank=expert_rank
57 | )
58 | self.mark_parallel_comm(expert_dp_comm)
59 |
60 | def forward(self, inp: torch.Tensor):
61 | r"""
62 | This module wraps up the FMoE module with reshape, residual and layer
63 | normalization.
64 | """
65 | original_shape = inp.shape
66 | inp = inp.reshape(-1, self.d_model)
67 | output = super().forward(inp)
68 | return output.reshape(original_shape)
69 |
70 |
71 | class FMoETransformerMLPOpt(FMoEOpt):
72 | r"""
73 | A complete MoE MLP module in a Transformer block.
74 | * `activation` is the activation function to be used in MLP in each expert.
75 | * `d_hidden` is the dimension of the MLP layer.
76 | """
77 |
78 | def __init__(
79 | self,
80 | num_expert=32,
81 | d_model=1024,
82 | d_hidden=4096,
83 | activation=torch.nn.GELU(),
84 | expert_dp_comm="none",
85 | expert_rank=0,
86 | moe_top_k=2,
87 | freq=0.0,
88 | alpha=0.0,
89 | act_experts="shuffle",
90 | g_blance=False,
91 | opt_blance=False,
92 | combine_gate=False,
93 | opt_loss="mse",
94 | **kwargs
95 | ):
96 | super().__init__(
97 | num_expert=num_expert,
98 | d_model=d_model,
99 | moe_top_k=moe_top_k,
100 | freq=freq,
101 | alpha=alpha,
102 | act_experts=act_experts,
103 | g_blance=g_blance,
104 | opt_blance=opt_blance,
105 | combine_gate=combine_gate,
106 | opt_loss=opt_loss,
107 | **kwargs
108 | )
109 | self.experts = _Expert(
110 | num_expert, d_model, d_hidden, activation, rank=expert_rank
111 | )
112 | self.mark_parallel_comm(expert_dp_comm)
113 |
114 | def forward(self, inp: torch.Tensor):
115 | r"""
116 | This module wraps up the FMoE module with reshape, residual and layer
117 | normalization.
118 | """
119 | original_shape = inp.shape
120 | inp = inp.reshape(-1, self.d_model)
121 | output = super().forward(inp)
122 | return output.reshape(original_shape)
123 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import argparse
3 | import math, random
4 | import torch
5 | import tqdm
6 |
7 |
8 | def _tokenize(text_path, dictionary_to_update):
9 | """Tokenizes a text file."""
10 | print("Tokenizing {}".format(text_path))
11 | assert os.path.exists(text_path)
12 |
13 | nb_tokens_in_dictionary = len(dictionary_to_update)
14 |
15 | # Count nb of tokens in text and update the dictionary
16 | with open(text_path, "r", encoding="utf8") as f:
17 | for line in f:
18 | tokens = line.split() + [""]
19 | for token in tokens:
20 | if token not in dictionary_to_update:
21 | dictionary_to_update[token] = nb_tokens_in_dictionary
22 | nb_tokens_in_dictionary += 1
23 |
24 | # Assign to each token its identifier
25 | ids = []
26 | with open(text_path, "r", encoding="utf8") as f:
27 | for line in f:
28 | tokens = line.split() + [""]
29 | for token in tokens:
30 | ids.append(dictionary_to_update[token])
31 | ids = torch.LongTensor(ids)
32 | return ids
33 |
34 |
35 | class Corpus:
36 | def __init__(self, data_path):
37 | self._dictionary = {}
38 | self.train = _tokenize(
39 | text_path=os.path.join(data_path, "train.txt"),
40 | dictionary_to_update=self._dictionary,
41 | )
42 | self.valid = _tokenize(
43 | text_path=os.path.join(data_path, "valid.txt"),
44 | dictionary_to_update=self._dictionary,
45 | )
46 | self.test = _tokenize(
47 | text_path=os.path.join(data_path, "test.txt"),
48 | dictionary_to_update=self._dictionary,
49 | )
50 |
51 | @property
52 | def vocab_size(self):
53 | return len(self._dictionary)
54 |
55 |
56 | def _batchify(data_tensor, batch_size):
57 | nb_batches = data_tensor.size(0) // batch_size
58 | # trim away some tokens to make whole batches
59 | data_tensor = data_tensor.narrow(0, 0, nb_batches * batch_size)
60 | data_tensor = data_tensor.view(batch_size, -1).contiguous()
61 | return data_tensor
62 |
63 |
64 | def _build_corpus(data_path, env_params, data_name=None):
65 | # save the corpus to a file so that it's faster next time
66 | corpus_path = os.path.join(data_path, "corpus.pt")
67 | if os.path.exists(corpus_path):
68 | print("Loading an existing corpus file from {}".format(corpus_path))
69 | corpus = torch.load(corpus_path)
70 | else:
71 | print("Creating a corpus file at {}".format(corpus_path))
72 | if env_params["distributed"]:
73 | # only one process need to create a corpus file
74 | if env_params["rank"] == 0:
75 | corpus = Corpus(data_path)
76 | torch.save(corpus, corpus_path)
77 | # sync with other processes
78 | torch.distributed.broadcast(torch.zeros(1).cuda(), src=0)
79 | else:
80 | print("Waiting rank0 to create a corpus file.")
81 | # sync with rank0
82 | torch.distributed.broadcast(torch.zeros(1).cuda(), src=0)
83 | corpus = torch.load(corpus_path)
84 | else:
85 | corpus = Corpus(data_path)
86 | torch.save(corpus, corpus_path)
87 | return corpus
88 |
89 |
90 | def _get_train_val_test_data(corpus, batch_size):
91 | return [
92 | _batchify(corpus.train, batch_size),
93 | _batchify(corpus.valid, batch_size),
94 | _batchify(corpus.test, batch_size),
95 | ]
96 |
97 |
98 | def get_train_val_test_data(data_params, env_params, batch_size, device):
99 | corpus = _build_corpus(**data_params, env_params=env_params)
100 | data_params["vocab_size"] = corpus.vocab_size
101 | train_data, val_data, test_data = _get_train_val_test_data(
102 | corpus=corpus, batch_size=batch_size
103 | )
104 |
105 | if env_params["distributed"]:
106 | # split the data into equal parts
107 | assert batch_size % env_params["world_size"] == 0
108 | device_batch_size = batch_size // env_params["world_size"]
109 | slice_data = slice(
110 | device_batch_size * env_params["rank"],
111 | device_batch_size * (env_params["rank"] + 1),
112 | )
113 | train_data = train_data[slice_data]
114 | val_data = val_data[slice_data]
115 | test_data = test_data[slice_data]
116 |
117 | train_data = train_data.to(device)
118 | val_data = val_data.to(device)
119 | test_data = test_data.to(device)
120 | return train_data, val_data, test_data
121 |
--------------------------------------------------------------------------------
/fastermoe/schedule.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | from torch.autograd.function import Function
4 |
5 | from fmoe.functions import prepare_forward, ensure_comm
6 | from fmoe.functions import _local_scatter, _local_gather
7 | import fmoe_cuda as fmoe_native
8 | from fmoe.fastermoe import expert_utils
9 |
10 | from .shadow_policy import get_shadow_policy
11 |
12 | class MoEForward(Function):
13 | @staticmethod
14 | def forward(
15 | ctx,
16 | expert_fn,
17 | experts,
18 | inp, # models,
19 | pos_s, pos_g,
20 | local_expert_count, global_expert_count,
21 | stored_models,
22 | fwd_batch_size, out_batch_size,
23 | world_size):
24 | local_input_buf = _local_scatter(inp, pos_s)
25 |
26 | ctx.gibs = [None] * (world_size * 2)
27 | ctx.gobs = [None] * (world_size * 2)
28 | def _expert_forward(x, y, idx):
29 | nothing = lambda a: a
30 | x = x.data
31 | with torch.enable_grad():
32 | x.requires_grad = True
33 | # To skip torch autograd's version check.
34 | with torch.autograd.graph.saved_tensors_hooks(nothing, nothing):
35 | y0 = expert_fn(x, [x.shape[0]])
36 | ctx.gibs[idx] = x
37 | ctx.gobs[idx] = y0
38 | y.copy_(y0)
39 |
40 | ctx.experts = experts
41 | if stored_models.any():
42 | ctx.expert_size = expert_utils.get_expert_param_size(experts)
43 | else:
44 | ctx.expert_size = 0
45 | get_param_fn = lambda out: expert_utils.get_expert_params(experts, out)
46 | pop_fn = lambda: expert_utils.pop_expert_params(experts)
47 | ctx.shadows = [None] * world_size
48 | def stash_fn(params, idx):
49 | expert_utils.stash_expert_params(experts, params)
50 | ctx.shadows[idx] = params
51 |
52 | local_output_buf, gib = fmoe_native.smart_sch_forward(
53 | local_input_buf,
54 | local_expert_count, global_expert_count,
55 | stored_models, fwd_batch_size, ctx.expert_size,
56 | world_size, _expert_forward, get_param_fn, stash_fn, pop_fn)
57 |
58 | out = _local_gather(local_output_buf, pos_g, out_batch_size,
59 | maybe_overlap=False)
60 |
61 | # gib and local_input_buf are necessary, because ctx.gibs are created
62 | # based on their memory
63 | variables = (pos_s, pos_g, local_expert_count, global_expert_count,
64 | stored_models, gib, local_input_buf)
65 |
66 | ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
67 | ctx.save_for_backward(*variables)
68 |
69 | return out
70 |
71 | @staticmethod
72 | def backward(ctx, grad_out):
73 | (pos_s, pos_g, local_expert_count, global_expert_count,
74 | stored_models, _1, _2) = ctx.saved_tensors
75 | (fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args
76 |
77 | def _expert_backward(grad_y, grad_x, idx):
78 | y = ctx.gobs[idx]
79 | x = ctx.gibs[idx]
80 | torch.autograd.backward([y], [grad_y])
81 | grad_x.copy_(x.grad)
82 |
83 | experts = ctx.experts
84 | def stash_fn(idx):
85 | expert_utils.stash_expert_params(experts, ctx.shadows[idx])
86 | pop_fn = lambda: expert_utils.pop_expert_params(experts)
87 | def collect_fn(idx, root):
88 | grad = ctx.shadows[idx]
89 | expert_utils.collect_expert_grads(experts, grad)
90 | fmoe_native.reduce_grad(grad, root, ctx.expert_size)
91 | set_grad_fn = lambda idx: expert_utils.set_grads(experts, ctx.shadows[idx])
92 |
93 | grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g)
94 | grad_in_buf = fmoe_native.smart_sch_backward(
95 | grad_out_buf,
96 | local_expert_count, global_expert_count,
97 | stored_models,
98 | pos_s.shape[0], fwd_batch_size,
99 | world_size,
100 | _expert_backward, stash_fn, pop_fn, collect_fn, set_grad_fn)
101 | grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size)
102 |
103 | return (None, None, grad_in, None, None, None, None, None, None, None, None)
104 |
105 | policy_fn = None
106 |
107 | def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None, stored_models=None):
108 | # TODO: Using multiple tensors as input is to be supported.
109 | assert(isinstance(inp, torch.Tensor))
110 | # TODO: Support many experts on each process
111 | assert(n_expert == 1)
112 | (
113 | pos,
114 | local_expert_count,
115 | global_expert_count,
116 | fwd_expert_count,
117 | fwd_batch_size,
118 | ) = prepare_forward(gate, n_expert, world_size)
119 |
120 | global policy_fn
121 | if policy_fn is None:
122 | policy_fn = get_shadow_policy(d_model=inp.shape[-1])
123 |
124 | if stored_models is None:
125 | stored_models = policy_fn(local_expert_count, global_expert_count,
126 | n_expert, world_size)
127 |
128 | topk = 1
129 | if len(gate.shape) == 2:
130 | topk = gate.shape[1]
131 | out_batch_size = inp.shape[0] * topk
132 |
133 | return MoEForward.apply(expert_fn, experts, inp,
134 | torch.div(pos, topk, rounding_mode='floor'), pos,
135 | local_expert_count, global_expert_count, stored_models,
136 | fwd_batch_size, out_batch_size, world_size)
--------------------------------------------------------------------------------
/gates/noisy_gate.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | from .base_gate import BaseGate
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.distributions.normal import Normal
8 | import math
9 |
10 | class NoisyGate(BaseGate):
11 | def __init__(self, d_model, num_expert, world_size, top_k=2):
12 | super().__init__(num_expert, world_size)
13 | self.w_gate = nn.Parameter(
14 | torch.zeros(d_model, self.tot_expert), requires_grad=True
15 | )
16 | self.w_noise = nn.Parameter(
17 | torch.zeros(d_model, self.tot_expert), requires_grad=True
18 | )
19 | self.top_k = top_k
20 | self.softplus = nn.Softplus()
21 | self.softmax = nn.Softmax(1)
22 |
23 | self.noise_epsilon = 1e-2
24 |
25 | self.reset_parameters()
26 |
27 | def reset_parameters(self):
28 | # Approach is the same as in torch.nn.Linear
29 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
30 |
31 | torch.nn.init.kaiming_uniform_(self.w_gate, a=math.sqrt(5))
32 | torch.nn.init.kaiming_uniform_(self.w_noise, a=math.sqrt(5))
33 |
34 | def _gates_to_load(self, gates):
35 | """Compute the true load per expert, given the gates.
36 | The load is the number of examples for which the corresponding gate is >0.
37 | Args:
38 | gates: a `Tensor` of shape [batch_size, n]
39 | Returns:
40 | a float32 `Tensor` of shape [n]
41 | """
42 | return (gates > 0).sum(0)
43 |
44 | def _prob_in_top_k(
45 | self, clean_values, noisy_values, noise_stddev, noisy_top_values
46 | ):
47 | """Helper function to NoisyTopKGating.
48 | Computes the probability that value is in top k, given different random noise.
49 | This gives us a way of backpropagating from a loss that balances the number
50 | of times each expert is in the top k experts per example.
51 | In the case of no noise, pass in None for noise_stddev, and the result will
52 | not be differentiable.
53 | Args:
54 | clean_values: a `Tensor` of shape [batch, n].
55 | noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
56 | normally distributed noise with standard deviation noise_stddev.
57 | noise_stddev: a `Tensor` of shape [batch, n], or None
58 | noisy_top_values: a `Tensor` of shape [batch, m].
59 | "values" Output of tf.top_k(noisy_top_values, m). m >= k+1
60 | Returns:
61 | a `Tensor` of shape [batch, n].
62 | """
63 |
64 | batch = clean_values.size(0)
65 | m = noisy_top_values.size(1)
66 | top_values_flat = noisy_top_values.flatten()
67 | threshold_positions_if_in = (
68 | torch.arange(batch, device=clean_values.device) * m + self.top_k
69 | )
70 | threshold_if_in = torch.unsqueeze(
71 | torch.gather(top_values_flat, 0, threshold_positions_if_in), 1
72 | )
73 | is_in = torch.gt(noisy_values, threshold_if_in)
74 | threshold_positions_if_out = threshold_positions_if_in - 1
75 | threshold_if_out = torch.unsqueeze(
76 | torch.gather(top_values_flat, 0, threshold_positions_if_out), 1
77 | )
78 | # is each value currently in the top k.
79 | normal = Normal(
80 | torch.tensor([0.0], device=clean_values.device),
81 | torch.tensor([1.0], device=clean_values.device),
82 | )
83 | prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev)
84 | prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev)
85 | prob = torch.where(is_in, prob_if_in, prob_if_out)
86 | return prob
87 |
88 | def cv_squared(self, x):
89 | """The squared coefficient of variation of a sample.
90 | Useful as a loss to encourage a positive distribution to be more uniform.
91 | Epsilons added for numerical stability.
92 | Returns 0 for an empty Tensor.
93 | Args:
94 | x: a `Tensor`.
95 | Returns:
96 | a `Scalar`.
97 | """
98 | eps = 1e-10
99 | # if only num_expert = 1
100 | if x.shape[0] == 1:
101 | return torch.Tensor([0])
102 | return x.float().var() / (x.float().mean() ** 2 + eps)
103 |
104 | def forward(self, inp):
105 | clean_logits = inp @ self.w_gate
106 | raw_noise_stddev = inp @ self.w_noise
107 | noise_stddev = (
108 | self.softplus(raw_noise_stddev) + self.noise_epsilon
109 | ) * self.training
110 | noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
111 | logits = noisy_logits
112 |
113 | # calculate topk + 1 that will be needed for the noisy gates
114 | top_logits, top_indices = logits.topk(
115 | min(self.top_k + 1, self.tot_expert), dim=1
116 | )
117 | top_k_logits = top_logits[:, : self.top_k]
118 | top_k_indices = top_indices[:, : self.top_k]
119 | top_k_gates = self.softmax(top_k_logits)
120 |
121 | zeros = torch.zeros_like(logits, requires_grad=True)
122 | gates = zeros.scatter(1, top_k_indices, top_k_gates)
123 |
124 | if self.top_k < self.tot_expert:
125 | load = (
126 | self._prob_in_top_k(
127 | clean_logits, noisy_logits, noise_stddev, top_logits
128 | )
129 | ).sum(0)
130 | else:
131 | load = self._gates_to_load(gates)
132 |
133 | importance = gates.sum(0)
134 | loss = self.cv_squared(importance) + self.cv_squared(load)
135 | self.set_loss(loss)
136 |
137 | return (
138 | top_k_indices.contiguous().view(-1),
139 | top_k_gates.contiguous().unsqueeze(1),
140 | )
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import argparse
3 | import math, random
4 | import torch
5 | import tqdm
6 |
7 | from custom_gates import *
8 |
9 |
10 | def _train_step(model, load_balance, X, Y, h_cache, eval_only, loss_div=1):
11 | """Single training step."""
12 |
13 | out, h_cache = model(X, h_cache)
14 | out = out.view(-1, out.size(-1))
15 | loss = torch.nn.functional.nll_loss(out, Y.view(-1))
16 | loss_value = loss.item() / loss_div
17 |
18 | if not eval_only:
19 | # loss term from adaptive-span
20 | if model.module.layers[0].attn.attn.adapt_span_enabled:
21 | loss += sum(
22 | model.module.layers[layer_i].attn.attn.adaptive_span.get_loss()
23 | for layer_i in range(model.module.attn_layer_count)
24 | )
25 |
26 | if load_balance > 0:
27 | balance_loss = 0
28 | for name, m in model.named_modules():
29 | if isinstance(m, CustomNaiveGate_Balance_SMoE) or isinstance(
30 | m, CustomNaiveGate_Balance_XMoE
31 | ):
32 | if m.loss is not None:
33 | balance_loss += m.loss
34 | loss += load_balance * balance_loss
35 | (loss / loss_div).backward(retain_graph=True)
36 | return loss_value, h_cache
37 |
38 |
39 | def _train_batch(
40 | model, load_balance, optimizer, scheduler, X, Y, h_cache, eval_only, batch_split
41 | ):
42 | """Train on a batch."""
43 |
44 | optimizer.zero_grad()
45 |
46 | if batch_split == 1:
47 | # process a batch in a single step (default behaviour)
48 | loss_value, h_cache = _train_step(model, load_balance, X, Y, h_cache, eval_only)
49 | else:
50 | # split a batch into multiple pieces that each can fit in memory
51 | assert X.size(0) % batch_split == 0
52 | split_size = X.size(0) // batch_split
53 | loss_value = 0
54 | h_cache_list = []
55 | for split_ind in range(batch_split):
56 | split_slice = slice(split_ind * split_size, (split_ind + 1) * split_size)
57 | split_h_cache = [h[split_slice, :, :] for h in h_cache]
58 | split_loss_value, split_h_cache = _train_step(
59 | model,
60 | load_balance,
61 | X[split_slice, :],
62 | Y[split_slice],
63 | split_h_cache,
64 | eval_only,
65 | batch_split,
66 | )
67 | loss_value += split_loss_value
68 | h_cache_list.append(split_h_cache)
69 | h_cache = [
70 | torch.cat([h_cache_list[i][l] for i in range(batch_split)], dim=0)
71 | for l in range(len(h_cache))
72 | ]
73 | if not eval_only:
74 | if scheduler is not None:
75 | scheduler.step()
76 | optimizer.step()
77 |
78 | # make sure span parameters are in a correct range
79 | if model.module.layers[0].attn.attn.adapt_span_enabled:
80 | for layer in model.module.layers:
81 | if layer.use_attn:
82 | layer.attn.attn.adaptive_span.clamp_param()
83 | return loss_value, h_cache
84 |
85 |
86 | def train_iteration(
87 | model,
88 | load_balance,
89 | optimizer,
90 | scheduler,
91 | data,
92 | nb_batches_per_iter,
93 | block_size,
94 | eval_only,
95 | train_pos,
96 | h_cache,
97 | batch_split,
98 | checkpoint_path,
99 | ):
100 | """Single training iteration."""
101 | if eval_only:
102 | model.eval()
103 | else:
104 | model.train()
105 |
106 | nb_batches_per_iter_max = nb_batches_per_iter
107 | if eval_only:
108 | # eval on fewer batches during training for speed-up
109 | nb_batches_per_iter_max = max(1, nb_batches_per_iter // 10)
110 | nb_batches_per_iter_max = min(
111 | nb_batches_per_iter_max, math.ceil(data.size(1) / block_size)
112 | )
113 |
114 | loss_all = 0
115 | actual_nb_batches_per_iter = 0
116 | for _ in tqdm.tqdm(range(nb_batches_per_iter_max)):
117 | actual_nb_batches_per_iter += 1
118 | X = data[:, train_pos : train_pos + block_size].contiguous()
119 | Y = data[:, train_pos + 1 : train_pos + block_size + 1].contiguous()
120 |
121 | loss, h_cache = _train_batch(
122 | model=model,
123 | load_balance=load_balance,
124 | optimizer=optimizer,
125 | scheduler=scheduler,
126 | X=X,
127 | Y=Y,
128 | h_cache=h_cache,
129 | eval_only=eval_only,
130 | batch_split=batch_split,
131 | )
132 | loss_all += loss
133 | train_pos += block_size
134 | if train_pos >= data.size(1) - block_size:
135 | # reached the end. randomize the offset to reduce overfitting
136 | train_pos = random.randrange(block_size)
137 | # reset the cache
138 | for h in h_cache:
139 | h.fill_(0)
140 |
141 | loss_all = loss_all / actual_nb_batches_per_iter
142 | return loss_all, train_pos, h_cache
143 |
144 |
145 | # do full evaluation
146 | def full_eval(model, optimizer, scheduler, data, block_size, hidden_size):
147 | model.eval()
148 | train_pos = 0
149 | nb_batches_per_iter_max = math.ceil(data.size(1) / block_size)
150 | h_cache = [
151 | torch.zeros(
152 | data.size(0),
153 | model.module.layers[layer_i].attn.attn.get_cache_size(),
154 | hidden_size,
155 | ).to(data.device)
156 | for layer_i in range(model.module.attn_layer_count)
157 | ]
158 |
159 | loss_all = 0
160 | actual_nb_batches_per_iter = 0
161 | for _ in tqdm.tqdm(range(nb_batches_per_iter_max)):
162 | actual_nb_batches_per_iter += 1
163 | X = data[:, train_pos : train_pos + block_size].contiguous()
164 | Y = data[:, train_pos + 1 : train_pos + block_size + 1].contiguous()
165 |
166 | loss, h_cache = _train_batch(
167 | model=model,
168 | load_balance=0,
169 | optimizer=optimizer,
170 | scheduler=scheduler,
171 | X=X,
172 | Y=Y,
173 | h_cache=h_cache,
174 | eval_only=True,
175 | batch_split=1,
176 | )
177 | loss_all += loss
178 | train_pos += block_size
179 | if train_pos >= data.size(1) - block_size:
180 | # Skip the remaining tokens as it can't make a whole block.
181 | # An effect on performance should be negligable for a large data.
182 | break
183 |
184 | loss_all = loss_all / actual_nb_batches_per_iter
185 | return loss_all
186 |
--------------------------------------------------------------------------------
/custom_functions.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import argparse
3 | import math, random
4 | import torch
5 | import fmoe_cuda
6 | from torch.autograd import Function
7 | from custom_utils import get_torch_default_comm
8 |
9 | _moe_group = None
10 |
11 |
12 | def ensure_comm(t, comm):
13 | if comm is None:
14 | comm = get_torch_default_comm()
15 | global _moe_group
16 | _moe_group = comm
17 | fmoe_cuda.ensure_nccl(comm, t)
18 |
19 |
20 | def get_moe_group():
21 | return _moe_group
22 |
23 |
24 | def count_by_gate(gate, num_expert, world_size, require_pos=True):
25 | with torch.no_grad():
26 | local_expert_count = torch.zeros(
27 | num_expert * world_size, device=gate.device, dtype=torch.int32
28 | )
29 | fmoe_cuda.expert_count(gate, local_expert_count)
30 | local_expert_count = local_expert_count.long()
31 |
32 | if world_size > 1:
33 | global_expert_count = fmoe_cuda.expert_exchange(
34 | local_expert_count, num_expert, world_size
35 | )
36 | else:
37 | global_expert_count = local_expert_count
38 | if not require_pos:
39 | pos = None
40 | else:
41 | lec_cum = torch.cumsum(local_expert_count, dim=0).int()
42 | pos_size = lec_cum[-1].item()
43 | pos = torch.empty((pos_size,), device=gate.device, dtype=torch.long)
44 | fmoe_cuda.assign_pos(lec_cum, gate, pos)
45 | return pos, local_expert_count, global_expert_count
46 |
47 |
48 | def prepare_forward(gate, num_expert, world_size):
49 | r"""
50 | Prepare necessary information from gate output for MoE computation.
51 |
52 | Args:
53 | gate: a 1-d Long Tensor representing the target expert of each input
54 | sample.
55 | num_expert: number of experts on each worker.
56 | world_size: number of workers that hold different experts.
57 | comm: the communicator of all workers in the expert-parallel group.
58 | """
59 | pos, local_expert_count, global_expert_count = count_by_gate(
60 | gate, num_expert, world_size
61 | )
62 | with torch.no_grad():
63 | fwd_expert_count = global_expert_count.view(world_size, num_expert).sum(dim=0)
64 | fwd_batch_size = int(fwd_expert_count.sum().item())
65 | return (
66 | pos,
67 | local_expert_count.cpu(),
68 | global_expert_count.cpu(),
69 | fwd_expert_count.cpu(),
70 | fwd_batch_size,
71 | )
72 |
73 |
74 | def _local_scatter(inp, pos):
75 | inp_buf = torch.index_select(inp, 0, pos)
76 | return inp_buf
77 |
78 |
79 | def _local_gather(inp, pos, out_batch_size, maybe_overlap=True):
80 | inp_buf = torch.zeros(
81 | out_batch_size, inp.shape[-1], dtype=inp.dtype, device=inp.device
82 | )
83 | if maybe_overlap:
84 | inp_buf.index_add_(0, pos, inp)
85 | else:
86 | inp_buf.index_copy_(0, pos, inp)
87 | return inp_buf
88 |
89 |
90 | class MOEScatter(Function):
91 | r"""
92 | Scatter input samples from [batch x sequences] to contiguous alone experts.
93 | If `world_size` is greater than 1, the samples will first be locally
94 | scattered, and then exchanged across workers.
95 | """
96 |
97 | @staticmethod
98 | def forward(
99 | ctx,
100 | inp,
101 | pos,
102 | local_expert_count,
103 | global_expert_count,
104 | fwd_batch_size,
105 | world_size,
106 | ):
107 | local_input_buf = _local_scatter(inp, pos)
108 | if world_size > 1:
109 | global_input_buf = fmoe_cuda.global_scatter(
110 | local_input_buf,
111 | local_expert_count,
112 | global_expert_count,
113 | fwd_batch_size,
114 | world_size,
115 | )
116 | else:
117 | global_input_buf = local_input_buf
118 | ctx.moe_args = inp.shape[0], pos.shape[0], world_size
119 | variables = (pos, local_expert_count, global_expert_count)
120 | ctx.save_for_backward(*variables)
121 | return global_input_buf
122 |
123 | @staticmethod
124 | def backward(ctx, global_grad_in):
125 | (pos, local_expert_count, global_expert_count) = ctx.saved_tensors
126 | (inp_batch_size, buf_batch_size, world_size) = ctx.moe_args
127 |
128 | if world_size > 1:
129 | local_grad_in = fmoe_cuda.global_gather(
130 | global_grad_in,
131 | local_expert_count,
132 | global_expert_count,
133 | buf_batch_size,
134 | world_size,
135 | )
136 | else:
137 | local_grad_in = global_grad_in
138 | grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
139 | return grad_in, None, None, None, None, None
140 |
141 |
142 | class MOEGather(Function):
143 | r"""
144 | Gather output samples from contiguous alone experts back to [batch x
145 | sequences]. Works symmetrically with MOEScatter.
146 | """
147 |
148 | @staticmethod
149 | def forward(
150 | ctx,
151 | global_output_buf,
152 | pos,
153 | local_expert_count,
154 | global_expert_count,
155 | local_batch_size,
156 | world_size,
157 | ):
158 | if world_size > 1:
159 | local_output_buf = fmoe_cuda.global_gather(
160 | global_output_buf,
161 | local_expert_count,
162 | global_expert_count,
163 | pos.shape[0],
164 | world_size,
165 | )
166 | else:
167 | local_output_buf = global_output_buf
168 | output = _local_gather(
169 | local_output_buf, pos, local_batch_size, maybe_overlap=False
170 | )
171 |
172 | ctx.moe_args = (global_output_buf.shape[0], world_size)
173 | variables = (pos, local_expert_count, global_expert_count)
174 | ctx.save_for_backward(*variables)
175 | return output
176 |
177 | @staticmethod
178 | def backward(ctx, grad_out):
179 | pos, local_expert_count, global_expert_count = ctx.saved_tensors
180 | fwd_batch_size, world_size = ctx.moe_args
181 | grad_out_buf = _local_scatter(grad_out.contiguous(), pos)
182 | if world_size > 1:
183 | global_grad_out_buf = fmoe_cuda.global_scatter(
184 | grad_out_buf,
185 | local_expert_count,
186 | global_expert_count,
187 | fwd_batch_size,
188 | world_size,
189 | )
190 | else:
191 | global_grad_out_buf = grad_out_buf
192 | return global_grad_out_buf, None, None, None, None, None
193 |
194 |
195 | class AllGather(Function):
196 | r"""
197 | A wrapper for the All-Gather function to support auto-differentiation.
198 | """
199 |
200 | @staticmethod
201 | def forward(ctx, inp, rank, world_size, group):
202 | tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
203 | torch.distributed.all_gather(tensor_list, inp, group=group)
204 | torch.cuda.synchronize()
205 | output = torch.cat(tensor_list, dim=0)
206 | ctx.args = rank, inp.shape[0]
207 | return output
208 |
209 | @staticmethod
210 | def backward(ctx, grad_out):
211 | rank, dim0 = ctx.args
212 | return grad_out[rank * dim0 : (rank + 1) * dim0], None, None, None
213 |
214 |
215 | class Slice(Function):
216 | r"""
217 | A wrapper for the Slice function to support auto-differentiation.
218 | """
219 |
220 | @staticmethod
221 | def forward(ctx, inp, rank, world_size, group):
222 | B: int = inp.shape[0]
223 | local_batch_size = B // world_size
224 | batch_start = local_batch_size * rank
225 | batch_end = min(batch_start + local_batch_size, B)
226 | inp = inp[batch_start:batch_end]
227 | ctx.args = world_size, group
228 | return inp
229 |
230 | @staticmethod
231 | def backward(ctx, grad_out):
232 | world_size, group = ctx.args
233 | tensor_list = [torch.empty_like(grad_out) for _ in range(world_size)]
234 | torch.distributed.all_gather(tensor_list, grad_out, group=group)
235 | torch.cuda.synchronize()
236 | grad_out = torch.cat(tensor_list, dim=0)
237 | return grad_out, None, None, None
238 |
--------------------------------------------------------------------------------
/custom_gates.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import argparse
3 | import math, random
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | import pdb
9 | import numpy as np
10 | from fmoe.gates.base_gate import BaseGate
11 |
12 | __all__ = [
13 | "CustomNaiveGate_Balance_SMoE",
14 | "CustomNaiveGate_Balance_XMoE",
15 | "CustomNaiveGate_Balance_StableMoE",
16 | ]
17 |
18 |
19 | class CustomNaiveGate_Balance_SMoE(BaseGate):
20 | def __init__(self, d_model, num_expert, world_size, top_k=2, g_blance=False):
21 | super().__init__(num_expert, world_size)
22 | self.gate = nn.Linear(d_model, self.tot_expert)
23 | self.top_k = top_k
24 | self.dense_moe_flag = False
25 | self.g_blance = g_blance
26 | self.loss = None
27 |
28 | def set_load_balance(self, gate, gate_top_k_idx):
29 |
30 | score = F.softmax(gate, dim=-1)
31 | valid_idx = gate_top_k_idx[gate_top_k_idx > -1]
32 | fraction_expert = (
33 | torch.scatter_add(
34 | torch.zeros(self.tot_expert, device=valid_idx.device),
35 | 0,
36 | valid_idx,
37 | torch.ones_like(valid_idx, dtype=torch.float),
38 | )
39 | / valid_idx.numel()
40 | )
41 | prob_expert = score.sum(dim=0) / valid_idx.numel()
42 |
43 | loss = (fraction_expert * prob_expert).sum() * self.tot_expert
44 | self.loss = loss
45 |
46 | def forward(self, inp, return_all_scores=False):
47 |
48 | gate = self.gate(inp)
49 |
50 | if self.dense_moe_flag:
51 | gate = torch.ones_like(gate) # average the importance of all experts
52 | gate_top_k_val, gate_top_k_idx = torch.topk(
53 | gate, k=self.tot_expert, dim=-1, largest=True, sorted=False
54 | )
55 | gate_top_k_val = gate_top_k_val.view(-1, self.tot_expert)
56 | else:
57 | gate_top_k_val, gate_top_k_idx = torch.topk(
58 | gate, k=self.top_k, dim=-1, largest=True, sorted=False
59 | ) # [.. x top_k]
60 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) # (BxL) x 1 x top_k
61 |
62 | gate_score = F.softmax(gate_top_k_val, dim=-1)
63 | if self.g_blance:
64 | self.set_load_balance(gate, gate_top_k_idx)
65 |
66 | if return_all_scores:
67 | return gate_top_k_idx, gate_score, gate
68 | return gate_top_k_idx, gate_score
69 |
70 |
71 | class CustomNaiveGate_Balance_XMoE(BaseGate):
72 | def __init__(self, d_model, num_expert, world_size, top_k=2, g_balance=False):
73 | super().__init__(num_expert, world_size)
74 | self.gate = nn.Linear(d_model, self.tot_expert)
75 | self.top_k = top_k
76 | self.dense_moe_flag = False
77 | self.g_balance = g_balance
78 | self.loss = 0.0
79 |
80 | expert_embeddings = torch.empty(num_expert, 8)
81 | torch.nn.init.orthogonal_(expert_embeddings, gain=0.32)
82 | self.register_parameter(
83 | "expert_embeddings", torch.nn.Parameter(expert_embeddings)
84 | )
85 |
86 | self.inp_reduction = torch.nn.Linear(d_model, 8, bias=False)
87 |
88 | def set_load_balance(self, gate, gate_top_k_idx):
89 | # gate_top_k_idx (tokens_number, top-k)
90 | # gate_top_k_val (tokens_number, top-k)
91 |
92 | score = F.softmax(gate / 0.3, dim=-1)
93 | valid_idx = gate_top_k_idx[gate_top_k_idx > -1]
94 | fraction_expert = (
95 | torch.scatter_add(
96 | torch.zeros(self.tot_expert, device=valid_idx.device),
97 | 0,
98 | valid_idx,
99 | torch.ones_like(valid_idx, dtype=torch.float),
100 | )
101 | / valid_idx.numel()
102 | )
103 | prob_expert = score.sum(dim=0) / valid_idx.numel()
104 |
105 | loss = (fraction_expert * prob_expert).sum() * self.tot_expert
106 | self.loss = loss
107 |
108 | def forward(self, inp, return_all_scores=False):
109 |
110 | reduced_inp = self.inp_reduction(inp)
111 | with torch.no_grad():
112 | expert_embeddings_norm = self.expert_embeddings.norm(
113 | p=2.0, dim=1, keepdim=True
114 | )
115 | self.expert_embeddings.mul_(1.5 / expert_embeddings_norm)
116 |
117 | gate = self._cosine(reduced_inp, self.expert_embeddings)
118 | gate = self._make_finite(gate)
119 |
120 | if self.dense_moe_flag:
121 | gate = torch.ones_like(gate) # average the importance of all experts
122 | gate_top_k_val, gate_top_k_idx = torch.topk(
123 | gate, k=self.tot_expert, dim=-1, largest=True, sorted=False
124 | )
125 | gate_top_k_val = gate_top_k_val.view(-1, self.tot_expert)
126 | else:
127 | gate_top_k_val, gate_top_k_idx = torch.topk(
128 | gate, k=self.top_k, dim=-1, largest=True, sorted=False
129 | ) # [.. x top_k]
130 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) # (BxL) x 1 x top_k
131 |
132 | gate_score = F.softmax(gate_top_k_val, dim=-1)
133 | if self.g_balance:
134 | self.set_load_balance(gate, gate_top_k_idx)
135 |
136 | if return_all_scores:
137 | return gate_top_k_idx, gate_score, gate
138 | return gate_top_k_idx, gate_score
139 |
140 | def _cosine(self, mat1, mat2, eps=1e-4):
141 | assert mat1.dim() == 2
142 | assert mat2.dim() == 2
143 | # mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps)
144 | mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps)
145 | return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1)
146 |
147 | def _make_finite(self, scores):
148 | ok = scores.isfinite()
149 | if not ok.all():
150 | # NaNs here can break the assignment algorithm
151 | scores[~ok] = scores[ok].min()
152 | return scores
153 |
154 |
155 | class CustomNaiveGate_Balance_StableMoE(BaseGate):
156 | r"""
157 | Naive Gate StableMoE
158 | """
159 |
160 | def __init__(self, d_model, num_expert, world_size, top_k=2, g_balance=False):
161 | super().__init__(num_expert, world_size)
162 | self.top_k = top_k
163 | self.dense_moe_flag = False
164 | self.g_balance = g_balance
165 | self.loss = 0.0
166 |
167 | expert_embeddings = torch.empty(num_expert, d_model)
168 | torch.nn.init.orthogonal_(expert_embeddings, gain=0.32)
169 | self.register_parameter(
170 | "expert_embeddings", torch.nn.Parameter(expert_embeddings)
171 | )
172 |
173 | def set_load_balance(self, gate, gate_top_k_idx):
174 |
175 | score = F.softmax(gate / 0.3, dim=-1)
176 | valid_idx = gate_top_k_idx[gate_top_k_idx > -1]
177 | fraction_expert = (
178 | torch.scatter_add(
179 | torch.zeros(self.tot_expert, device=valid_idx.device),
180 | 0,
181 | valid_idx,
182 | torch.ones_like(valid_idx, dtype=torch.float),
183 | )
184 | / valid_idx.numel()
185 | )
186 | prob_expert = score.sum(dim=0) / valid_idx.numel()
187 |
188 | loss = (fraction_expert * prob_expert).sum() * self.tot_expert
189 | self.loss = loss
190 |
191 | def forward(self, inp, return_all_scores=False):
192 |
193 | gate = self._cosine(inp, self.expert_embeddings)
194 | gate = self._make_finite(gate)
195 |
196 | if self.dense_moe_flag:
197 | gate = torch.ones_like(gate) # average the importance of all experts
198 | gate_top_k_val, gate_top_k_idx = torch.topk(
199 | gate, k=self.tot_expert, dim=-1, largest=True, sorted=False
200 | )
201 | gate_top_k_val = gate_top_k_val.view(-1, self.tot_expert)
202 | else:
203 | gate_top_k_val, gate_top_k_idx = torch.topk(
204 | gate, k=self.top_k, dim=-1, largest=True, sorted=False
205 | ) # [.. x top_k]
206 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
207 | # (BxL) x 1 x top_k
208 |
209 | gate_score = F.softmax(gate_top_k_val, dim=-1)
210 | if self.g_balance:
211 | self.set_load_balance(gate, gate_top_k_idx)
212 |
213 | if return_all_scores:
214 | return gate_top_k_idx, gate_score, gate
215 | return gate_top_k_idx, gate_score
216 |
217 | def _cosine(self, mat1, mat2, eps=1e-4):
218 | assert mat1.dim() == 2
219 | assert mat2.dim() == 2
220 | # mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps)
221 | mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps)
222 | return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1)
223 |
224 | def _make_finite(self, scores):
225 | ok = scores.isfinite()
226 | if not ok.all():
227 | # NaNs here can break the assignment algorithm
228 | scores[~ok] = scores[ok].min()
229 | return scores
230 |
--------------------------------------------------------------------------------
/finetune_trainer.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import argparse
3 | import math, random
4 | import torch
5 | import tqdm
6 | import pdb
7 | import torch.nn as nn
8 | from custom_gates import *
9 |
10 |
11 | def _train_step(model, load_balance, X, Y, h_cache, eval_only, loss_div=1):
12 | """Single training step."""
13 | acc_num, acc_value = 0, 0.0
14 | out, h_cache = model(X, h_cache)
15 | # print(model.module.layers[0].attn.proj_key.weight)
16 | out = out.view(-1, out.size(-1))
17 | # loss = torch.nn.functional.nll_loss(out, Y.view(-1))
18 | criterion = nn.CrossEntropyLoss()
19 | # pdb.set_trace()
20 | # print(out, Y)
21 | loss = criterion(out, Y)
22 | loss_value = loss.item() / loss_div
23 | # loss = loss.float()
24 | # loss = loss.item() / loss_div
25 | acc_value += (out.argmax(-1) == Y).sum().item()
26 | acc_num += Y.shape[0]
27 |
28 | if not eval_only:
29 | # loss term from adaptive-span
30 | if model.module.layers[0].attn.attn.adapt_span_enabled:
31 | loss += sum(
32 | model.module.layers[layer_i].attn.attn.adaptive_span.get_loss()
33 | for layer_i in range(model.module.attn_layer_count)
34 | )
35 |
36 | if load_balance > 0:
37 | balance_loss = 0
38 | for name, m in model.named_modules():
39 | if isinstance(m, CustomNaiveGate_Balance_SMoE) or isinstance(
40 | m, CustomNaiveGate_Balance_XMoE
41 | ):
42 | balance_loss += m.loss
43 | loss += load_balance * balance_loss
44 |
45 | (loss / loss_div).backward()
46 | # print(torch.norm(model.module.layers[0].attn.proj_key.weight.grad))
47 | return loss_value, acc_num, acc_value, h_cache
48 |
49 |
50 | def _train_batch(
51 | model, load_balance, optimizer, scheduler, X, Y, h_cache, eval_only, batch_split
52 | ):
53 | """Train on a batch."""
54 |
55 | optimizer.zero_grad()
56 | total_len, total_acc = 0, 0.0
57 | if batch_split == 1:
58 | # process a batch in a single step (default behaviour)
59 | loss_value, total_len, total_acc, h_cache = _train_step(
60 | model, load_balance, X, Y, h_cache, eval_only
61 | )
62 | else:
63 | # split a batch into multiple pieces that each can fit in memory
64 | assert X.size(0) % batch_split == 0
65 | split_size = X.size(0) // batch_split
66 | loss_value = 0
67 | h_cache_list = []
68 | for split_ind in range(batch_split):
69 | split_slice = slice(split_ind * split_size, (split_ind + 1) * split_size)
70 | split_h_cache = [h[split_slice, :, :] for h in h_cache]
71 | tmp_len, tmp_acc = 0, 0.0
72 | # pdb.set_trace()
73 | split_loss_value, tmp_len, tmp_acc, split_h_cache = _train_step(
74 | model=model,
75 | load_balance=load_balance,
76 | X=X[split_slice, :],
77 | Y=Y[split_slice],
78 | h_cache=split_h_cache,
79 | eval_only=eval_only,
80 | loss_div=batch_split,
81 | )
82 | loss_value += split_loss_value
83 | total_len += tmp_len
84 | total_acc += tmp_acc
85 | h_cache_list.append(split_h_cache)
86 | h_cache = [
87 | torch.cat([h_cache_list[i][l] for i in range(batch_split)], dim=0)
88 | for l in range(len(h_cache))
89 | ]
90 | if not eval_only:
91 | if scheduler is not None:
92 | scheduler.step()
93 | optimizer.step()
94 |
95 | # make sure span parameters are in a correct range
96 | if model.module.layers[0].attn.attn.adapt_span_enabled:
97 | for layer in model.module.layers:
98 | if layer.use_attn:
99 | layer.attn.attn.adaptive_span.clamp_param()
100 | return loss_value, total_len, total_acc, h_cache
101 |
102 |
103 | def train_iteration(
104 | model,
105 | load_balance,
106 | optimizer,
107 | scheduler,
108 | data,
109 | nb_batches_per_iter,
110 | block_size,
111 | eval_only,
112 | train_pos,
113 | h_cache,
114 | batch_split,
115 | checkpoint_path,
116 | ):
117 | """Single training iteration."""
118 | if eval_only:
119 | model.eval()
120 | else:
121 | model.train()
122 |
123 | # nb_batches_per_iter_max = nb_batches_per_iter
124 | # if eval_only:
125 | # # eval on fewer batches during training for speed-up
126 | # nb_batches_per_iter_max = max(1, nb_batches_per_iter // 10)
127 | # for _temp, _, _ in data:
128 | # ch_block = _temp.size(1) # data.size(1)
129 | # break
130 | # nb_batches_per_iter_max = min(
131 | # nb_batches_per_iter_max, math.ceil(ch_block / block_size)
132 | # )
133 | nb_batches_per_iter_max = nb_batches_per_iter
134 | if eval_only:
135 | # eval on fewer batches during training for speed-up
136 | nb_batches_per_iter_max = max(1, nb_batches_per_iter // 10)
137 | nb_batches_per_iter_max = min(
138 | nb_batches_per_iter_max, math.ceil(data.n_step / block_size)
139 | )
140 |
141 | loss_all = 0
142 | # all accuracy and number
143 | total_len, total_acc = 0, 0.0
144 | actual_nb_batches_per_iter = 0
145 | for _data, _att_mask, _target in tqdm.tqdm(data):
146 | actual_nb_batches_per_iter += 1
147 | _data = _data.cuda()
148 | _target = _target.cuda()
149 | X = _data.permute(
150 | 1, 0
151 | ).contiguous() # data[:, train_pos: train_pos + block_size].contiguous()
152 | Y = _target # .permute(1,0) #.contiguous() #data[:, train_pos + 1: train_pos + block_size + 1].contiguous()
153 | # print(Y)
154 |
155 | loss, tmp_len, tmp_acc, h_cache = _train_batch(
156 | model=model,
157 | load_balance=load_balance,
158 | optimizer=optimizer,
159 | scheduler=scheduler,
160 | X=X,
161 | Y=Y,
162 | h_cache=h_cache,
163 | eval_only=eval_only,
164 | batch_split=batch_split,
165 | )
166 | # print(tmp_acc)
167 | loss_all += loss
168 | total_len += tmp_len
169 | total_acc += tmp_acc
170 | train_pos += block_size
171 | if train_pos >= _data.size(1) - block_size:
172 | # reached the end. randomize the offset to reduce overfitting
173 | train_pos = random.randrange(block_size)
174 | # reset the cache
175 | for h in h_cache:
176 | h.fill_(0)
177 |
178 | loss_all = loss_all / actual_nb_batches_per_iter
179 | acc_all = 100 * total_acc / total_len
180 | return loss_all, acc_all, train_pos, h_cache
181 |
182 |
183 | # do full evaluation
184 | def full_eval(model, optimizer, scheduler, data, block_size, hidden_size):
185 | model.eval()
186 | train_pos = 0
187 | # nb_batches_per_iter_max = math.ceil(data.encoded.size(1) / block_size)
188 | h_cache = [
189 | torch.zeros(
190 | data.bsz,
191 | model.module.layers[layer_i].attn.attn.get_cache_size(),
192 | hidden_size,
193 | ).cuda()
194 | for layer_i in range(model.module.attn_layer_count)
195 | ]
196 |
197 | loss_all = 0
198 | actual_nb_batches_per_iter = 0
199 | total_len, total_acc = 0, 0.0
200 |
201 | for _data, _att_mask, _target in tqdm.tqdm(data):
202 | actual_nb_batches_per_iter += 1
203 | _data = _data.cuda()
204 | _target = _target.cuda()
205 | X = _data.permute(
206 | 1, 0
207 | ).contiguous() # data[:, train_pos: train_pos + block_size].contiguous()
208 | Y = _target # .permute(1,0) #.contiguous() #data[:, train_pos + 1: train_pos + block_size + 1].contiguous()
209 | # pdb.set_trace()
210 | # for _ in tqdm.tqdm(range(nb_batches_per_iter_max)):
211 | # actual_nb_batches_per_iter += 1
212 | # X = data[:, train_pos: train_pos + block_size].contiguous()
213 | # Y = data[:, train_pos + 1: train_pos + block_size + 1].contiguous()
214 |
215 | loss, tmp_len, tmp_acc, h_cache = _train_batch(
216 | model=model,
217 | load_balance=0,
218 | optimizer=optimizer,
219 | scheduler=scheduler,
220 | X=X,
221 | Y=Y,
222 | h_cache=h_cache,
223 | eval_only=True,
224 | batch_split=1,
225 | )
226 | loss_all += loss
227 | total_len += tmp_len
228 | total_acc += tmp_acc
229 | train_pos += block_size
230 | if train_pos >= _data.size(1) - block_size:
231 | # Skip the remaining tokens as it can't make a whole block.
232 | # An effect on performance should be negligable for a large data.
233 | break
234 |
235 | loss_all = loss_all / actual_nb_batches_per_iter
236 | acc_all = 100 * total_acc / total_len
237 | return loss_all, acc_all
238 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import argparse
3 | import math, random
4 | import functools
5 | import os, shutil
6 | import torch
7 | import tqdm
8 | from models import CustomizedMoEPositionwiseFFOpt
9 |
10 |
11 | def logging(s, log_path, print_=True, log_=True):
12 | if print_:
13 | print(s)
14 | if log_:
15 | with open(log_path, "a+") as f_log:
16 | f_log.write(s + "\n")
17 |
18 |
19 | def get_logger(log_path, **kwargs):
20 | return functools.partial(logging, log_path=log_path, **kwargs)
21 |
22 |
23 | def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
24 | if debug:
25 | print("Debug Mode : no experiment dir created")
26 | return functools.partial(logging, log_path=None, log_=False)
27 |
28 | if not os.path.exists(dir_path):
29 | os.makedirs(dir_path)
30 |
31 | print("Experiment dir : {}".format(dir_path))
32 | if scripts_to_save is not None:
33 | script_path = os.path.join(dir_path, "scripts")
34 | if not os.path.exists(script_path):
35 | os.makedirs(script_path)
36 | for script in scripts_to_save:
37 | dst_file = os.path.join(dir_path, "scripts", os.path.basename(script))
38 | shutil.copyfile(script, dst_file)
39 |
40 | return get_logger(log_path=os.path.join(dir_path, "log.txt"))
41 |
42 |
43 | def freeze_gate_weight(model):
44 | print("* Freeze Router")
45 | for name, p in model.named_parameters():
46 | if "gate.gate" in name:
47 | print("Freeze: ", name)
48 | p.requires_grad = False
49 |
50 |
51 | def set_freq_optimal_search(model, threshold):
52 | print(f"* Set Freq Optimal Search: ")
53 | for name, m in model.named_modules():
54 | if isinstance(m, CustomizedMoEPositionwiseFFOpt):
55 | if random.random() > (1 - threshold):
56 | print(f"* Set Freq of {name} to 1.0")
57 | m.freq = 1.0
58 | else:
59 | print(f"* Set Freq of {name} to 0.0")
60 | m.freq = 0.0
61 |
62 |
63 | def _parse_args(params_config, args):
64 | parser = argparse.ArgumentParser()
65 | for params_category in params_config: # e.g., 'model_params'
66 | for param_flag, param_config in params_config[params_category].items():
67 | # e.g., param_flag = '--block-sz'
68 | parser.add_argument(param_flag, **param_config)
69 | return parser.parse_args(args)
70 |
71 |
72 | def get_params(params_config, args=None):
73 | namespace = _parse_args(params_config, args)
74 | return {
75 | params_category: {
76 | param_config["dest"]: namespace.__getattribute__(param_config["dest"])
77 | for param_config in params_config[params_category].values()
78 | }
79 | for params_category in params_config
80 | }
81 |
82 |
83 | ##############################################################################
84 | # ENVIRONMENT
85 | ##############################################################################
86 |
87 |
88 | def _torch_distributed_init_process_group(local_rank):
89 | torch.distributed.init_process_group(backend="nccl", init_method="env://")
90 | rank = torch.distributed.get_rank()
91 | world_size = torch.distributed.get_world_size()
92 | print("my rank={} local_rank={}".format(rank, local_rank))
93 | torch.cuda.set_device(local_rank)
94 | return {
95 | "rank": rank,
96 | "world_size": world_size,
97 | }
98 |
99 |
100 | def set_up_env(env_params):
101 | assert torch.cuda.is_available()
102 | if env_params["distributed"]:
103 | env_params.update(
104 | _torch_distributed_init_process_group(local_rank=env_params["local_rank"])
105 | )
106 | env_params["device"] = torch.device("cuda")
107 |
108 |
109 | ##############################################################################
110 | # OPTIMIZER AND SCHEDULER
111 | ##############################################################################
112 |
113 |
114 | def _get_grad_requiring_params(model):
115 | nb_parameters = 0
116 | grad_requiring_params = []
117 | for param in model.parameters():
118 | if param.requires_grad:
119 | nb_parameters += param.numel()
120 | grad_requiring_params.append(param)
121 | print("nb_parameters={:.2f}M".format(nb_parameters / 1e6))
122 | return grad_requiring_params
123 |
124 |
125 | def _get_optimizer(model, optim, lr: float, momentum: float, grad_clip: float):
126 | if optim == "sgd":
127 | return torch.optim.SGD(
128 | _get_grad_requiring_params(model), lr=lr, momentum=momentum
129 | )
130 | elif optim == "adam":
131 | return torch.optim.Adam(
132 | _get_grad_requiring_params(model),
133 | lr=lr,
134 | )
135 | else:
136 | raise RuntimeError("wrong type of optimizer - must be 'sgd' or 'adam'")
137 |
138 |
139 | def _get_scheduler(optimizer, lr_warmup):
140 | if lr_warmup > 0:
141 | return torch.optim.lr_scheduler.LambdaLR(
142 | optimizer, lambda ep: min(1, ep / lr_warmup)
143 | )
144 | return None
145 |
146 |
147 | def get_optimizer_and_scheduler(model, optim_params):
148 | optimizer = _get_optimizer(
149 | model=model,
150 | optim=optim_params["optim"],
151 | lr=optim_params["lr"],
152 | momentum=optim_params["momentum"],
153 | grad_clip=optim_params["grad_clip"],
154 | )
155 | scheduler = _get_scheduler(optimizer=optimizer, lr_warmup=optim_params["lr_warmup"])
156 | return optimizer, scheduler
157 |
158 |
159 | ##############################################################################
160 | # CHECKPOINT
161 | ##############################################################################
162 |
163 |
164 | def _load_checkpoint(checkpoint_path, model, optimizer, scheduler, logger, distributed):
165 | print("loading from a checkpoint at {}".format(checkpoint_path))
166 | if distributed:
167 | # the model is saved from gpu0 so we need to map it to CPU first
168 | checkpoint_state = torch.load(
169 | checkpoint_path, map_location=lambda storage, loc: storage
170 | )
171 | else:
172 | checkpoint_state = torch.load(checkpoint_path)
173 | iter_init = checkpoint_state["nb_batches_per_iter"] + 1 # next iteration
174 | # del checkpoint_state["model"]["module.in_emb.weight"]
175 | # del checkpoint_state["model"]["module.out_emb.weight"]
176 | # del checkpoint_state["model"]["module.out_emb.bias"]
177 | model.load_state_dict(checkpoint_state["model"])
178 | optimizer.load_state_dict(checkpoint_state["optimizer"])
179 | if "scheduler_iter" in checkpoint_state:
180 | # we only need the step count
181 | scheduler.step(checkpoint_state["scheduler_iter"])
182 | return iter_init
183 |
184 |
185 | def load_checkpoint(checkpoint_path, model, optimizer, scheduler, logger, distributed, resume):
186 | print(checkpoint_path)
187 | if resume and os.path.exists(checkpoint_path):
188 | return _load_checkpoint(
189 | checkpoint_path=checkpoint_path,
190 | model=model,
191 | optimizer=optimizer,
192 | scheduler=scheduler,
193 | logger=logger,
194 | distributed=distributed,
195 | )
196 | return 0
197 |
198 |
199 | def save_checkpoint(
200 | checkpoint_path, nb_batches_per_iter, model, optimizer, scheduler, logger
201 | ):
202 | if checkpoint_path:
203 | checkpoint_state = {
204 | "nb_batches_per_iter": nb_batches_per_iter, # last completed iteration
205 | "model": model.state_dict(),
206 | "optimizer": optimizer.state_dict(),
207 | }
208 | if scheduler is not None:
209 | checkpoint_state["scheduler_iter"] = scheduler.last_epoch
210 | torch.save(checkpoint_state, checkpoint_path)
211 |
212 |
213 | ##############################################################################
214 | # LOGGER
215 | ##############################################################################
216 |
217 |
218 | class Logger:
219 | def __init__(self):
220 | self._state_dict = dict()
221 |
222 | def load_state_dict(self, state_dict):
223 | self._state_dict = state_dict
224 |
225 | def state_dict(self):
226 | return self._state_dict
227 |
228 | def _log(self, title, value):
229 | if title not in self._state_dict:
230 | self._state_dict[title] = []
231 | self._state_dict[title].append(value)
232 |
233 | def log_iter(
234 | self, iter_no, nb_batches_per_iter, loss_train, loss_val, elapsed, model
235 | ):
236 | step = (iter_no + 1) * nb_batches_per_iter
237 | train_bpc = float(loss_train / math.log(2))
238 | val_bpc = float(loss_val / math.log(2))
239 | msg = "steps: {}".format(step)
240 | msg += "\ttrain: {:.3f}bpc\tval: {:.3f}bpc".format(train_bpc, val_bpc)
241 | msg += "\tms/batch: {:.1f}".format(elapsed)
242 | self._log(title="step", value=step)
243 | self._log(title="train_bpc", value=train_bpc)
244 | self._log(title="val_bpc", value=val_bpc)
245 |
246 | if model.module.layers[0].attn.attn.adapt_span_enabled:
247 | avg_spans = []
248 | max_spans = []
249 | for layer in model.module.layers:
250 | if layer.use_attn:
251 | avg_spans.append(
252 | layer.attn.attn.adaptive_span.get_current_avg_span()
253 | )
254 | max_spans.append(
255 | layer.attn.attn.adaptive_span.get_current_max_span()
256 | )
257 | span_avg = float(sum(avg_spans)) / len(avg_spans)
258 | span_max = float(max(max_spans))
259 | self._log("span_avg", span_avg)
260 | self._log("span_max", span_max)
261 | msg += "\tspan_avg: {:.0f}\tspan_max: {:.0f}".format(span_avg, span_max)
262 |
263 | print(msg)
264 |
--------------------------------------------------------------------------------
/finetune_train.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import warnings
3 |
4 | warnings.filterwarnings("ignore")
5 |
6 | import argparse
7 | import math, random
8 | import torch
9 | import time
10 | import pdb
11 | from config import PARAMS_CONFIG
12 |
13 | from finetune_data import get_lm_corpus
14 | from finetune_models import TransformerSeq
15 | from finetune_trainer import train_iteration, full_eval
16 | import datetime
17 | from utils import (
18 | get_params,
19 | set_up_env,
20 | get_optimizer_and_scheduler,
21 | load_checkpoint,
22 | save_checkpoint,
23 | create_exp_dir,
24 | freeze_gate_weight,
25 | Logger,
26 | )
27 |
28 |
29 | def launch(
30 | env_params,
31 | model_params,
32 | adapt_span_params,
33 | optim_params,
34 | data_params,
35 | trainer_params,
36 | ):
37 | # global val
38 | best_val_loss = None
39 | # ENVIRONMENT (device, distributed, etc.)
40 | set_up_env(env_params)
41 | device = env_params["device"]
42 | distributed = env_params["distributed"]
43 |
44 | if distributed == False or env_params["rank"] == 0:
45 | print("data_params:\t", data_params)
46 | print("model_params:\t", model_params)
47 | print("optim_params:\t", optim_params)
48 | print("trainer_params:\t", trainer_params)
49 | print("adapt_span_params:\t", adapt_span_params)
50 |
51 | # DATA
52 | corpus = get_lm_corpus(data_params["data_path"], data_params["data_name"])
53 | ntokens = len(corpus.vocab)
54 |
55 | if data_params["data_name"] in ["sst2", "imdb"]:
56 | num_classes = 2
57 | elif data_params["data_name"] == "sst5":
58 | num_classes = 5
59 | elif data_params["data_name"] == "banking77":
60 | num_classes = 77
61 |
62 | eval_batch_size = 10
63 | train_data = corpus.get_iterator("train", trainer_params["batch_size"])
64 | val_data = corpus.get_iterator("valid", eval_batch_size)
65 | test_data = val_data # corpus.get_iterator('test', eval_batch_size)
66 |
67 | # MODEL data_params['vocab_size']
68 | model = TransformerSeq(
69 | vocab_size=ntokens,
70 | **model_params,
71 | num_classes=num_classes,
72 | adapt_span_params=adapt_span_params,
73 | )
74 | print(model)
75 | if distributed:
76 | local_rank = env_params["local_rank"]
77 | model = model.to(device)
78 | model = torch.nn.parallel.DistributedDataParallel(
79 | model,
80 | device_ids=[local_rank],
81 | output_device=local_rank,
82 | find_unused_parameters=True,
83 | )
84 | else:
85 | model = torch.nn.DataParallel(model)
86 | model = model.to(device)
87 |
88 | # OPTIMIZER AND SCHEDULER
89 | optimizer, scheduler = get_optimizer_and_scheduler(
90 | model=model, optim_params=optim_params
91 | )
92 |
93 | # create logger
94 | logger = Logger()
95 | fold_name = trainer_params["checkpoint_path"].split("/")[-1].split(".")[0]
96 | folder_path = "/".join(trainer_params["checkpoint_path"].split("/")[:-1])
97 | logging = create_exp_dir(f"{folder_path}/experiments/{fold_name}")
98 | # log paramters
99 | logging(f"Training Parameters:\n {trainer_params}")
100 | # logging time
101 | current_time = datetime.datetime.now()
102 | logging(str(current_time))
103 | # log model
104 | logging(str(model))
105 | logging(f"Total of Parameters: {sum(p.numel() for p in model.parameters())}")
106 | logging(
107 | f"Total of Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
108 | )
109 | # load check points
110 |
111 | logging("=" * 100)
112 | logging(
113 | "==== loading pretrained model from {} ====".format(
114 | trainer_params["pretrained_weight"]
115 | )
116 | )
117 | logging("=" * 100)
118 |
119 | # Load the best saved model.
120 | if not trainer_params["full_eval_mode"]:
121 | with open(trainer_params["pretrained_weight"], "rb") as f:
122 | pretrained_model = torch.load(f)
123 | # pdb.set_trace()
124 | pretrained_model_checkpoint = pretrained_model["model"] # .state_dict()
125 | filtered_checkpoint = {}
126 | for key in pretrained_model_checkpoint.keys():
127 | if not key in model.state_dict():
128 | logging("Can not load {}".format(key))
129 | elif (
130 | not pretrained_model_checkpoint[key].shape
131 | == model.state_dict()[key].shape
132 | ):
133 | logging("Can not load {}, shape do not match".format(key))
134 | else:
135 | filtered_checkpoint[key] = pretrained_model_checkpoint[key]
136 |
137 | model.load_state_dict(filtered_checkpoint, strict=False)
138 | iter_init = 0
139 | else:
140 | # resume training from last checkpoint if exists
141 | iter_init = load_checkpoint(
142 | trainer_params["checkpoint_path"],
143 | model,
144 | optimizer,
145 | scheduler,
146 | logger,
147 | distributed,
148 | )
149 |
150 | # fix gate
151 | if model_params["smoe_dropout"]:
152 | freeze_gate_weight(model)
153 | # calculate time
154 | start_time = time.time()
155 | # eval model
156 | if trainer_params["full_eval_mode"]:
157 | # evaluate the model on test data
158 | with torch.no_grad():
159 | loss_val, acc_val = full_eval(
160 | model,
161 | optimizer,
162 | scheduler,
163 | val_data,
164 | model_params["block_size"],
165 | model_params["hidden_size"],
166 | )
167 | loss_test, acc_test = full_eval(
168 | model,
169 | optimizer,
170 | scheduler,
171 | test_data,
172 | model_params["block_size"],
173 | model_params["hidden_size"],
174 | )
175 | if distributed:
176 | # collect results into rank0
177 | stats = torch.tensor([loss_val, loss_test]).to(device)
178 | torch.distributed.reduce(stats, 0)
179 | if env_params["rank"] == 0:
180 | loss_val = stats[0] / env_params["world_size"]
181 | loss_test = stats[1] / env_params["world_size"]
182 | else:
183 | return
184 |
185 | # log accuracy score
186 | logging("Val: {:.3f} Acc".format(acc_val))
187 | logging("Test: {:.3f} Acc".format(acc_test))
188 |
189 | return
190 |
191 | # position of current batch
192 | data_pos = [0] * 2
193 | # initialize caches for train and valid
194 | hid_cache = [
195 | [
196 | torch.zeros(
197 | train_data.bsz,
198 | model.module.layers[layer_i].attn.attn.get_cache_size(),
199 | model_params["hidden_size"],
200 | ).to(device)
201 | for layer_i in range(model.module.attn_layer_count)
202 | ]
203 | for _ in range(2)
204 | ]
205 |
206 | nb_batches_per_iter = trainer_params["nb_batches_per_iter"]
207 | for iter_no in range(iter_init, trainer_params["nb_iter"]):
208 | t_sta = time.time()
209 | loss_train, acc_train, data_pos[0], hid_cache[0] = train_iteration(
210 | model,
211 | model_params["load_balance"],
212 | optimizer,
213 | scheduler,
214 | train_data,
215 | nb_batches_per_iter,
216 | model_params["block_size"],
217 | False,
218 | data_pos[0],
219 | hid_cache[0],
220 | trainer_params["batch_split"],
221 | trainer_params["checkpoint_path"],
222 | )
223 | elapsed = 1000 * (time.time() - t_sta) / nb_batches_per_iter
224 | with torch.no_grad():
225 | loss_val, acc_val, data_pos[1], hid_cache[1] = train_iteration(
226 | model,
227 | model_params["load_balance"],
228 | optimizer,
229 | scheduler,
230 | val_data,
231 | nb_batches_per_iter,
232 | model_params["block_size"],
233 | True,
234 | data_pos[1],
235 | hid_cache[1],
236 | trainer_params["batch_split"],
237 | trainer_params["checkpoint_path"],
238 | )
239 |
240 | if distributed:
241 | # collect results into rank0
242 | stats = torch.tensor([loss_train, loss_val]).to(device)
243 | torch.distributed.reduce(stats, 0)
244 | if env_params["rank"] == 0:
245 | loss_train = stats[0] / env_params["world_size"]
246 | loss_val = stats[1] / env_params["world_size"]
247 | else:
248 | continue
249 | logging(f"=================== EPOCHS {iter_no} ======================")
250 | # if ('enwik8' in data_params['data_path']) or ('text8' in data_params['data_path']):
251 | msg_result = "Epochs: {} | loss_train: {:.3f} ~ {:.3f} Acc | loss_val: {:.3f} ~ {:.3f} Acc | elapsed: {:.1f}".format(
252 | iter_no, loss_train, acc_train, loss_val, acc_val, elapsed
253 | )
254 | logging(msg_result)
255 | # Save the model if the validation loss is the best we've seen so far.
256 | if (best_val_loss is None) or loss_val < best_val_loss:
257 | best_val_loss = loss_val
258 | save_checkpoint(
259 | trainer_params["checkpoint_path"],
260 | iter_no,
261 | model,
262 | optimizer,
263 | scheduler,
264 | logger,
265 | )
266 | # save_checkpoint(trainer_params['checkpoint_path'], nb_batches_per_iter, model, optimizer, scheduler, logger)
267 | end_time = time.time()
268 | logging(f"Training time total: {(end_time - start_time)/3600} h")
269 |
270 |
271 | if __name__ == "__main__":
272 | launch(**get_params(params_config=PARAMS_CONFIG))
273 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import warnings
3 |
4 | warnings.filterwarnings("ignore")
5 |
6 | import argparse
7 | import math, random
8 | import torch
9 | import time
10 |
11 | from config import PARAMS_CONFIG
12 | from data import get_train_val_test_data
13 | from models import TransformerSeq
14 | from trainer import train_iteration, full_eval
15 | import datetime
16 | import wandb
17 | import os
18 | from utils import (
19 | get_params,
20 | set_up_env,
21 | get_optimizer_and_scheduler,
22 | load_checkpoint,
23 | save_checkpoint,
24 | create_exp_dir,
25 | freeze_gate_weight,
26 | Logger,
27 | set_freq_optimal_search,
28 | )
29 |
30 |
31 | def launch(
32 | env_params,
33 | model_params,
34 | adapt_span_params,
35 | optim_params,
36 | data_params,
37 | trainer_params,
38 | wandb_params,
39 | ):
40 | wandb_flag = wandb_params["wandb_flag"]
41 | if wandb_flag:
42 | wandb.init(project=wandb_params["project_name"])
43 | wandb.run.name = wandb_params["job_name"]
44 | wandb.config.update(model_params)
45 | # global val
46 | best_val_loss = None
47 | # ENVIRONMENT (device, distributed, etc.)
48 | set_up_env(env_params)
49 | device = env_params["device"]
50 | distributed = env_params["distributed"]
51 | resume = trainer_params["resume"]
52 |
53 | if distributed == False or env_params["rank"] == 0:
54 | print("data_params:\t", data_params)
55 | print("model_params:\t", model_params)
56 | print("optim_params:\t", optim_params)
57 | print("trainer_params:\t", trainer_params)
58 | print("adapt_span_params:\t", adapt_span_params)
59 |
60 | # DATA
61 | train_data, val_data, test_data = get_train_val_test_data(
62 | data_params=data_params,
63 | env_params=env_params,
64 | batch_size=trainer_params["batch_size"],
65 | device=device,
66 | )
67 |
68 | # MODEL
69 | model = TransformerSeq(
70 | vocab_size=data_params["vocab_size"],
71 | **model_params,
72 | adapt_span_params=adapt_span_params,
73 | )
74 | print(model)
75 | if distributed:
76 | local_rank = env_params["local_rank"]
77 | model = model.to(device)
78 | model = torch.nn.parallel.DistributedDataParallel(
79 | model,
80 | device_ids=[local_rank],
81 | output_device=local_rank,
82 | find_unused_parameters=True,
83 | )
84 | else:
85 | model = torch.nn.DataParallel(model)
86 | model = model.to(device)
87 |
88 | # OPTIMIZER AND SCHEDULER
89 | optimizer, scheduler = get_optimizer_and_scheduler(
90 | model=model, optim_params=optim_params
91 | )
92 |
93 | # create logger
94 | logger = Logger()
95 | fold_name = trainer_params["checkpoint_path"].split("/")[-1].split(".")[0]
96 | folder_path = "/".join(trainer_params["checkpoint_path"].split("/")[:-1])
97 | logging = create_exp_dir(f"{folder_path}/experiments/{fold_name}")
98 | # log paramters
99 | logging(f"Training Parameters:\n {trainer_params}")
100 | logging(f"Models Parameters:\n {model_params}")
101 | # logging time
102 | current_time = datetime.datetime.now()
103 | logging(str(current_time))
104 | # log model
105 | logging(str(model))
106 | logging(f"Total of Parameters: {sum(p.numel() for p in model.parameters())}")
107 | logging(
108 | f"Total of Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
109 | )
110 | # resume training from last checkpoint if exists
111 | iter_init = load_checkpoint(
112 | trainer_params["checkpoint_path"],
113 | model,
114 | optimizer,
115 | scheduler,
116 | logger,
117 | distributed,
118 | resume,
119 | )
120 | # fix gate
121 | if model_params["smoe_dropout"]:
122 | freeze_gate_weight(model)
123 | # calculate time
124 | start_time = time.time()
125 | # eval model
126 | if trainer_params["full_eval_mode"]:
127 | # evaluate the model on test data
128 | with torch.no_grad():
129 | loss_val = full_eval(
130 | model,
131 | optimizer,
132 | scheduler,
133 | val_data,
134 | model_params["block_size"],
135 | model_params["hidden_size"],
136 | )
137 | loss_test = full_eval(
138 | model,
139 | optimizer,
140 | scheduler,
141 | test_data,
142 | model_params["block_size"],
143 | model_params["hidden_size"],
144 | )
145 | if distributed:
146 | # collect results into rank0
147 | stats = torch.tensor([loss_val, loss_test]).to(device)
148 | torch.distributed.reduce(stats, 0)
149 | if env_params["rank"] == 0:
150 | loss_val = stats[0] / env_params["world_size"]
151 | loss_test = stats[1] / env_params["world_size"]
152 | else:
153 | return
154 |
155 | # print('Test BPC: {:.4f}'.format(loss_test / math.log(2)))
156 | if ("enwik8" in data_params["data_path"]) or (
157 | "text8" in data_params["data_path"]
158 | ):
159 | logging("Val: {:.3f} BPC".format(loss_val / math.log(2)))
160 | logging("Test: {:.3f} BPC".format(loss_test / math.log(2)))
161 | else:
162 | logging("Val: {:.3f} PPL".format(math.exp(loss_val)))
163 | logging("Test: {:.3f} PPL".format(math.exp(loss_test)))
164 | return
165 |
166 | # position of current batch
167 | data_pos = [0] * 2
168 | # initialize caches for train and valid
169 | hid_cache = [
170 | [
171 | torch.zeros(
172 | train_data.size(0),
173 | model.module.layers[layer_i].attn.attn.get_cache_size(),
174 | model_params["hidden_size"],
175 | ).to(device)
176 | for layer_i in range(model.module.attn_layer_count)
177 | ]
178 | for _ in range(2)
179 | ]
180 |
181 | nb_batches_per_iter = trainer_params["nb_batches_per_iter"]
182 | for iter_no in range(iter_init, trainer_params["nb_iter"]):
183 | # freq type
184 | if model_params["freq_type"] == "function":
185 | _threshold = 2.0 / (2.0 + math.sqrt((iter_no + 1)))
186 | set_freq_optimal_search(model, _threshold)
187 |
188 | # time storing
189 | t_sta = time.time()
190 | loss_train, data_pos[0], hid_cache[0] = train_iteration(
191 | model,
192 | model_params["load_balance"],
193 | optimizer,
194 | scheduler,
195 | train_data,
196 | nb_batches_per_iter,
197 | model_params["block_size"],
198 | False,
199 | data_pos[0],
200 | hid_cache[0],
201 | trainer_params["batch_split"],
202 | trainer_params["checkpoint_path"],
203 | )
204 | elapsed = 1000 * (time.time() - t_sta) / nb_batches_per_iter
205 | with torch.no_grad():
206 | loss_val, data_pos[1], hid_cache[1] = train_iteration(
207 | model,
208 | model_params["load_balance"],
209 | optimizer,
210 | scheduler,
211 | val_data,
212 | nb_batches_per_iter,
213 | model_params["block_size"],
214 | True,
215 | data_pos[1],
216 | hid_cache[1],
217 | trainer_params["batch_split"],
218 | trainer_params["checkpoint_path"],
219 | )
220 |
221 | if distributed:
222 | # collect results into rank0
223 | stats = torch.tensor([loss_train, loss_val]).to(device)
224 | torch.distributed.reduce(stats, 0)
225 | if env_params["rank"] == 0:
226 | loss_train = stats[0] / env_params["world_size"]
227 | loss_val = stats[1] / env_params["world_size"]
228 | else:
229 | continue
230 | logging(f"=================== EPOCHS {iter_no} ======================")
231 | if ("enwik8" in data_params["data_path"]) or (
232 | "text8" in data_params["data_path"]
233 | ):
234 | msg_result = "Epochs: {} | loss_train: {:.3f} ~ {:.3f} BPC | loss_val: {:.3f} ~ {:.3f} BPC | elapsed: {:.1f}".format(
235 | iter_no,
236 | loss_train,
237 | float(loss_train / math.log(2)),
238 | loss_val,
239 | float(loss_val / math.log(2)),
240 | elapsed,
241 | )
242 | else:
243 | msg_result = "Epochs: {} | loss_train: {:.3f} ~ {:.3f} PPL | loss_val: {:.3f} ~ {:.3f} PPL | elapsed: {:.1f}".format(
244 | iter_no,
245 | loss_train,
246 | float(math.exp(loss_train)),
247 | loss_val,
248 | float(math.exp(loss_val)),
249 | elapsed,
250 | )
251 | logging(msg_result)
252 | if wandb_flag:
253 | wandb.log({'train_ppl':float(math.exp(loss_train)),'Epoch':iter_no,'valid_ppl':float(math.exp(loss_val))})
254 | logger.log_iter(iter_no, nb_batches_per_iter, loss_train, loss_val, elapsed, model)
255 | # Save the model if the validation loss is the best we've seen so far.
256 | if (best_val_loss is None) or loss_val < best_val_loss:
257 | best_val_loss = loss_val
258 | save_checkpoint(
259 | trainer_params["checkpoint_path"],
260 | iter_no,
261 | model,
262 | optimizer,
263 | scheduler,
264 | logger,
265 | )
266 | # save_checkpoint(trainer_params['checkpoint_path'], nb_batches_per_iter, model, optimizer, scheduler, logger)
267 | end_time = time.time()
268 | logging(f"Training time total: {(end_time - start_time)/3600} h")
269 |
270 |
271 | if __name__ == "__main__":
272 | launch(**get_params(params_config=PARAMS_CONFIG))
273 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import argparse
3 | import math, random
4 | import torch
5 | import tqdm
6 |
7 | PARAMS_CONFIG = {
8 | # env-specific
9 | "env_params": {
10 | "--distributed": {
11 | "action": "store_true",
12 | "default": False,
13 | "help": "enable distributed training."
14 | "(otherwise will use all available GPUs with dataparallel)",
15 | "dest": "distributed",
16 | },
17 | "--local_rank": {
18 | "type": int,
19 | "default": int(os.environ['LOCAL_RANK']),
20 | # "default": 0,
21 | "help": "used in distributed training",
22 | "dest": "local_rank",
23 | },
24 | },
25 | # data-specific
26 | "data_params": {
27 | "--data": {
28 | "type": str,
29 | "default": "data/text8",
30 | "help": "data location " "(must contain train.txt, valid.txt and test.txt)",
31 | "dest": "data_path",
32 | },
33 | "--data_name": {
34 | "type": str,
35 | "default": "text8",
36 | "help": "The name of dataset",
37 | "dest": "data_name",
38 | },
39 | },
40 | # model-specific
41 | "model_params": {
42 | "--hid-sz": {
43 | "type": int,
44 | "default": 256,
45 | "help": "hidden size (i.e. model size)",
46 | "dest": "hidden_size",
47 | },
48 | "--inner-hid-sz": {
49 | "type": int,
50 | "default": 1024,
51 | "help": "inner hidden size of FF layer",
52 | "dest": "inner_hidden_size",
53 | },
54 | "--nlayers": {
55 | "type": int,
56 | "default": 8,
57 | "help": "number of layers",
58 | "dest": "nb_layers",
59 | },
60 | "--block-sz": {
61 | "type": int,
62 | "default": 64,
63 | "help": "block size " "(the length of sequence to process in parallel)",
64 | "dest": "block_size",
65 | },
66 | "--nheads": {
67 | "type": int,
68 | "default": 2,
69 | "help": "number of self-attention heads",
70 | "dest": "nb_heads",
71 | },
72 | "--attn-span": {
73 | "type": int,
74 | "default": 32,
75 | "help": "length of the attention span",
76 | "dest": "attn_span",
77 | },
78 | "--dropout": {
79 | "type": float,
80 | "default": 0.2,
81 | "help": "dropout rate of ReLU and attention",
82 | "dest": "dropout",
83 | },
84 | "--architecture": {
85 | "type": str,
86 | "default": None,
87 | "help": "arch",
88 | "dest": "architecture",
89 | },
90 | "--base_arch": {
91 | "type": str,
92 | "default": None,
93 | "help": "arch",
94 | "dest": "base_arch",
95 | },
96 | "--smoe_dropout": {
97 | "action": "store_true",
98 | "default": False,
99 | "help": "enable SMoE-drop - Freeze gate",
100 | "dest": "smoe_dropout",
101 | },
102 | "--optimal_policy": {
103 | "action": "store_true",
104 | "default": False,
105 | "help": "Searching the best routing policy",
106 | "dest": "optimal_policy",
107 | },
108 | "--load_balance": {
109 | "type": float,
110 | "default": 1.0,
111 | "help": "Ratio of blance loss",
112 | "dest": "load_balance",
113 | },
114 | "--moe_top_k": {
115 | "type": int,
116 | "default": 2,
117 | "help": "Number of activate experts",
118 | "dest": "moe_top_k",
119 | },
120 | "--freq": {
121 | "type": float,
122 | "default": 0.03,
123 | "help": "Frequent for searching optimal policy",
124 | "dest": "freq",
125 | },
126 | "--freq_type": {
127 | "type": str,
128 | "default": "fix",
129 | "help": "Type of frequent for searching optimal policy. Choice: fix or function",
130 | "dest": "freq_type",
131 | },
132 | "--alpha": {
133 | "type": float,
134 | "default": 1.0,
135 | "help": "Impact of optimal loss",
136 | "dest": "alpha",
137 | },
138 | "--gate_name": {
139 | "type": str,
140 | "default": "smoe",
141 | "help": "Names of gates: smoe, smoe-dropout, xmoe, stablemoe",
142 | "dest": "gate_name",
143 | },
144 | "--act_experts": {
145 | "type": str,
146 | "default": "shuffle",
147 | "help": "Type to activate all experts: shuffle OR linear",
148 | "dest": "act_experts",
149 | },
150 | "--g_blance": {
151 | "action": "store_true",
152 | "default": False,
153 | "help": "Activate balance loss for router",
154 | "dest": "g_blance",
155 | },
156 | "--opt_blance": {
157 | "action": "store_true",
158 | "default": False,
159 | "help": "Activate blancing for optimal router",
160 | "dest": "opt_blance",
161 | },
162 | "--combine_gate": {
163 | "action": "store_true",
164 | "default": False,
165 | "help": "Utilize previous information for better consistancy",
166 | "dest": "combine_gate",
167 | },
168 | "--opt_loss": {
169 | "type": str,
170 | "default": "mse",
171 | "help": "Type of loss for optimal policy searching",
172 | "dest": "opt_loss",
173 | },
174 | "--gamma1": {
175 | "type": float,
176 | "default": 1.0,
177 | "help": "Adam decay parameter",
178 | "dest": "gamma1",
179 | },
180 | "--gamma2": {
181 | "type": float,
182 | "default": 1.0,
183 | "help": "Momentum learning rate",
184 | "dest": "gamma2",
185 | },
186 | "--mu": {
187 | "type": float,
188 | "default": 0.9,
189 | "help": "Momentum parameter",
190 | "dest": "mu",
191 | },
192 | "--beta1": {
193 | "type": float,
194 | "default": 0.9,
195 | "help": "ADAM parameter",
196 | "dest": "beta1",
197 | },
198 | "--beta2": {
199 | "type": float,
200 | "default": 0.999,
201 | "help": "ADAM parameter",
202 | "dest": "beta2",
203 | },
204 | },
205 | # optimization-specific
206 | "optim_params": {
207 | "--lr": {"type": float, "default": 0.03, "help": "learning rate", "dest": "lr"},
208 | "--momentum": {
209 | "type": float,
210 | "default": 0.9,
211 | "help": "SGD momentum",
212 | "dest": "momentum",
213 | },
214 | "--optim": {
215 | "type": str,
216 | "default": "sgd",
217 | "help": "optimization method: sgd | adagrad",
218 | "dest": "optim",
219 | },
220 | "--lr-warmup": {
221 | "type": int,
222 | "default": 0,
223 | "help": "linearly increase LR from 0 " "during first lr_warmup updates",
224 | "dest": "lr_warmup",
225 | },
226 | "--grad-clip": {
227 | "type": float,
228 | "default": 0,
229 | "help": "[only works with adagrad!] "
230 | "clip gradient of each module parameters by a given "
231 | "value",
232 | "dest": "grad_clip",
233 | },
234 | },
235 | # trainer-specific
236 | "trainer_params": {
237 | "--batch-sz": {
238 | "type": int,
239 | "default": 64,
240 | "help": "batch size",
241 | "dest": "batch_size",
242 | },
243 | "--batch-split": {
244 | "type": int,
245 | "default": 1,
246 | "help": "split a batch into smaller parts to fit in GPU memory",
247 | "dest": "batch_split",
248 | },
249 | "--nbatches": {
250 | "type": int,
251 | "default": 1000,
252 | "help": "number of batches in each iteration",
253 | "dest": "nb_batches_per_iter",
254 | },
255 | "--niter": {
256 | "type": int,
257 | "default": 1000,
258 | "help": "number of iterations to train",
259 | "dest": "nb_iter",
260 | },
261 | "--checkpoint": {
262 | "type": str,
263 | "default": "",
264 | "help": "path to save/load model",
265 | "dest": "checkpoint_path",
266 | },
267 | "--resume": {
268 | "action": "store_true",
269 | "default": False,
270 | "help": "resume training",
271 | "dest": "resume",
272 | },
273 | "--pretrained_weight": {
274 | "type": str,
275 | "default": "",
276 | "help": "path to save/load model",
277 | "dest": "pretrained_weight",
278 | },
279 | "--full-eval-mode": {
280 | "action": "store_true",
281 | "default": False,
282 | "help": "do evaluation on the whole validation and the test data",
283 | "dest": "full_eval_mode",
284 | },
285 | },
286 | # adaptive attention span specific params
287 | "adapt_span_params": {
288 | "--adapt-span": {
289 | "action": "store_true",
290 | "default": False,
291 | "help": "enable adaptive attention span",
292 | "dest": "adapt_span_enabled",
293 | },
294 | "--adapt-span-loss": {
295 | "type": float,
296 | "default": 0,
297 | "help": "the loss coefficient for span lengths",
298 | "dest": "adapt_span_loss",
299 | },
300 | "--adapt-span-ramp": {
301 | "type": int,
302 | "default": 32,
303 | "help": "ramp length of the soft masking function",
304 | "dest": "adapt_span_ramp",
305 | },
306 | "--adapt-span-init": {
307 | "type": float,
308 | "default": 0,
309 | "help": "initial attention span ratio",
310 | "dest": "adapt_span_init",
311 | },
312 | "--adapt-span-cache": {
313 | "action": "store_true",
314 | "default": False,
315 | "help": "adapt cache size as well to reduce memory usage",
316 | "dest": "adapt_span_cache",
317 | },
318 | },
319 | "wandb_params": {
320 | "--project-name": {
321 | "type": str,
322 | "default": "project_name",
323 | "help": "wandb project name",
324 | "dest": "project_name",
325 | },
326 | "--job-name": {
327 | "type": str,
328 | "default": "job_name",
329 | "help": "wandb job name",
330 | "dest": "job_name",
331 | },
332 | "--wandb-flag": {
333 | "action": "store_true",
334 | "default": False,
335 | "help": "use wandb",
336 | "dest": "wandb_flag",
337 | },
338 | },
339 | }
340 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/custom_layers.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import argparse
3 | import math, random
4 | import torch
5 | import torch.nn as nn
6 | import tree
7 | from custom_functions import prepare_forward, ensure_comm
8 | from custom_functions import MOEScatter, MOEGather
9 | from custom_functions import AllGather, Slice
10 | from gates import NaiveGate
11 |
12 | from fastermoe.config import switch_from_env
13 |
14 |
15 | def mark_module_parallel_comm(module, comm):
16 | r"""
17 | Mark all parameters in `module` as doing data parallel in `comm`, where
18 | `comm` may be one of `'world', 'dp', 'none'`.
19 | """
20 | for p in module.parameters():
21 | setattr(p, "dp_comm", comm)
22 |
23 |
24 | def _fmoe_general_global_forward(
25 | inp, gate, expert_fn, num_expert, world_size, **kwargs
26 | ):
27 | r"""
28 | A private function that performs the following steps to complete the MoE
29 | computation.
30 | * Count the number of tokens from each worker to each expert.
31 | * Send the features to their target position so that input features to each
32 | expert are contiguous in memory.
33 | * Perform the forward computation of the experts using `expert_fn`
34 | * Gather the output features of experts back, and reorder them as sentences.
35 | Intermediate results like expert counts are hidden from users by this
36 | function.
37 | """
38 | (
39 | pos,
40 | local_expert_count,
41 | global_expert_count,
42 | fwd_expert_count,
43 | fwd_batch_size,
44 | ) = prepare_forward(gate, num_expert, world_size)
45 | topk = 1
46 | if len(gate.shape) == 2:
47 | topk = gate.shape[1]
48 |
49 | def scatter_func(tensor):
50 | return MOEScatter.apply(
51 | tensor,
52 | torch.div(pos, topk, rounding_mode="floor"),
53 | local_expert_count,
54 | global_expert_count,
55 | fwd_batch_size,
56 | world_size,
57 | )
58 |
59 | x = tree.map_structure(scatter_func, inp)
60 |
61 | x = expert_fn(x, fwd_expert_count)
62 |
63 | out_batch_size = tree.flatten(inp)[0].shape[0]
64 | if len(gate.shape) == 2:
65 | out_batch_size *= gate.shape[1]
66 |
67 | def gather_func(tensor):
68 | return MOEGather.apply(
69 | tensor,
70 | pos,
71 | local_expert_count,
72 | global_expert_count,
73 | out_batch_size,
74 | world_size,
75 | )
76 |
77 | outp = tree.map_structure(gather_func, x)
78 | return outp
79 |
80 |
81 | fmoe_faster_schedule = False
82 | if switch_from_env("FMOE_FASTER_SCHEDULE_ENABLE", False):
83 | fmoe_faster_schedule = True
84 | from .fastermoe.schedule import _fmoe_general_global_forward
85 |
86 |
87 | class FMoE(nn.Module):
88 | r"""
89 | A general moe implementation that supports an arbitrary module as the
90 | expert.
91 | * `num_expert` stands for the number of experts on **each** worker.
92 | * `world_size` stands for the total number of workers that contains
93 | different experts.
94 | * `slice_group` can be a torch's communication group, indicating that
95 | specific model parallel is applied across the group, and workers in the
96 | group hold the same copy of input feature, and requires the same copy of
97 | the output. For each worker, FMoE only computes the output of a certain
98 | slice of the input batch, and will all-gather the outputs after
99 | computation.
100 | * `top_k` stands for the number of experts each token is going to.
101 | * `gate` is a gate class which can found in `fmoe.gates`.
102 | * `expert` can be specified as a module class, it is used to generate
103 | `num_expert` expert modules.
104 | """
105 |
106 | def __init__(
107 | self,
108 | num_expert=32,
109 | d_model=1024,
110 | world_size=1,
111 | mp_group=None, # being deprecated
112 | slice_group=None,
113 | moe_group=None,
114 | moe_top_k=2,
115 | gate=NaiveGate,
116 | expert=None,
117 | gate_hook=None,
118 | mask=None,
119 | mask_dict=None,
120 | ):
121 | super().__init__()
122 | self.num_expert = num_expert
123 | self.d_model = d_model
124 | self.world_size = world_size
125 |
126 | self.slice_group = slice_group
127 | if mp_group is not None:
128 | print("[Warning] mp_group is being deprecated")
129 | self.slice_group = mp_group
130 | if self.slice_group is None:
131 | self.slice_size = 1
132 | self.slice_rank = 0
133 | else:
134 | self.slice_size = self.slice_group.size()
135 | self.slice_rank = self.slice_group.rank()
136 |
137 | self.top_k = moe_top_k
138 | if type(expert) is list:
139 | self.experts = nn.ModuleList([e(d_model) for e in expert])
140 | self.experts_fused = False
141 | self.num_expert = num_expert = len(expert)
142 | elif expert is not None:
143 | self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)])
144 | self.experts_fused = False
145 | else:
146 | self.experts_fused = True
147 |
148 | self.gate = gate(d_model, num_expert, world_size, moe_top_k)
149 | self.gate_hook = gate_hook
150 | self.mask = mask
151 | self.mask_dict = mask_dict
152 | self.moe_group = moe_group
153 |
154 | def expert_fn(self, inp, fwd_expert_count):
155 | r"""
156 | The default expert function which either calls the experts as a whole
157 | or as separate experts.
158 | """
159 | if self.experts_fused:
160 | return self.experts(inp, fwd_expert_count)
161 | if isinstance(fwd_expert_count, torch.Tensor):
162 | fwd_expert_count = fwd_expert_count.cpu().numpy()
163 | outputs = []
164 | base_idx = 0
165 | for i in range(self.num_expert):
166 | batch_size = fwd_expert_count[i]
167 | inp_slice = inp[base_idx : base_idx + batch_size]
168 | outputs.append(self.experts[i](inp_slice))
169 | base_idx += batch_size
170 | return torch.cat(outputs, dim=0)
171 |
172 | def mark_parallel_comm(self, expert_dp_comm="none"):
173 | r"""
174 | Automatically mark the data parallel comms of the parameters within the
175 | module. This can be typically called at the end of the __init__ function
176 | in child classes.
177 | """
178 | if self.experts is not None:
179 | comm = expert_dp_comm
180 | if isinstance(self.experts, list):
181 | for e in self.experts:
182 | mark_module_parallel_comm(e, comm)
183 | else:
184 | mark_module_parallel_comm(self.experts, comm)
185 | mark_module_parallel_comm(self.gate, "gate")
186 |
187 | def forward(self, moe_inp):
188 | r"""
189 | The FMoE module first computes gate output, and then conduct MoE forward
190 | according to the gate. The score of the selected gate given by the
191 | expert is multiplied to the experts' output tensors as a weight.
192 | """
193 |
194 | moe_inp_batch_size = tree.flatten(
195 | tree.map_structure(lambda tensor: tensor.shape[0], moe_inp)
196 | )
197 | assert all(
198 | [batch_size == moe_inp_batch_size[0] for batch_size in moe_inp_batch_size]
199 | ), "MoE inputs must have the same batch size"
200 |
201 | if self.world_size > 1:
202 |
203 | def ensure_comm_func(tensor):
204 | ensure_comm(tensor, self.moe_group)
205 |
206 | tree.map_structure(ensure_comm_func, moe_inp)
207 | if self.slice_size > 1:
208 |
209 | def slice_func(tensor):
210 | return Slice.apply(
211 | tensor, self.slice_rank, self.slice_size, self.slice_group
212 | )
213 |
214 | moe_inp = tree.map_structure(slice_func, moe_inp)
215 |
216 | gate_top_k_idx, gate_score = self.gate(moe_inp)
217 |
218 | if hasattr(self.gate, "dynamic_top_k"):
219 | self.top_k = self.gate.dynamic_top_k
220 |
221 | if self.gate_hook is not None:
222 | self.gate_hook(gate_top_k_idx, gate_score, None)
223 |
224 | # delete masked tensors
225 | if self.mask is not None and self.mask_dict is not None:
226 | # TODO: to fix
227 | def delete_mask_func(tensor):
228 | # to: (BxL') x d_model
229 | tensor = tensor[mask == 0, :]
230 | return tensor
231 |
232 | mask = self.mask.view(-1)
233 | moe_inp = tree.map_structure(delete_mask_func, moe_inp)
234 | gate_top_k_idx = gate_top_k_idx[mask == 0, :]
235 |
236 | fwd = _fmoe_general_global_forward(
237 | moe_inp,
238 | gate_top_k_idx,
239 | self.expert_fn,
240 | self.num_expert,
241 | self.world_size,
242 | experts=self.experts,
243 | )
244 |
245 | # recover deleted tensors
246 | if self.mask is not None and self.mask_dict is not None:
247 |
248 | def recover_func(tensor):
249 | # to: (BxL') x top_k x dim
250 | dim = tensor.shape[-1]
251 | tensor = tensor.view(-1, self.top_k, dim)
252 | # to: (BxL) x top_k x d_model
253 | x = torch.zeros(
254 | mask.shape[0],
255 | self.top_k,
256 | dim,
257 | device=tensor.device,
258 | dtype=tensor.dtype,
259 | )
260 | # recover
261 | x[mask == 0] = tensor
262 | for k, v in self.mask_dict.items():
263 | x[mask == k] = v
264 | return x
265 |
266 | moe_outp = tree.map_structure(recover_func, fwd)
267 | else:
268 |
269 | def view_func(tensor):
270 | dim = tensor.shape[-1]
271 | tensor = tensor.view(-1, self.top_k, dim)
272 | return tensor
273 |
274 | moe_outp = tree.map_structure(view_func, fwd)
275 |
276 | gate_score = gate_score.view(-1, 1, self.top_k)
277 |
278 | def bmm_func(tensor):
279 | dim = tensor.shape[-1]
280 | tensor = torch.bmm(gate_score, tensor).reshape(-1, dim)
281 | return tensor
282 |
283 | moe_outp = tree.map_structure(bmm_func, moe_outp)
284 |
285 | if self.slice_size > 1:
286 |
287 | def all_gather_func(tensor):
288 | return AllGather.apply(
289 | tensor, self.slice_rank, self.slice_size, self.slice_group
290 | )
291 |
292 | moe_outp = tree.map_structure(all_gather_func, moe_outp)
293 |
294 | moe_outp_batch_size = tree.flatten(
295 | tree.map_structure(lambda tensor: tensor.shape[0], moe_outp)
296 | )
297 | assert all(
298 | [batch_size == moe_outp_batch_size[0] for batch_size in moe_outp_batch_size]
299 | ), "MoE outputs must have the same batch size"
300 | return moe_outp
301 |
302 |
303 | ##############################################################################
304 |
305 | import torch
306 | import torch.nn as nn
307 | import math
308 | import fmoe_cuda
309 | from torch.autograd import Function
310 |
311 |
312 | class MOELinear(Function):
313 | r"""
314 | Computes linear operators within one GPU on different experts simutaneously.
315 | """
316 |
317 | @staticmethod
318 | def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
319 | global_output_buf = fmoe_cuda.linear_forward(
320 | global_input_buf, fwd_expert_count, weight, bias
321 | )
322 | variables = (global_input_buf, fwd_expert_count, weight, bias)
323 | ctx.save_for_backward(*variables)
324 | return global_output_buf
325 |
326 | @staticmethod
327 | def backward(ctx, grad_out):
328 | (input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
329 | grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
330 | grad_out, input_buf, fwd_expert_count, weight, bias
331 | )
332 |
333 | if not torch.is_tensor(bias):
334 | grad_bias = None
335 |
336 | return grad_inp_buf, None, grad_weight, grad_bias
337 |
338 |
339 | class FMoELinear(nn.Module):
340 | r"""
341 | A linear layer that contains multiple experts.
342 | As multiple experts can be placed on the same worker, the computation can be
343 | performed in parallel to increase the performance.
344 | The FMoELinear module provides such function.
345 | """
346 |
347 | def __init__(
348 | self,
349 | num_expert: int,
350 | in_feat: int,
351 | out_feat: int,
352 | bias: bool = True,
353 | rank: int = 0,
354 | ):
355 | super().__init__()
356 | self.num_expert = num_expert
357 | self.in_feat = in_feat
358 | self.out_feat = out_feat
359 | self.rank = rank
360 | self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
361 | if bias:
362 | self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
363 | else:
364 | self.register_parameter("bias", None)
365 |
366 | self.reset_parameters()
367 |
368 | def forward(self, inp, fwd_expert_count):
369 | r"""
370 | Call MOE function
371 | """
372 | x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
373 | return x
374 |
375 | def extra_repr(self) -> str:
376 | return "num_expert={}, in_features={}, \
377 | out_features={}, bias={}, rank={}".format(
378 | self.num_expert,
379 | self.in_feat,
380 | self.out_feat,
381 | self.bias is not None,
382 | self.rank,
383 | )
384 |
385 | def reset_parameters(self):
386 | # Approach is the same as in torch.nn.Linear
387 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
388 | # bias is left to zero, similar as megatron
389 |
390 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
391 |
--------------------------------------------------------------------------------
/vocabulary.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import csv
4 | import json
5 | import torch
6 | from collections import Counter, OrderedDict
7 | import torch.nn.functional as F
8 | from torch.nn.utils.rnn import pad_sequence
9 |
10 |
11 | class Vocab(object):
12 | def __init__(
13 | self,
14 | special=[],
15 | min_freq=0,
16 | max_size=None,
17 | lower_case=True,
18 | delimiter=None,
19 | vocab_file=None,
20 | ):
21 | self.counter = Counter()
22 | self.special = special
23 | self.min_freq = min_freq
24 | self.max_size = max_size
25 | self.lower_case = lower_case
26 | self.delimiter = delimiter
27 | self.vocab_file = vocab_file
28 |
29 | def tokenize(
30 | self,
31 | line,
32 | add_eos=False,
33 | add_double_eos=False,
34 | add_cls_token=False,
35 | add_s=False,
36 | add_cls_token_last=False,
37 | ):
38 | line = line.strip()
39 | # convert to lower case
40 | if self.lower_case:
41 | line = line.lower()
42 |
43 | # empty delimiter '' will evaluate False
44 | if self.delimiter == "":
45 | symbols = line
46 | else:
47 | symbols = line.split(self.delimiter)
48 |
49 | if add_cls_token:
50 | return [""] + symbols + [""]
51 | elif add_cls_token_last:
52 | return [""] + symbols + [""]
53 | elif add_double_eos: # lm1b
54 | return [""] + symbols + [""]
55 | elif add_eos:
56 | return symbols + [""]
57 | elif add_s:
58 | return symbols + [""]
59 | else:
60 | return symbols
61 |
62 | def count_file(self, path, verbose=False, add_eos=False):
63 | if verbose:
64 | print("counting file {} ...".format(path))
65 | assert os.path.exists(path)
66 |
67 | sents = []
68 | with open(path, "r", encoding="utf-8") as f:
69 | for idx, line in enumerate(f):
70 | if verbose and idx > 0 and idx % 500000 == 0:
71 | print(" line {}".format(idx))
72 | symbols = self.tokenize(line, add_eos=add_eos)
73 | self.counter.update(symbols)
74 | sents.append(symbols)
75 |
76 | return sents
77 |
78 | def count_csqa(
79 | self,
80 | path,
81 | num_classes=5,
82 | verbose=False,
83 | add_eos=False,
84 | add_double_eos=False,
85 | add_cls_token=False,
86 | ):
87 | if verbose:
88 | print("counting file {} ...".format(path))
89 | assert os.path.exists(path)
90 |
91 | sents = []
92 | with open(path, "r", encoding="utf-8") as f:
93 | for idx, line in enumerate(f):
94 | if verbose and idx > 0 and idx % 500000 == 0:
95 | print(" line {}".format(idx))
96 | example = json.loads(line.strip())
97 | question = example["question"]["stem"]
98 | assert len(example["question"]["choices"]) == num_classes
99 | # format: ` Q: Where would I not want a fox? A: hen house `
100 | question = "Q: " + question
101 | question_toks = self.tokenize(
102 | question,
103 | add_eos=add_eos,
104 | add_double_eos=add_double_eos,
105 | add_cls_token=add_cls_token,
106 | )
107 | for i, choice in enumerate(example["question"]["choices"]):
108 | src = "A: " + choice["text"]
109 | assert (ord(choice["label"]) - ord("A")) == i
110 | src_bin = self.tokenize(src, add_eos=add_eos)
111 | question_toks.extend(src_bin)
112 | self.counter.update(question_toks)
113 | sents.append(question_toks)
114 | return sents
115 |
116 | def count_sst2(
117 | self,
118 | path,
119 | verbose=False,
120 | add_eos=False,
121 | add_double_eos=False,
122 | add_cls_token=False,
123 | ):
124 | if verbose:
125 | print("counting file {} ...".format(path))
126 | assert os.path.exists(path)
127 | sents = []
128 | with open(path, "r", encoding="utf-8") as f:
129 | tsv_file = csv.reader(f, delimiter="\t")
130 | for line in tsv_file:
131 | if not line[1] in ["0", "1"]:
132 | print('* Ignore ', line)
133 | continue
134 | sentence, label = line[0], int(line[1])
135 | assert label in [0, 1]
136 | sentence_toks = self.tokenize(
137 | sentence,
138 | add_eos=add_eos,
139 | add_double_eos=add_double_eos,
140 | add_cls_token=add_cls_token,
141 | )
142 | self.counter.update(sentence_toks)
143 | sents.append(sentence_toks)
144 | return sents
145 |
146 | def count_sst5(
147 | self,
148 | dataset,
149 | verbose=False,
150 | add_eos=False,
151 | add_double_eos=False,
152 | add_cls_token=False,
153 | ):
154 | # if verbose:
155 | # print("counting file {} ...".format(path))
156 | # assert os.path.exists(path)
157 | sents = []
158 | for sample in dataset:
159 | # print("here:", sample)
160 | sample = sample.to_labeled_lines()[0]
161 | if not sample[0] in [0, 1, 2, 3, 4]:
162 | print("* Ignore ", sample)
163 | continue
164 | sentence, label = sample[1], sample[0]
165 | # assert label in [0, 1, 2, 3, 4]
166 | sentence_toks = self.tokenize(
167 | sentence,
168 | add_eos=add_eos,
169 | add_double_eos=add_double_eos,
170 | add_cls_token=add_cls_token,
171 | )
172 | self.counter.update(sentence_toks)
173 | sents.append(sentence_toks)
174 | return sents
175 |
176 | def count_banking77(
177 | self,
178 | path,
179 | verbose=False,
180 | add_eos=False,
181 | add_double_eos=False,
182 | add_cls_token=False,
183 | ):
184 | if verbose:
185 | print("counting file {} ...".format(path))
186 | assert os.path.exists(path)
187 | sents = []
188 | with open(path, "r", encoding="utf-8") as f:
189 | tsv_file = csv.reader(f, delimiter="\t")
190 | for line in tsv_file:
191 | if not line[1] in [str(x) for x in list(range(77))]:
192 | # print('* Ignore ', line)
193 | continue
194 | sentence, label = line[0], int(line[1])
195 | assert label in list(range(77))
196 | sentence_toks = self.tokenize(
197 | sentence,
198 | add_eos=add_eos,
199 | add_double_eos=add_double_eos,
200 | add_cls_token=add_cls_token,
201 | )
202 | self.counter.update(sentence_toks)
203 | sents.append(sentence_toks)
204 | return sents
205 |
206 | def count_sents(self, sents, verbose=False):
207 | """
208 | sents : a list of sentences, each a list of tokenized symbols
209 | """
210 | if verbose:
211 | print("counting {} sents ...".format(len(sents)))
212 | for idx, symbols in enumerate(sents):
213 | if verbose and idx > 0 and idx % 500000 == 0:
214 | print(" line {}".format(idx))
215 | self.counter.update(symbols)
216 |
217 | def _build_from_file(self, vocab_file):
218 | self.idx2sym = []
219 | self.sym2idx = OrderedDict()
220 |
221 | with open(vocab_file, "r", encoding="utf-8") as f:
222 | for line in f:
223 | symb = line.strip().split()[0]
224 | self.add_symbol(symb)
225 | self.unk_idx = self.sym2idx[""]
226 |
227 | def build_vocab(self):
228 | if self.vocab_file:
229 | print("building vocab from {}".format(self.vocab_file))
230 | self._build_from_file(self.vocab_file)
231 | print("final vocab size {}".format(len(self)))
232 | else:
233 | print(
234 | "building vocab with min_freq={}, max_size={}".format(
235 | self.min_freq, self.max_size
236 | )
237 | )
238 | self.idx2sym = []
239 | self.sym2idx = OrderedDict()
240 |
241 | for sym in self.special:
242 | self.add_special(sym)
243 |
244 | for sym, cnt in self.counter.most_common(self.max_size):
245 | if cnt < self.min_freq:
246 | break
247 | self.add_symbol(sym)
248 |
249 | print(
250 | "final vocab size {} from {} unique tokens".format(
251 | len(self), len(self.counter)
252 | )
253 | )
254 |
255 | def encode_file(
256 | self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False
257 | ):
258 | if verbose:
259 | print("encoding file {} ...".format(path))
260 | assert os.path.exists(path)
261 | encoded = []
262 | with open(path, "r", encoding="utf-8") as f:
263 | for idx, line in enumerate(f):
264 | if verbose and idx > 0 and idx % 500000 == 0:
265 | print(" line {}".format(idx))
266 | symbols = self.tokenize(
267 | line, add_eos=add_eos, add_double_eos=add_double_eos
268 | )
269 | encoded.append(self.convert_to_tensor(symbols))
270 |
271 | if ordered:
272 | encoded = torch.cat(encoded)
273 |
274 | return encoded
275 |
276 | def encode_csqa_file(
277 | self,
278 | path,
279 | ordered=False,
280 | num_classes=5,
281 | verbose=False,
282 | add_eos=False,
283 | add_double_eos=False,
284 | add_cls_token=False,
285 | ):
286 | if verbose:
287 | print("encoding file {} ...".format(path))
288 | assert os.path.exists(path)
289 | encoded = [[] for i in range(num_classes)]
290 | labels = []
291 |
292 | with open(path, "r", encoding="utf-8") as f:
293 | for idx, line in enumerate(f):
294 | if verbose and idx > 0 and idx % 500000 == 0:
295 | print(" line {}".format(idx))
296 | example = json.loads(line.strip())
297 | if "answerKey" in example:
298 | label = ord(example["answerKey"]) - ord("A")
299 | labels.append(label)
300 | question = example["question"]["stem"]
301 | assert len(example["question"]["choices"]) == num_classes
302 | # format: ` Q: Where would I not want a fox? A: hen house `
303 | question = "Q: " + question
304 | question_bin = self.tokenize(
305 | question,
306 | add_eos=add_eos,
307 | add_double_eos=add_double_eos,
308 | add_cls_token=add_cls_token,
309 | )
310 | for i, choice in enumerate(example["question"]["choices"]):
311 | src = " A: " + choice["text"]
312 | assert (ord(choice["label"]) - ord("A")) == i
313 | src_bin = question_bin + self.tokenize(src, add_s=True)
314 | encoded[i].append(self.convert_to_tensor(src_bin))
315 |
316 | labels = torch.LongTensor(labels)
317 |
318 | # pdb.set_trace()
319 |
320 | # if ordered:
321 | # for idx in range(num_classes):
322 | # encoded[idx] = pad_sequence(encoded[idx])
323 |
324 | # encoded = pad_sequence(encoded)
325 | # print(encoded.shape)
326 |
327 | return [encoded, labels]
328 |
329 | def encode_sst2_file(
330 | self,
331 | path,
332 | verbose=False,
333 | add_eos=False,
334 | add_double_eos=False,
335 | add_cls_token=False,
336 | ):
337 | if verbose:
338 | print("encoding file {} ...".format(path))
339 | assert os.path.exists(path)
340 | encoded = []
341 | labels = []
342 | with open(path, "r", encoding="utf-8") as f:
343 | tsv_file = csv.reader(f, delimiter="\t")
344 | for line in tsv_file:
345 | if not line[1] in ["0", "1"]:
346 | print(line[0])
347 | print("* Ignore ", line)
348 | continue
349 | sentence, label = line[0], int(line[1])
350 | assert label in [0, 1]
351 | sentence_toks = self.tokenize(
352 | sentence,
353 | add_eos=add_eos,
354 | add_double_eos=add_double_eos,
355 | add_cls_token=add_cls_token,
356 | )
357 | encoded.append(self.convert_to_tensor(sentence_toks))
358 | labels.append(label)
359 |
360 | labels = torch.LongTensor(labels)
361 | return [encoded, labels]
362 |
363 | def encode_sst5_file(
364 | self,
365 | dataset,
366 | verbose=False,
367 | add_eos=False,
368 | add_double_eos=False,
369 | add_cls_token=False,
370 | ):
371 | # if verbose:
372 | # print("encoding file {} ...".format(path))
373 | # assert os.path.exists(path)
374 | encoded = []
375 | labels = []
376 | for sample in dataset:
377 | sample = sample.to_labeled_lines()[0]
378 | if not sample[0] in [0, 1, 2, 3, 4]:
379 | print("* Ignore ", sample)
380 | continue
381 | sentence, label = sample[1], sample[0]
382 | # assert label in [0, 1, 2, 3, 4]
383 | sentence_toks = self.tokenize(
384 | sentence,
385 | add_eos=add_eos,
386 | add_double_eos=add_double_eos,
387 | add_cls_token=add_cls_token,
388 | )
389 | encoded.append(self.convert_to_tensor(sentence_toks))
390 | labels.append(label)
391 |
392 | labels = torch.LongTensor(labels)
393 | return [encoded, labels]
394 |
395 | def encode_banking77_file(
396 | self,
397 | path,
398 | verbose=False,
399 | add_eos=False,
400 | add_double_eos=False,
401 | add_cls_token=False,
402 | ):
403 | if verbose:
404 | print("encoding file {} ...".format(path))
405 | assert os.path.exists(path)
406 | encoded = []
407 | labels = []
408 | with open(path, "r", encoding="utf-8") as f:
409 | tsv_file = csv.reader(f, delimiter="\t")
410 | for line in tsv_file:
411 | if not line[1] in [str(x) for x in list(range(77))]:
412 | print("* Ignore ", line)
413 | continue
414 | sentence, label = line[0], int(line[1])
415 | assert label in list(range(77))
416 | sentence_toks = self.tokenize(
417 | sentence,
418 | add_eos=add_eos,
419 | add_double_eos=add_double_eos,
420 | add_cls_token=add_cls_token,
421 | )
422 | encoded.append(self.convert_to_tensor(sentence_toks))
423 | labels.append(label)
424 |
425 | labels = torch.LongTensor(labels)
426 | return [encoded, labels]
427 |
428 | def encode_sents(self, sents, ordered=False, verbose=False):
429 | if verbose:
430 | print("encoding {} sents ...".format(len(sents)))
431 | encoded = []
432 | for idx, symbols in enumerate(sents):
433 | if verbose and idx > 0 and idx % 500000 == 0:
434 | print(" line {}".format(idx))
435 | encoded.append(self.convert_to_tensor(symbols))
436 |
437 | if ordered:
438 | encoded = torch.cat(encoded)
439 |
440 | return encoded
441 |
442 | def add_special(self, sym):
443 | if sym not in self.sym2idx:
444 | self.idx2sym.append(sym)
445 | self.sym2idx[sym] = len(self.idx2sym) - 1
446 | setattr(self, "{}_idx".format(sym.strip("<>")), self.sym2idx[sym])
447 |
448 | def add_symbol(self, sym):
449 | if sym not in self.sym2idx:
450 | self.idx2sym.append(sym)
451 | self.sym2idx[sym] = len(self.idx2sym) - 1
452 |
453 | def get_sym(self, idx):
454 | assert 0 <= idx < len(self), "Index {} out of range".format(idx)
455 | return self.idx2sym[idx]
456 |
457 | def get_idx(self, sym):
458 | if sym in self.sym2idx:
459 | return self.sym2idx[sym]
460 | else:
461 | # print('encounter unk {}'.format(sym))
462 | print(sym)
463 | assert "" not in sym
464 | assert hasattr(self, "unk_idx")
465 | return self.sym2idx.get(sym, self.unk_idx)
466 |
467 | def get_symbols(self, indices):
468 | return [self.get_sym(idx) for idx in indices]
469 |
470 | def get_indices(self, symbols):
471 | return [self.get_idx(sym) for sym in symbols]
472 |
473 | def convert_to_tensor(self, symbols):
474 | return torch.LongTensor(self.get_indices(symbols))
475 |
476 | def convert_to_sent(self, indices, exclude=None):
477 | if exclude is None:
478 | return " ".join([self.get_sym(idx) for idx in indices])
479 | else:
480 | return " ".join(
481 | [self.get_sym(idx) for idx in indices if idx not in exclude]
482 | )
483 |
484 | def __len__(self):
485 | return len(self.idx2sym)
486 |
--------------------------------------------------------------------------------
/finetune_data.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import glob
3 | import pdb
4 | from collections import Counter, OrderedDict
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import Dataset, DataLoader
8 | from vocabulary import Vocab
9 | import torch.nn.functional as F
10 | from torch.nn.utils.rnn import pad_sequence
11 | import pytreebank
12 |
13 |
14 | def pad_sequence_reverse(data):
15 | # data should be a list of 1D tensors
16 |
17 | assert data[0].dim() == 1
18 | device = data[0].device
19 | length_list = []
20 | for item in data:
21 | length_list.append(item.shape[0])
22 | max_length = max(length_list)
23 |
24 | # padding
25 | padded_data_list = []
26 | for item in data:
27 | padded_item = torch.cat(
28 | [torch.zeros(max_length - item.shape[0], dtype=item.dtype).to(device), item]
29 | ).reshape(-1, 1)
30 | padded_data_list.append(padded_item)
31 | padded_data_list = torch.cat(padded_data_list, dim=1)
32 | return padded_data_list
33 |
34 |
35 | class LMOrderedIterator(object):
36 | def __init__(self, data, bsz, bptt, device="cpu", ext_len=None):
37 | """
38 | data -- LongTensor -- the LongTensor is strictly ordered
39 | """
40 | self.bsz = bsz
41 | self.bptt = bptt
42 | self.ext_len = ext_len if ext_len is not None else 0
43 |
44 | self.device = device
45 |
46 | # Work out how cleanly we can divide the dataset into bsz parts.
47 | self.n_step = data.size(0) // bsz
48 |
49 | # Trim off any extra elements that wouldn't cleanly fit (remainders).
50 | data = data.narrow(0, 0, self.n_step * bsz)
51 |
52 | # Evenly divide the data across the bsz batches.
53 | self.data = data.view(bsz, -1).t().contiguous().to(device)
54 |
55 | # Number of mini-batches
56 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
57 |
58 | def get_batch(self, i, bptt=None):
59 | if bptt is None:
60 | bptt = self.bptt
61 | seq_len = min(bptt, self.data.size(0) - 1 - i)
62 |
63 | end_idx = i + seq_len
64 | beg_idx = max(0, i - self.ext_len)
65 |
66 | data = self.data[beg_idx:end_idx]
67 | target = self.data[i + 1 : i + 1 + seq_len]
68 |
69 | return data, target, seq_len
70 |
71 | def get_fixlen_iter(self, start=0):
72 | for i in range(start, self.data.size(0) - 1, self.bptt):
73 | yield self.get_batch(i)
74 |
75 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
76 | max_len = self.bptt + max_deviation * std
77 | i = start
78 | while True:
79 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.0
80 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
81 | data, target, seq_len = self.get_batch(i, bptt)
82 | i += seq_len
83 | yield data, target, seq_len
84 | if i >= self.data.size(0) - 2:
85 | break
86 |
87 | def __iter__(self):
88 | return self.get_fixlen_iter()
89 |
90 |
91 | class SST2Iterator(object):
92 | def __init__(self, data, bsz):
93 | """
94 | data: [encoded, labels]
95 | """
96 |
97 | self.bsz = bsz
98 |
99 | self.encoded = data[0]
100 | self.labels = data[1] # Tensor
101 |
102 | self.n_step = self.labels.size(0) // bsz
103 | self.cur_step = 0
104 | self.n_samples = self.labels.size(0)
105 | self.sequence_array = np.arange(self.n_samples)
106 |
107 | def get_batch(self, index_list):
108 |
109 | subencoded = []
110 | mask_idx_pre = []
111 | sublabels = []
112 |
113 | for idx in index_list:
114 | subencoded.append(self.encoded[idx])
115 | sublabels.append(self.labels[idx])
116 | mask_idx_pre.append(torch.ones(self.encoded[idx].shape[0]))
117 |
118 | subencoded = pad_sequence_reverse(subencoded)
119 | mask_idx = 1 - pad_sequence_reverse(mask_idx_pre)
120 | length = mask_idx.shape[0]
121 |
122 | expand_mask_idx = mask_idx.unsqueeze(1).repeat(
123 | 1, length, 1
124 | ) # length, length, batch-size
125 | expand_mask_idx = ((expand_mask_idx + mask_idx) > 0).byte()
126 |
127 | # mask_idx = pad_sequence(mask_idx)
128 | sublabels = torch.LongTensor(sublabels)
129 |
130 | return subencoded, expand_mask_idx, sublabels
131 |
132 | def get_varlen_iter(self, start=0):
133 | sample_array = np.random.permutation(self.n_samples)
134 | for i in range(self.n_step):
135 | sub_index = sample_array[i * self.bsz : i * self.bsz + self.bsz]
136 | yield self.get_batch(sub_index)
137 |
138 | def get_fixlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
139 | # print(self.n_step)
140 | for i in range(self.cur_step, self.cur_step+(self.n_step//5)):
141 | # for i in range(self.n_step):
142 | sub_index = self.sequence_array[i * self.bsz : i * self.bsz + self.bsz]
143 | yield self.get_batch(sub_index)
144 | self.cur_step = i + 1
145 |
146 | def __iter__(self):
147 | return self.get_fixlen_iter()
148 |
149 |
150 | class LMShuffledIterator(object):
151 | def __init__(self, data, bsz, bptt, device="cpu", ext_len=None, shuffle=False):
152 | """
153 | data -- list[LongTensor] -- there is no order among the LongTensors
154 | """
155 | self.data = data
156 |
157 | self.bsz = bsz
158 | self.bptt = bptt
159 | self.ext_len = ext_len if ext_len is not None else 0
160 |
161 | self.device = device
162 | self.shuffle = shuffle
163 |
164 | def get_sent_stream(self):
165 | # index iterator
166 | epoch_indices = (
167 | np.random.permutation(len(self.data))
168 | if self.shuffle
169 | else np.array(range(len(self.data)))
170 | )
171 |
172 | # sentence iterator
173 | for idx in epoch_indices:
174 | yield self.data[idx]
175 |
176 | def stream_iterator(self, sent_stream):
177 | # streams for each data in the batch
178 | streams = [None] * self.bsz
179 |
180 | data = torch.LongTensor(self.bptt, self.bsz)
181 | target = torch.LongTensor(self.bptt, self.bsz)
182 |
183 | n_retain = 0
184 |
185 | while True:
186 | # data : [n_retain+bptt x bsz]
187 | # target : [bptt x bsz]
188 | data[n_retain:].fill_(-1)
189 | target.fill_(-1)
190 |
191 | valid_batch = True
192 |
193 | for i in range(self.bsz):
194 | n_filled = 0
195 | try:
196 | while n_filled < self.bptt:
197 | if streams[i] is None or len(streams[i]) <= 1:
198 | streams[i] = next(sent_stream)
199 | # number of new tokens to fill in
200 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
201 | # first n_retain tokens are retained from last batch
202 | data[n_retain + n_filled : n_retain + n_filled + n_new, i] = (
203 | streams[i][:n_new]
204 | )
205 | target[n_filled : n_filled + n_new, i] = streams[i][
206 | 1 : n_new + 1
207 | ]
208 | streams[i] = streams[i][n_new:]
209 | n_filled += n_new
210 | except StopIteration:
211 | valid_batch = False
212 | break
213 |
214 | if not valid_batch:
215 | return
216 |
217 | data = data.to(self.device)
218 | target = target.to(self.device)
219 |
220 | yield data, target, self.bptt
221 |
222 | n_retain = min(data.size(0), self.ext_len)
223 | if n_retain > 0:
224 | data[:n_retain] = data[-n_retain:]
225 | data.resize_(n_retain + self.bptt, data.size(1))
226 |
227 | def __iter__(self):
228 | # sent_stream is an iterator
229 | sent_stream = self.get_sent_stream()
230 |
231 | for batch in self.stream_iterator(sent_stream):
232 | yield batch
233 |
234 |
235 | class LMMultiFileIterator(LMShuffledIterator):
236 | def __init__(
237 | self, paths, vocab, bsz, bptt, device="cpu", ext_len=None, shuffle=False
238 | ):
239 |
240 | self.paths = paths
241 | self.vocab = vocab
242 |
243 | self.bsz = bsz
244 | self.bptt = bptt
245 | self.ext_len = ext_len if ext_len is not None else 0
246 |
247 | self.device = device
248 | self.shuffle = shuffle
249 |
250 | def get_sent_stream(self, path):
251 | sents = self.vocab.encode_file(path, add_double_eos=True)
252 | if self.shuffle:
253 | np.random.shuffle(sents)
254 | sent_stream = iter(sents)
255 |
256 | return sent_stream
257 |
258 | def __iter__(self):
259 | if self.shuffle:
260 | np.random.shuffle(self.paths)
261 |
262 | for path in self.paths:
263 | # sent_stream is an iterator
264 | sent_stream = self.get_sent_stream(path)
265 | for batch in self.stream_iterator(sent_stream):
266 | yield batch
267 |
268 |
269 | class Corpus(object):
270 | def __init__(self, path, dataset, *args, **kwargs):
271 | self.dataset = dataset
272 | self.vocab = Vocab(*args, **kwargs)
273 |
274 | if self.dataset in ["ptb", "wt2", "enwik8", "text8"]:
275 | self.vocab.count_file(os.path.join(path, "train.txt"))
276 | self.vocab.count_file(os.path.join(path, "valid.txt"))
277 | self.vocab.count_file(os.path.join(path, "test.txt"))
278 | elif self.dataset == "wt103":
279 | self.vocab.count_file(os.path.join(path, "train.txt"))
280 | elif self.dataset == "lm1b":
281 | train_path_pattern = os.path.join(
282 | path,
283 | "1-billion-word-language-modeling-benchmark-r13output",
284 | "training-monolingual.tokenized.shuffled",
285 | "news.en-*",
286 | )
287 | train_paths = glob.glob(train_path_pattern)
288 | # the vocab will load from file when build_vocab() is called
289 |
290 | elif self.dataset == "csqa":
291 | self.vocab.count_csqa(
292 | os.path.join(path, "train_rand_split.jsonl"), add_cls_token=True
293 | )
294 | self.vocab.count_csqa(
295 | os.path.join(path, "dev_rand_split.jsonl"), add_cls_token=True
296 | )
297 | self.vocab.count_csqa(
298 | os.path.join(path, "test_rand_split_no_answers.jsonl"),
299 | add_cls_token=True,
300 | )
301 |
302 | elif self.dataset in ["sst2", "imdb"]:
303 | self.vocab.count_sst2(os.path.join(path, "train.tsv"), add_cls_token=True)
304 | self.vocab.count_sst2(os.path.join(path, "dev.tsv"), add_cls_token=True)
305 |
306 | elif self.dataset == "sst5":
307 | dataset = pytreebank.load_sst(path)
308 | train = dataset['train']
309 | val = dataset['dev']
310 | test = dataset['test']
311 | # self.vocab.count_sst5(os.path.join(path, "train.tsv"), add_cls_token=True)
312 | # self.vocab.count_sst5(os.path.join(path, "dev.tsv"), add_cls_token=True)
313 | self.vocab.count_sst5(train, add_cls_token=True)
314 | self.vocab.count_sst5(val, add_cls_token=True)
315 |
316 | elif self.dataset == "banking77":
317 | self.vocab.count_banking77(
318 | os.path.join(path, "train.tsv"), add_cls_token=True
319 | )
320 | self.vocab.count_banking77(
321 | os.path.join(path, "dev.tsv"), add_cls_token=True
322 | )
323 |
324 | self.vocab.build_vocab()
325 |
326 | if self.dataset in ["ptb", "wt2", "wt103"]:
327 | self.train = self.vocab.encode_file(
328 | os.path.join(path, "train.txt"), ordered=True
329 | )
330 | self.valid = self.vocab.encode_file(
331 | os.path.join(path, "valid.txt"), ordered=True
332 | )
333 | self.test = self.vocab.encode_file(
334 | os.path.join(path, "test.txt"), ordered=True
335 | )
336 | elif self.dataset in ["enwik8", "text8"]:
337 | self.train = self.vocab.encode_file(
338 | os.path.join(path, "train.txt"), ordered=True, add_eos=False
339 | )
340 | self.valid = self.vocab.encode_file(
341 | os.path.join(path, "valid.txt"), ordered=True, add_eos=False
342 | )
343 | self.test = self.vocab.encode_file(
344 | os.path.join(path, "test.txt"), ordered=True, add_eos=False
345 | )
346 | elif self.dataset == "lm1b":
347 | self.train = train_paths
348 | self.valid = self.vocab.encode_file(
349 | os.path.join(path, "valid.txt"), ordered=False, add_double_eos=True
350 | )
351 | self.test = self.vocab.encode_file(
352 | os.path.join(path, "test.txt"), ordered=False, add_double_eos=True
353 | )
354 | elif self.dataset == "csqa":
355 | self.train = self.vocab.encode_csqa_file(
356 | os.path.join(path, "train_rand_split.jsonl"),
357 | ordered=True,
358 | add_cls_token=True,
359 | )
360 | self.valid = self.vocab.encode_csqa_file(
361 | os.path.join(path, "dev_rand_split.jsonl"),
362 | ordered=True,
363 | add_cls_token=True,
364 | )
365 | elif self.dataset in ["sst2", "imdb"]:
366 | self.train = self.vocab.encode_sst2_file(
367 | os.path.join(path, "train.tsv"), add_cls_token=True
368 | )
369 | self.valid = self.vocab.encode_sst2_file(
370 | os.path.join(path, "dev.tsv"), add_cls_token=True
371 | )
372 |
373 | elif self.dataset == "sst5":
374 | self.train = self.vocab.encode_sst5_file(
375 | train, add_cls_token=True
376 | )
377 | self.valid = self.vocab.encode_sst5_file(
378 | val, add_cls_token=True
379 | )
380 | # self.test = self.vocab.encode_sst5_file(
381 | # test, add_cls_token=True
382 | # )
383 | elif self.dataset == "banking77":
384 | self.train = self.vocab.encode_banking77_file(
385 | os.path.join(path, "train.tsv"), add_cls_token=True
386 | )
387 | self.valid = self.vocab.encode_banking77_file(
388 | os.path.join(path, "dev.tsv"), add_cls_token=True
389 | )
390 |
391 | def get_iterator(self, split, *args, **kwargs):
392 |
393 | if split == "train":
394 | if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
395 | data_iter = LMOrderedIterator(self.train, *args, **kwargs)
396 | elif self.dataset == "lm1b":
397 | kwargs["shuffle"] = True
398 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
399 | elif self.dataset == "csqa":
400 | data_iter = CSQAIterator(self.train, *args, **kwargs)
401 | elif self.dataset in ["sst2", "imdb", "sst5", "banking77"]:
402 | data_iter = SST2Iterator(self.train, *args, **kwargs)
403 | # dataset = CSQADataset(self.train)
404 | # data_iter = DataLoader(dataset, *args, shuffle=True,
405 | # num_workers=4, drop_last=False, pin_memory=True)
406 |
407 | elif split in ["valid", "test"]:
408 | data = self.valid if split == "valid" else self.test
409 | if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
410 | data_iter = LMOrderedIterator(data, *args, **kwargs)
411 | elif self.dataset == "lm1b":
412 | data_iter = LMShuffledIterator(data, *args, **kwargs)
413 | elif self.dataset == "csqa":
414 | data_iter = CSQAIterator(self.valid, *args, **kwargs)
415 | elif self.dataset in ["sst2", "imdb", "sst5", "banking77"]:
416 | data_iter = SST2Iterator(self.valid, *args, **kwargs)
417 |
418 | # dataset = CSQADataset(self.valid)
419 | # data_iter = DataLoader(dataset, *args, shuffle=False,
420 | # num_workers=4, drop_last=False, pin_memory=True)
421 | return data_iter
422 |
423 |
424 | def get_lm_corpus(datadir, dataset):
425 | fn = os.path.join(datadir, "cache.pt")
426 | # print(fn)
427 | if os.path.exists(fn):
428 | print("Loading cached dataset...")
429 | corpus = torch.load(fn)
430 | else:
431 | print("Producing dataset {}...".format(dataset))
432 | kwargs = {}
433 | if dataset in ["wt103", "wt2"]:
434 | kwargs["special"] = [""]
435 | kwargs["lower_case"] = False
436 | elif dataset == "ptb":
437 | kwargs["special"] = [""]
438 | kwargs["lower_case"] = True
439 | elif dataset == "lm1b":
440 | kwargs["special"] = []
441 | kwargs["lower_case"] = False
442 | kwargs["vocab_file"] = os.path.join(datadir, "1b_word_vocab.txt")
443 | elif dataset in ["csqa", "sst2", "imdb", "sst5", "banking77"]:
444 | kwargs["special"] = [""]
445 | elif dataset in ["enwik8", "text8"]:
446 | pass
447 |
448 | corpus = Corpus(datadir, dataset, **kwargs)
449 | torch.save(corpus, fn)
450 |
451 | return corpus
452 |
453 |
454 | if __name__ == "__main__":
455 | import argparse
456 |
457 | parser = argparse.ArgumentParser(description="unit test")
458 | parser.add_argument(
459 | "--datadir", type=str, default="./enwik8", help="location of the data corpus"
460 | )
461 | parser.add_argument(
462 | "--dataset",
463 | type=str,
464 | default="enwik8",
465 | choices=["ptb", "wt2", "wt103", "lm1b", "enwik8", "text8"],
466 | help="dataset name",
467 | )
468 | args = parser.parse_args()
469 |
470 | corpus = get_lm_corpus(args.datadir, args.dataset)
471 | print("Vocab size : {}".format(len(corpus.vocab.idx2sym)))
472 |
--------------------------------------------------------------------------------
/custom_layers_opt.py:
--------------------------------------------------------------------------------
1 | r"""
2 | FMoE core layer
3 | """
4 |
5 | import tree
6 | import os
7 | import torch
8 | import torch.nn as nn
9 |
10 | from custom_functions import prepare_forward, ensure_comm
11 | from custom_functions import MOEScatter, MOEGather
12 | from custom_functions import AllGather, Slice
13 | from gates import NaiveGate
14 |
15 | from fastermoe.config import switch_from_env
16 | import random
17 | import torch.nn.functional as F
18 | from torchmetrics.regression import KLDivergence
19 |
20 |
21 | def kl_divergence(softmax_1, softmax_2):
22 | kl_divergence = KLDivergence(log_prob=False).cuda()
23 | return kl_divergence(softmax_1, softmax_2)
24 |
25 |
26 | def cal_mse_loss(input, target):
27 | mse_loss = nn.MSELoss()
28 | _loss = mse_loss(input, target)
29 | return _loss
30 |
31 |
32 | def mark_module_parallel_comm(module, comm):
33 | r"""
34 | Mark all parameters in `module` as doing data parallel in `comm`, where
35 | `comm` may be one of `'world', 'dp', 'none'`.
36 | """
37 | for p in module.parameters():
38 | setattr(p, "dp_comm", comm)
39 |
40 |
41 | def _fmoe_general_global_forward(
42 | inp, gate, expert_fn, num_expert, world_size, **kwargs
43 | ):
44 | r"""
45 | A private function that performs the following steps to complete the MoE
46 | computation.
47 | * Count the number of tokens from each worker to each expert.
48 | * Send the features to their target position so that input features to each
49 | expert are contiguous in memory.
50 | * Perform the forward computation of the experts using `expert_fn`
51 | * Gather the output features of experts back, and reorder them as sentences.
52 | Intermediate results like expert counts are hidden from users by this
53 | function.
54 | """
55 | (
56 | pos,
57 | local_expert_count,
58 | global_expert_count,
59 | fwd_expert_count,
60 | fwd_batch_size,
61 | ) = prepare_forward(gate, num_expert, world_size)
62 | topk = 1
63 | if len(gate.shape) == 2:
64 | topk = gate.shape[1]
65 |
66 | def scatter_func(tensor):
67 | return MOEScatter.apply(
68 | tensor,
69 | torch.div(pos, topk, rounding_mode="floor"),
70 | local_expert_count,
71 | global_expert_count,
72 | fwd_batch_size,
73 | world_size,
74 | )
75 |
76 | x = tree.map_structure(scatter_func, inp)
77 |
78 | x = expert_fn(x, fwd_expert_count)
79 |
80 | out_batch_size = tree.flatten(inp)[0].shape[0]
81 | if len(gate.shape) == 2:
82 | out_batch_size *= gate.shape[1]
83 |
84 | def gather_func(tensor):
85 | return MOEGather.apply(
86 | tensor,
87 | pos,
88 | local_expert_count,
89 | global_expert_count,
90 | out_batch_size,
91 | world_size,
92 | )
93 |
94 | outp = tree.map_structure(gather_func, x)
95 | return outp
96 |
97 |
98 | fmoe_faster_schedule = False
99 | if switch_from_env("FMOE_FASTER_SCHEDULE_ENABLE", False):
100 | fmoe_faster_schedule = True
101 | from .fastermoe.schedule import _fmoe_general_global_forward
102 |
103 |
104 | class FMoEOpt(nn.Module):
105 | r"""
106 | A general moe implementation that supports an arbitrary module as the
107 | expert.
108 | * `num_expert` stands for the number of experts on **each** worker.
109 | * `world_size` stands for the total number of workers that contains
110 | different experts.
111 | * `slice_group` can be a torch's communication group, indicating that
112 | specific model parallel is applied across the group, and workers in the
113 | group hold the same copy of input feature, and requires the same copy of
114 | the output. For each worker, FMoE only computes the output of a certain
115 | slice of the input batch, and will all-gather the outputs after
116 | computation.
117 | * `top_k` stands for the number of experts each token is going to.
118 | * `gate` is a gate class which can found in `fmoe.gates`.
119 | * `expert` can be specified as a module class, it is used to generate
120 | `num_expert` expert modules.
121 | """
122 |
123 | def __init__(
124 | self,
125 | num_expert=32,
126 | d_model=1024,
127 | world_size=1,
128 | mp_group=None, # being deprecated
129 | slice_group=None,
130 | moe_group=None,
131 | moe_top_k=2,
132 | gate=NaiveGate,
133 | expert=None,
134 | gate_hook=None,
135 | mask=None,
136 | mask_dict=None,
137 | freq=0.0,
138 | alpha=0.0,
139 | act_experts="shuffle",
140 | g_blance=False,
141 | opt_blance=False,
142 | combine_gate=False,
143 | opt_loss="mse",
144 | ):
145 | super().__init__()
146 | self.num_expert = num_expert
147 | self.d_model = d_model
148 | self.world_size = world_size
149 | self.freq = freq
150 | self.alpha = alpha
151 | self.act_experts = act_experts
152 | self.opt_blance = opt_blance
153 | self.combine_gate = combine_gate
154 | self.opt_loss = opt_loss
155 | self.slice_group = slice_group
156 | if mp_group is not None:
157 | print("[Warning] mp_group is being deprecated")
158 | self.slice_group = mp_group
159 | if self.slice_group is None:
160 | self.slice_size = 1
161 | self.slice_rank = 0
162 | else:
163 | self.slice_size = self.slice_group.size()
164 | self.slice_rank = self.slice_group.rank()
165 |
166 | self.top_k = moe_top_k
167 | if type(expert) is list:
168 | self.experts = nn.ModuleList([e(d_model) for e in expert])
169 | self.experts_fused = False
170 | self.num_expert = num_expert = len(expert)
171 | elif expert is not None:
172 | self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)])
173 | self.experts_fused = False
174 | else:
175 | self.experts_fused = True
176 |
177 | self.gate = gate(d_model, num_expert, world_size, moe_top_k, g_blance)
178 | self.gate_hook = gate_hook
179 | self.mask = mask
180 | self.mask_dict = mask_dict
181 | self.moe_group = moe_group
182 |
183 | def expert_fn(self, inp, fwd_expert_count):
184 | r"""
185 | The default expert function which either calls the experts as a whole
186 | or as separate experts.
187 | """
188 | if self.experts_fused:
189 | return self.experts(inp, fwd_expert_count)
190 | if isinstance(fwd_expert_count, torch.Tensor):
191 | fwd_expert_count = fwd_expert_count.cpu().numpy()
192 | outputs = []
193 | base_idx = 0
194 | for i in range(self.num_expert):
195 | batch_size = fwd_expert_count[i]
196 | inp_slice = inp[base_idx : base_idx + batch_size]
197 | outputs.append(self.experts[i](inp_slice))
198 | base_idx += batch_size
199 | return torch.cat(outputs, dim=0)
200 |
201 | def mark_parallel_comm(self, expert_dp_comm="none"):
202 | r"""
203 | Automatically mark the data parallel comms of the parameters within the
204 | module. This can be typically called at the end of the __init__ function
205 | in child classes.
206 | """
207 | if self.experts is not None:
208 | comm = expert_dp_comm
209 | if isinstance(self.experts, list):
210 | for e in self.experts:
211 | mark_module_parallel_comm(e, comm)
212 | else:
213 | mark_module_parallel_comm(self.experts, comm)
214 | mark_module_parallel_comm(self.gate, "gate")
215 |
216 | def cal_load_balance(self, gate, gate_top_k_idx):
217 |
218 | score = F.softmax(gate, dim=-1)
219 | valid_idx = gate_top_k_idx[gate_top_k_idx > -1]
220 | fraction_expert = (
221 | torch.scatter_add(
222 | torch.zeros(self.num_expert, device=valid_idx.device),
223 | 0,
224 | valid_idx,
225 | torch.ones_like(valid_idx, dtype=torch.float),
226 | )
227 | / valid_idx.numel()
228 | )
229 | prob_expert = score.sum(dim=0) / valid_idx.numel()
230 |
231 | loss = (fraction_expert * prob_expert).sum() * self.num_expert
232 | return loss
233 |
234 | def forward(self, moe_inp):
235 | r"""
236 | The FMoE module first computes gate output, and then conduct MoE forward
237 | according to the gate. The score of the selected gate given by the
238 | expert is multiplied to the experts' output tensors as a weight.
239 | """
240 |
241 | moe_inp_batch_size = tree.flatten(
242 | tree.map_structure(lambda tensor: tensor.shape[0], moe_inp)
243 | )
244 | assert all(
245 | [batch_size == moe_inp_batch_size[0] for batch_size in moe_inp_batch_size]
246 | ), "MoE inputs must have the same batch size"
247 |
248 | if self.world_size > 1:
249 |
250 | def ensure_comm_func(tensor):
251 | ensure_comm(tensor, self.moe_group)
252 |
253 | tree.map_structure(ensure_comm_func, moe_inp)
254 | if self.slice_size > 1:
255 |
256 | def slice_func(tensor):
257 | return Slice.apply(
258 | tensor, self.slice_rank, self.slice_size, self.slice_group
259 | )
260 |
261 | moe_inp = tree.map_structure(slice_func, moe_inp)
262 | flip_ = random.random()
263 | gate_top_k_idx, gate_score, gate_ = self.gate(moe_inp, return_all_scores=True)
264 |
265 | if self.training:
266 | if flip_ > (1 - self.freq):
267 | # all experts score
268 | gate_top_k_val_org, _ = torch.topk(
269 | gate_, k=self.num_expert, dim=-1, largest=True, sorted=False
270 | )
271 | gate_top_k_val_org = gate_top_k_val_org.view(
272 | -1, self.num_expert
273 | ) # (BxL) x 1 x top_k
274 | gate_score_org = F.softmax(gate_top_k_val_org, dim=-1)
275 | # activate all experts with shuffle index
276 | if self.act_experts == "shuffle":
277 | # searching best routing optimal
278 | gate_dense = torch.ones_like(
279 | gate_
280 | ) # average the importance of all experts
281 | gate_top_k_val_opt, gate_top_k_idx_opt = torch.topk(
282 | gate_dense,
283 | k=self.num_expert,
284 | dim=-1,
285 | largest=True,
286 | sorted=False,
287 | ) # [.. x top_k]
288 | gate_top_k_val_opt = gate_top_k_val_opt.view(
289 | -1, self.num_expert
290 | ) # (BxL) x 1 x num_experts
291 | gate_score_opt = F.softmax(gate_top_k_val_opt, dim=-1)
292 |
293 | if hasattr(self.gate, "dynamic_top_k"):
294 | self.top_k = self.gate.dynamic_top_k
295 |
296 | if self.gate_hook is not None:
297 | self.gate_hook(gate_top_k_idx_opt, gate_score_opt, None)
298 |
299 | # delete masked tensors
300 | if self.mask is not None and self.mask_dict is not None:
301 | # TODO: to fix
302 | def delete_mask_func(tensor):
303 | # to: (BxL') x d_model
304 | tensor = tensor[mask == 0, :]
305 | return tensor
306 |
307 | mask = self.mask.view(-1)
308 | moe_inp = tree.map_structure(delete_mask_func, moe_inp)
309 | gate_top_k_idx_opt = gate_top_k_idx_opt[mask == 0, :]
310 | bs = moe_inp.shape[0]
311 | fwd_tmp = _fmoe_general_global_forward(
312 | moe_inp,
313 | gate_top_k_idx_opt,
314 | self.expert_fn,
315 | self.num_expert,
316 | self.world_size,
317 | experts=self.experts,
318 | ).reshape(bs, self.num_expert, -1)
319 | # cal norm of output experts
320 | fwd_norm = torch.norm(fwd_tmp, dim=2)
321 | # activate all experts without shuffle index
322 | else:
323 | # activate with grad
324 | if self.opt_blance:
325 | fwd_tmp = None
326 | for i in range(self.num_expert):
327 |
328 | temp_ = (
329 | moe_inp @ self.experts.htoh4.weight[i].T
330 | + self.experts.htoh4.bias[i]
331 | )
332 | temp_ = F.relu(temp_)
333 | temp_ = (
334 | temp_ @ self.experts.h4toh.weight[i].T
335 | + self.experts.h4toh.bias[i]
336 | )
337 | temp_ = torch.unsqueeze(temp_, -1)
338 | if fwd_tmp is None:
339 | fwd_tmp = temp_.clone()
340 | else:
341 | fwd_tmp = torch.concat([fwd_tmp, temp_], dim=-1)
342 | # activate without grad
343 | else:
344 | fwd_tmp = None
345 | for i in range(self.num_expert):
346 |
347 | temp_ = (
348 | moe_inp @ self.experts.htoh4.weight[i].T
349 | + self.experts.htoh4.bias[i]
350 | )
351 | temp_ = F.relu(temp_)
352 | temp_ = (
353 | temp_ @ self.experts.h4toh.weight[i].T
354 | + self.experts.h4toh.bias[i]
355 | )
356 | temp_ = torch.unsqueeze(temp_, -1)
357 | if fwd_tmp is None:
358 | fwd_tmp = temp_.clone()
359 | else:
360 | fwd_tmp = torch.concat([fwd_tmp, temp_], dim=-1)
361 | # cal norm of output experts
362 | fwd_norm = torch.norm(fwd_tmp, dim=1)
363 | # ensemble with gate information
364 | if self.combine_gate:
365 | fwd_norm = fwd_norm * 0.5 + gate_ * 0.5
366 | gate_top_k_val_optim, gate_top_k_idx_optim = torch.topk(
367 | fwd_norm, k=self.top_k, dim=-1, largest=True, sorted=False
368 | )
369 | # if balance loss
370 | if self.opt_blance:
371 | opt_bl_loss = self.cal_load_balance(fwd_norm, gate_top_k_idx_optim)
372 | # get output
373 | gate_top_k_val_optim = gate_top_k_val_optim.view(-1, self.top_k)
374 | # push low score to zeros
375 | gate_score2 = torch.zeros(
376 | (gate_top_k_val_optim.shape[0], self.num_expert)
377 | ).cuda()
378 | gate_score2.fill_(-10e9)
379 | # fill with topk score value
380 | gate_score2 = gate_score2.scatter(
381 | 1, gate_top_k_idx_optim, gate_top_k_val_optim
382 | )
383 | gate_score_optimal = F.softmax(gate_score2, dim=1)
384 | # # calculate loss
385 | if self.opt_loss == "mse":
386 | add_loss = cal_mse_loss(
387 | gate_score_org,
388 | gate_score_optimal,
389 | )
390 | else:
391 | add_loss = kl_divergence(
392 | gate_score_org,
393 | gate_score_optimal,
394 | )
395 | # if balance loss
396 | if self.opt_blance:
397 | add_loss += opt_bl_loss
398 | # add to balance loss
399 | self.gate.loss = add_loss * self.alpha
400 | else:
401 | self.gate.loss = add_loss * self.alpha
402 | #update gate policy
403 | gate_top_k_idx = gate_top_k_idx_optim
404 | gate_score = F.softmax(gate_top_k_val_optim, dim=1)
405 |
406 |
407 | if hasattr(self.gate, "dynamic_top_k"):
408 | self.top_k = self.gate.dynamic_top_k
409 |
410 | if self.gate_hook is not None:
411 | self.gate_hook(gate_top_k_idx, gate_score, None)
412 |
413 | # delete masked tensors
414 | if self.mask is not None and self.mask_dict is not None:
415 | # TODO: to fix
416 | def delete_mask_func(tensor):
417 | # to: (BxL') x d_model
418 | tensor = tensor[mask == 0, :]
419 | return tensor
420 |
421 | mask = self.mask.view(-1)
422 | moe_inp = tree.map_structure(delete_mask_func, moe_inp)
423 | gate_top_k_idx = gate_top_k_idx[mask == 0, :]
424 |
425 | fwd = _fmoe_general_global_forward(
426 | moe_inp,
427 | gate_top_k_idx,
428 | self.expert_fn,
429 | self.num_expert,
430 | self.world_size,
431 | experts=self.experts,
432 | )
433 |
434 | # recover deleted tensors
435 | if self.mask is not None and self.mask_dict is not None:
436 |
437 | def recover_func(tensor):
438 | # to: (BxL') x top_k x dim
439 | dim = tensor.shape[-1]
440 | tensor = tensor.view(-1, self.top_k, dim)
441 | # to: (BxL) x top_k x d_model
442 | x = torch.zeros(
443 | mask.shape[0],
444 | self.top_k,
445 | dim,
446 | device=tensor.device,
447 | dtype=tensor.dtype,
448 | )
449 | # recover
450 | x[mask == 0] = tensor
451 | for k, v in self.mask_dict.items():
452 | x[mask == k] = v
453 | return x
454 |
455 | moe_outp = tree.map_structure(recover_func, fwd)
456 | else:
457 |
458 | def view_func(tensor):
459 | dim = tensor.shape[-1]
460 | tensor = tensor.view(-1, self.top_k, dim)
461 | return tensor
462 |
463 | moe_outp = tree.map_structure(view_func, fwd)
464 |
465 | gate_score = gate_score.view(-1, 1, self.top_k)
466 |
467 | def bmm_func(tensor):
468 | dim = tensor.shape[-1]
469 | tensor = torch.bmm(gate_score, tensor).reshape(-1, dim)
470 | return tensor
471 |
472 | moe_outp = tree.map_structure(bmm_func, moe_outp)
473 |
474 | if self.slice_size > 1:
475 |
476 | def all_gather_func(tensor):
477 | return AllGather.apply(
478 | tensor, self.slice_rank, self.slice_size, self.slice_group
479 | )
480 |
481 | moe_outp = tree.map_structure(all_gather_func, moe_outp)
482 |
483 | moe_outp_batch_size = tree.flatten(
484 | tree.map_structure(lambda tensor: tensor.shape[0], moe_outp)
485 | )
486 | assert all(
487 | [batch_size == moe_outp_batch_size[0] for batch_size in moe_outp_batch_size]
488 | ), "MoE outputs must have the same batch size"
489 | return moe_outp
490 |
--------------------------------------------------------------------------------