├── .gitignore
├── CODE_OF_CONDUCT.md
├── HISTORY.rst
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── Transparency_FAQ.md
├── example
├── README.md
├── eval_moe.sh
├── moe
│ ├── __init__.py
│ ├── transformer_switch.py
│ ├── transformer_switch_layer.py
│ └── translation_switch.py
├── requirements.txt
└── run_moe.sh
├── setup.py
└── sparsemixer
├── __init__.py
├── sparsemixer.py
├── sparsemixer_v2.py
└── switchgate.py
/.gitignore:
--------------------------------------------------------------------------------
1 | output-*
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/#use-with-ide
112 | .pdm.toml
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/HISTORY.rst:
--------------------------------------------------------------------------------
1 | History
2 | =======
3 |
4 | 0.0.0 (2023/10/5)
5 | ------------------
6 | * empty placeholder
7 |
8 | 0.1.0 (2023/10/5)
9 | ------------------
10 | * implemented SparseMixer
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | 
3 |
4 |
SparseMixer
5 | Sparse Backpropagation for Mixture-of-Expert Training
6 |
7 |
8 | Mixture-of-Expert •
9 | SparseMixer •
10 | How to Use? •
11 | Examples •
12 | Citation •
13 | License
14 |
15 |
16 | [SparseMixer](https://arxiv.org/abs/2310.00811), a scalable gradient estimator, bridges the gap between backpropagation and sparse expert routing.
17 |
18 | What is Mixture-of-Expert
19 | The significant success of large-scale pre-training across various applications has underscored the imperative need for scalable models that are economically feasible.
20 | Recent advances in sparsely activated networks, prominently known as Mixture-of-Experts (MoE), have attracted widespread interest.
21 | Unlike traditional networks that densely activate all modules for all input, MoE selectively activates parts of modules to specific inputs through a process called {expert routing}, leading to notable efficiency enhancements.
22 |
23 | Numerous methods have emerged to bridge discrete and back-propagation, and most of them are based on Straight-Through (ST).
24 | Unfortunately, all existing ST estimators are incompatible with MoE, since they require activating all experts for gradient computing, thereby eliminating all the efficiency improvements of MoE.
25 | Consequently, typical MoE training strategically neglects the gradient computation for routing, trading certain training signals for sparse computation.
26 | Despite the scalability brought by sparse computation, this trade-off may result in slow convergence and improperly trained models.
27 |
28 | Backpropagation Made Sparse
29 |
30 | We propose [SparseMixer](https://arxiv.org/abs/2310.00811), a scalable gradient estimator, bridges the gap between backpropagation and sparse expert routing.
31 | Grounded in a numerical ODE framework, SparseMixer harnesses the mid-point method, a second-order ODE solver, to deliver precise gradient approximations with negligible computational overhead.
32 | Applying SparseMixer to Switch Transformer on both pre-training and machine translation tasks, SparseMixer showcases considerable performance gain, accelerating training convergence up to 2 times
33 |
34 | ### How to use?
35 |
36 | `sparsemixer` can be installed via `pip`
37 | ```
38 | pip install sparsemixer
39 | ```
40 |
41 | ### Examples
42 |
43 | Please check the `example` folder for a working example.
44 |
45 | ### Citation
46 | Please cite the following papers if you found our model useful. Thanks!
47 |
48 |
49 | >Liyuan Liu, Jianfeng Gao, and Weizhu Chen (2023). Sparse Backpropagation for MoE Training. *ArXiv, abs/2310.00811*.
50 | ```
51 | @inproceedings{liu2023bridging,
52 | title={Sparse Backpropagation for MoE Training},
53 | author = {Liu, Liyuan and Gao, Jianfeng and Chen, Weizhu},
54 | booktitle = {arXiv:2310.00811 [cs]},
55 | year={2023}
56 | }
57 | ```
58 |
59 | >Liyuan Liu, Chengyu Dong, Xiaodong Liu, Bin Yu, and Jianfeng Gao (2023). Bridging Discrete and Backpropagation: Straight-Through and Beyond. *ArXiv, abs/2304.08612*.
60 | ```
61 | @inproceedings{liu2023bridging,
62 | title={Bridging Discrete and Backpropagation: Straight-Through and Beyond},
63 | author = {Liu, Liyuan and Dong, Chengyu and Liu, Xiaodong and Yu, Bin and Gao, Jianfeng},
64 | booktitle = {arXiv:2304.08612 [cs]},
65 | year={2023}
66 | }
67 | ```
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # Support
2 |
3 | ## How to file issues and get help
4 |
5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
7 | feature request as a new Issue.
8 |
9 | For help and questions about using this project, please contact [Liyuan Liu](https://liyuanlucasliu.github.io/).
10 |
11 | ## Microsoft Support Policy
12 |
13 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
14 |
--------------------------------------------------------------------------------
/Transparency_FAQ.md:
--------------------------------------------------------------------------------
1 | # SparseMixer: Responsible AI Frequently Asked Questions
2 |
3 | ## What is SparseMixer?
4 |
5 | SparseMixer is a scalable gradient estimator for Mixture-of-Expert models.
6 |
7 | ## What can SparseMixer do?
8 |
9 | SparseMixer is able to provide reliable gradient approximation for expert routing, with only sparsely activated networks.
10 |
11 | ## What is SparseMixer’s intended use?
12 |
13 | SparseMixer aims to facilitate the training of Mixture-of-Expert models.
14 |
15 | ## How was SparseMixer evaluated? What metrics are used to measure performance?
16 |
17 | We conduct experiments on applying SparseMixer to Neural Machine Translation and Electra pre-training.
18 | SparseMixer consistently outperforms the baseline methods in all 8 settings.
19 | More details are elaborated in our paper.
20 |
21 |
22 | ## What are the limitations of SparseMixer? How can users minimize the impact of SparseMixer’s limitations when using the system?
23 |
24 | SparseMixer has first-order and second-order accuracy for gradient computation, and is only an approximation to the gradient. It excels as it facilitates a bias-variance tradeoff for the gradient estimation.
25 |
26 | ## How to use SparseMixer?
27 |
28 | `sparsemixer` can be installed via `pip`
29 | ```
30 | pip install sparsemixer
31 | ```
32 | Also, please check the `example` folder for a working example.
--------------------------------------------------------------------------------
/example/README.md:
--------------------------------------------------------------------------------
1 | ## Neural Machine Translation
2 |
3 | Please note that, this working example is build upon the the [THOR](https://github.com/microsoft/Stochastic-Mixture-of-Experts/tree/main) repo.
4 |
5 | ### Pre-processing
6 |
7 | Please refer to the [Transformer-Clinic](https://github.com/LiyuanLucasLiu/Transformer-Clinic/blob/master/pre-process/wmt14en-de.sh) repo for data preparation.
8 |
9 |
10 | ### Environment
11 |
12 | We recommend to use the docker image `nvcr.io/nvidia/pytorch:22.02-py3` for this example.
13 |
14 | ### Training and EVALUATION
15 |
16 | ```
17 | # for model training, the resulting model will be saved to `output-${NUM_OF_EXPERTS}`.
18 | bash run_moe.sh ${NUM_OF_EXPERTS} ${ROUTER} ${PATH_TO_DATA}
19 |
20 | # for model inference, the script will load model weights from `output-${NUM_OF_EXPERTS}/checkpoint_best.pt`.
21 | bash eval_moe.sh output-${NUM_OF_EXPERTS}/checkpoint_best.pt ${GPU_ID} ${PATH_TO_DATA}
22 | ```
23 |
--------------------------------------------------------------------------------
/example/eval_moe.sh:
--------------------------------------------------------------------------------
1 | MODELDIR=${1}
2 | DEVICE=${2}
3 | DATA_FOLDER=${3}
4 |
5 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
6 |
7 | CUDA_VISIBLE_DEVICES=${DEVICE} fairseq-generate $DATA_FOLDER \
8 | --path $MODELDIR \
9 | --batch-size 128 --beam 4 --lenpen 0.6 --remove-bpe \
10 | --quiet --fp16 --user-dir moe
11 |
--------------------------------------------------------------------------------
/example/moe/__init__.py:
--------------------------------------------------------------------------------
1 | from . import translation_switch
2 | from . import transformer_switch
3 |
--------------------------------------------------------------------------------
/example/moe/transformer_switch.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch import Tensor
4 | from typing import Dict, List, Optional
5 |
6 | from .transformer_switch_layer import SwitchTransformerDecoderLayer, SwitchTransformerEncoderLayer
7 |
8 | from fairseq.distributed import fsdp_wrap
9 | from fairseq.models import register_model, register_model_architecture
10 | from fairseq.models.transformer import (
11 | base_architecture,
12 | transformer_iwslt_de_en,
13 | transformer_vaswani_wmt_en_de_big,
14 | transformer_vaswani_wmt_en_fr_big,
15 | transformer_wmt_en_de_big,
16 | TransformerDecoder,
17 | TransformerEncoder,
18 | TransformerModel
19 | )
20 | from fairseq.modules.checkpoint_activations import checkpoint_wrapper
21 |
22 |
23 | DEFAULT_MAX_SOURCE_POSITIONS = 1024
24 | DEFAULT_MAX_TARGET_POSITIONS = 1024
25 |
26 | DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
27 |
28 |
29 | @register_model("switch_transformer")
30 | class SwitchTransformerModel(TransformerModel):
31 | def __init__(self, args, encoder, decoder):
32 | super().__init__(args, encoder, decoder)
33 | self.args = args
34 |
35 | @classmethod
36 | def build_model(cls, args, task):
37 | """Build a new model instance."""
38 |
39 | # make sure all arguments are present in older models
40 | base_architecture_switch(args)
41 |
42 | if args.encoder_layers_to_keep:
43 | args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
44 | if args.decoder_layers_to_keep:
45 | args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
46 |
47 | if getattr(args, "max_source_positions", None) is None:
48 | args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
49 | if getattr(args, "max_target_positions", None) is None:
50 | args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
51 |
52 | src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
53 |
54 | if args.share_all_embeddings:
55 | if src_dict != tgt_dict:
56 | raise ValueError("--share-all-embeddings requires a joined dictionary")
57 | if args.encoder_embed_dim != args.decoder_embed_dim:
58 | raise ValueError(
59 | "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
60 | )
61 | if args.decoder_embed_path and (
62 | args.decoder_embed_path != args.encoder_embed_path
63 | ):
64 | raise ValueError(
65 | "--share-all-embeddings not compatible with --decoder-embed-path"
66 | )
67 | encoder_embed_tokens = cls.build_embedding(
68 | args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
69 | )
70 | decoder_embed_tokens = encoder_embed_tokens
71 | args.share_decoder_input_output_embed = True
72 | else:
73 | encoder_embed_tokens = cls.build_embedding(
74 | args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
75 | )
76 | decoder_embed_tokens = cls.build_embedding(
77 | args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
78 | )
79 | if getattr(args, "offload_activations", False):
80 | args.checkpoint_activations = True # offloading implies checkpointing
81 |
82 | encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
83 | decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
84 |
85 | if not args.share_all_embeddings:
86 | min_params_to_wrap = getattr(
87 | args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
88 | )
89 | # fsdp_wrap is a no-op when --ddp-backend != fully_sharded
90 | encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
91 | decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
92 | return cls(args, encoder, decoder)
93 |
94 | @classmethod
95 | def build_encoder(cls, args, src_dict, embed_tokens):
96 | return SwitchTransformerEncoder(args, src_dict, embed_tokens)
97 |
98 | @classmethod
99 | def build_decoder(cls, args, tgt_dict, embed_tokens):
100 | return SwitchTransformerDecoder(
101 | args,
102 | tgt_dict,
103 | embed_tokens,
104 | no_encoder_attn=getattr(args, "no_cross_attention", False),
105 | )
106 |
107 |
108 | class SwitchTransformerEncoder(TransformerEncoder):
109 | """
110 | Transformer encoder consisting of *args.encoder_layers* layers. Each layer
111 | is a :class:`SwitchTransformerEncoderLayer`.
112 |
113 | Args:
114 | args (argparse.Namespace): parsed command-line arguments
115 | dictionary (~fairseq.data.Dictionary): encoding dictionary
116 | embed_tokens (torch.nn.Embedding): input embedding
117 | """
118 |
119 | def __init__(self, args, dictionary, embed_tokens):
120 | self.layer_idx = 0 # for building encoder layers
121 | super().__init__(args, dictionary, embed_tokens)
122 | assert self.layer_idx == args.encoder_layers
123 |
124 | def build_encoder_layer(self, args):
125 | layer = SwitchTransformerEncoderLayer(args, layer_idx=self.layer_idx)
126 | checkpoint = getattr(args, "checkpoint_activations", False)
127 | if checkpoint:
128 | offload_to_cpu = getattr(args, "offload_activations", False)
129 | layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
130 | # if we are checkpointing, enforce that FSDP always wraps the
131 | # checkpointed layer, regardless of layer size
132 | min_params_to_wrap = (
133 | getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
134 | if not checkpoint
135 | else 0
136 | )
137 | layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
138 | self.layer_idx += 1
139 | return layer
140 |
141 | def forward(
142 | self,
143 | src_tokens,
144 | src_lengths: Optional[torch.Tensor] = None,
145 | return_all_hiddens: bool = False,
146 | token_embeddings: Optional[torch.Tensor] = None,
147 | ):
148 | """
149 | Args:
150 | src_tokens (LongTensor): tokens in the source language of shape
151 | `(batch, src_len)`
152 | src_lengths (torch.LongTensor): lengths of each source sentence of
153 | shape `(batch)`
154 | return_all_hiddens (bool, optional): also return all of the
155 | intermediate hidden states (default: False).
156 | token_embeddings (torch.Tensor, optional): precomputed embeddings
157 | default `None` will recompute embeddings
158 |
159 | Returns:
160 | dict:
161 | - **encoder_out** (Tensor): the last encoder layer's output of
162 | shape `(src_len, batch, embed_dim)`
163 | - **encoder_padding_mask** (ByteTensor): the positions of
164 | padding elements of shape `(batch, src_len)`
165 | - **encoder_embedding** (Tensor): the (scaled) embedding lookup
166 | of shape `(batch, src_len, embed_dim)`
167 | - **encoder_states** (List[Tensor]): all intermediate
168 | hidden states of shape `(src_len, batch, embed_dim)`.
169 | Only populated if *return_all_hiddens* is True.
170 | """
171 | return self.forward_scriptable(
172 | src_tokens, src_lengths, return_all_hiddens, token_embeddings
173 | )
174 |
175 | # TorchScript doesn't support super() method so that the scriptable Subclass
176 | # can't access the base class model in Torchscript.
177 | # Current workaround is to add a helper function with different name and
178 | # call the helper function from scriptable Subclass.
179 | def forward_scriptable(
180 | self,
181 | src_tokens,
182 | src_lengths: Optional[torch.Tensor] = None,
183 | return_all_hiddens: bool = False,
184 | token_embeddings: Optional[torch.Tensor] = None,
185 | ):
186 | """
187 | Args:
188 | src_tokens (LongTensor): tokens in the source language of shape
189 | `(batch, src_len)`
190 | src_lengths (torch.LongTensor): lengths of each source sentence of
191 | shape `(batch)`
192 | return_all_hiddens (bool, optional): also return all of the
193 | intermediate hidden states (default: False).
194 | token_embeddings (torch.Tensor, optional): precomputed embeddings
195 | default `None` will recompute embeddings
196 |
197 | Returns:
198 | dict:
199 | - **encoder_out** (Tensor): the last encoder layer's output of
200 | shape `(src_len, batch, embed_dim)`
201 | - **encoder_padding_mask** (ByteTensor): the positions of
202 | padding elements of shape `(batch, src_len)`
203 | - **encoder_embedding** (Tensor): the (scaled) embedding lookup
204 | of shape `(batch, src_len, embed_dim)`
205 | - **encoder_states** (List[Tensor]): all intermediate
206 | hidden states of shape `(src_len, batch, embed_dim)`.
207 | Only populated if *return_all_hiddens* is True.
208 | """
209 | # compute padding mask
210 | encoder_padding_mask = src_tokens.eq(self.padding_idx)
211 | has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any()
212 |
213 | x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
214 |
215 | # account for padding while computing the representation
216 | if has_pads:
217 | x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
218 |
219 | # B x T x C -> T x B x C
220 | x = x.transpose(0, 1)
221 |
222 | encoder_states = []
223 |
224 | if return_all_hiddens:
225 | encoder_states.append(x)
226 |
227 | # encoder layers
228 | load = []
229 | balance_loss = 0
230 | for layer in self.layers:
231 | x, layer_load, layer_balance = layer(
232 | x,
233 | encoder_padding_mask=encoder_padding_mask if has_pads else None,
234 | )
235 | if return_all_hiddens:
236 | assert encoder_states is not None
237 | encoder_states.append(x)
238 | if layer_load is not None:
239 | load.append(layer_load)
240 | if layer_balance is not None:
241 | balance_loss = balance_loss + layer_balance
242 |
243 | if len(load) > 0:
244 | load = torch.vstack(load)
245 | load = load / load.sum(1, keepdim=True)
246 | else:
247 | load = None
248 | balance_loss = None
249 |
250 | if self.layer_norm is not None:
251 | x = self.layer_norm(x)
252 |
253 | # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
254 | # `forward` so we use a dictionary instead.
255 | # TorchScript does not support mixed values so the values are all lists.
256 | # The empty list is equivalent to None.
257 | return {
258 | "encoder_out": [x], # T x B x C
259 | "encoder_padding_mask": [encoder_padding_mask], # B x T
260 | "encoder_embedding": [encoder_embedding], # B x T x C
261 | "encoder_states": encoder_states, # List[T x B x C]
262 | "src_tokens": [],
263 | "src_lengths": [],
264 | "load": load,
265 | "balance_loss": balance_loss
266 | }
267 |
268 |
269 | class SwitchTransformerDecoder(TransformerDecoder):
270 | """
271 | Transformer decoder consisting of *args.decoder_layers* layers. Each layer
272 | is a :class:`SwitchTransformerDecoderLayer`.
273 |
274 | Args:
275 | args (argparse.Namespace): parsed command-line arguments
276 | dictionary (~fairseq.data.Dictionary): decoding dictionary
277 | embed_tokens (torch.nn.Embedding): output embedding
278 | no_encoder_attn (bool, optional): whether to attend to encoder outputs
279 | (default: False).
280 | """
281 |
282 | def __init__(
283 | self,
284 | args,
285 | dictionary,
286 | embed_tokens,
287 | no_encoder_attn=False,
288 | output_projection=None,
289 | ):
290 | self.layer_idx = 0 # for building decoder layers
291 | super().__init__(args, dictionary, embed_tokens, no_encoder_attn, output_projection)
292 | assert self.layer_idx == args.decoder_layers
293 |
294 | def build_decoder_layer(self, args, no_encoder_attn=False):
295 | layer = SwitchTransformerDecoderLayer(args, no_encoder_attn, layer_idx=self.layer_idx)
296 | checkpoint = getattr(args, "checkpoint_activations", False)
297 | if checkpoint:
298 | offload_to_cpu = getattr(args, "offload_activations", False)
299 | layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
300 | # if we are checkpointing, enforce that FSDP always wraps the
301 | # checkpointed layer, regardless of layer size
302 | min_params_to_wrap = (
303 | getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
304 | if not checkpoint
305 | else 0
306 | )
307 | layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
308 | self.layer_idx += 1
309 | return layer
310 |
311 | def extract_features_scriptable(
312 | self,
313 | prev_output_tokens,
314 | encoder_out: Optional[Dict[str, List[Tensor]]],
315 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
316 | full_context_alignment: bool = False,
317 | alignment_layer: Optional[int] = None,
318 | alignment_heads: Optional[int] = None,
319 | ):
320 | """
321 | Similar to *forward* but only return features.
322 |
323 | Includes several features from "Jointly Learning to Align and
324 | Translate with Transformer Models" (Garg et al., EMNLP 2019).
325 |
326 | Args:
327 | full_context_alignment (bool, optional): don't apply
328 | auto-regressive mask to self-attention (default: False).
329 | alignment_layer (int, optional): return mean alignment over
330 | heads at this layer (default: last layer).
331 | alignment_heads (int, optional): only average alignment over
332 | this many heads (default: all heads).
333 |
334 | Returns:
335 | tuple:
336 | - the decoder's features of shape `(batch, tgt_len, embed_dim)`
337 | - a dictionary with any model-specific outputs
338 | """
339 | bs, slen = prev_output_tokens.size()
340 | if alignment_layer is None:
341 | alignment_layer = self.num_layers - 1
342 |
343 | enc: Optional[Tensor] = None
344 | padding_mask: Optional[Tensor] = None
345 | if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
346 | enc = encoder_out["encoder_out"][0]
347 | assert (
348 | enc.size()[1] == bs
349 | ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
350 | if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
351 | padding_mask = encoder_out["encoder_padding_mask"][0]
352 |
353 | # embed positions
354 | positions = None
355 | if self.embed_positions is not None:
356 | positions = self.embed_positions(
357 | prev_output_tokens, incremental_state=incremental_state
358 | )
359 |
360 | if incremental_state is not None:
361 | prev_output_tokens = prev_output_tokens[:, -1:]
362 | if positions is not None:
363 | positions = positions[:, -1:]
364 |
365 | # embed tokens and positions
366 | x = self.embed_scale * self.embed_tokens(prev_output_tokens)
367 |
368 | if self.quant_noise is not None:
369 | x = self.quant_noise(x)
370 |
371 | if self.project_in_dim is not None:
372 | x = self.project_in_dim(x)
373 |
374 | if positions is not None:
375 | x += positions
376 |
377 | if self.layernorm_embedding is not None:
378 | x = self.layernorm_embedding(x)
379 |
380 | x = self.dropout_module(x)
381 |
382 | # B x T x C -> T x B x C
383 | x = x.transpose(0, 1)
384 |
385 | self_attn_padding_mask: Optional[Tensor] = None
386 | if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
387 | self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
388 |
389 | # decoder layers
390 | attn: Optional[Tensor] = None
391 | inner_states: List[Optional[Tensor]] = [x]
392 | load = []
393 | balance_loss = 0
394 | for idx, layer in enumerate(self.layers):
395 | if incremental_state is None and not full_context_alignment:
396 | self_attn_mask = self.buffered_future_mask(x)
397 | else:
398 | self_attn_mask = None
399 |
400 | x, layer_attn, _, layer_load, layer_balance = layer(
401 | x,
402 | enc,
403 | padding_mask,
404 | incremental_state,
405 | self_attn_mask=self_attn_mask,
406 | self_attn_padding_mask=self_attn_padding_mask,
407 | need_attn=bool((idx == alignment_layer)),
408 | need_head_weights=bool((idx == alignment_layer)),
409 | )
410 | inner_states.append(x)
411 | if layer_load is not None:
412 | load.append(layer_load)
413 | if layer_balance is not None:
414 | balance_loss = balance_loss + layer_balance
415 | if layer_attn is not None and idx == alignment_layer:
416 | attn = layer_attn.float().to(x)
417 |
418 | load = torch.vstack(load)
419 | load = load / load.sum(1, keepdim=True)
420 |
421 | if attn is not None:
422 | if alignment_heads is not None:
423 | attn = attn[:alignment_heads]
424 |
425 | # average probabilities over heads
426 | attn = attn.mean(dim=0)
427 |
428 | if self.layer_norm is not None:
429 | x = self.layer_norm(x)
430 |
431 | # T x B x C -> B x T x C
432 | x = x.transpose(0, 1)
433 |
434 | if self.project_out_dim is not None:
435 | x = self.project_out_dim(x)
436 |
437 | return x, {"attn": [attn], "inner_states": inner_states, "load": load, "balance_loss": balance_loss}
438 |
439 |
440 | @register_model_architecture('switch_transformer', 'switch_transformer')
441 | def base_architecture_switch(args):
442 | base_architecture(args)
443 |
444 |
445 | @register_model_architecture("switch_transformer", "switch_transformer_iwslt_de_en")
446 | def switch_transformer_iwslt_de_en(args):
447 | transformer_iwslt_de_en(args)
448 |
449 |
450 | @register_model_architecture('switch_transformer', 'switch_transformer_wmt_en_de_big')
451 | def switch_transformer_wmt_en_de_big(args):
452 | transformer_wmt_en_de_big(args)
453 |
454 |
455 | @register_model_architecture('switch_transformer', 'switch_transformer_vaswani_wmt_en_de_big')
456 | def switch_transformer_vaswani_wmt_en_de_big(args):
457 | transformer_vaswani_wmt_en_de_big(args)
458 |
459 |
460 | @register_model_architecture("switch_transformer", "switch_transformer_vaswani_wmt_en_fr_big")
461 | def switch_transformer_vaswani_wmt_en_fr_big(args):
462 | transformer_vaswani_wmt_en_fr_big(args)
463 |
--------------------------------------------------------------------------------
/example/moe/transformer_switch_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torch import Tensor
6 | from typing import Dict, List, Optional
7 |
8 | from sparsemixer import get_router
9 |
10 | from fairseq.modules.quant_noise import quant_noise
11 | from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
12 |
13 |
14 | def use_switch(layer_idx):
15 | return layer_idx % 2 == 0
16 |
17 |
18 | class SwitchTransformerEncoderLayer(TransformerEncoderLayer):
19 | """Encoder layer block.
20 |
21 | In the original paper each operation (multi-head attention or FFN) is
22 | postprocessed with: `dropout -> add residual -> layernorm`. In the
23 | tensor2tensor code they suggest that learning is more robust when
24 | preprocessing each layer with layernorm and postprocessing with:
25 | `dropout -> add residual`. We default to the approach in the paper, but the
26 | tensor2tensor approach can be enabled by setting
27 | *args.encoder_normalize_before* to ``True``.
28 |
29 | Args:
30 | args (argparse.Namespace): parsed command-line arguments
31 | """
32 |
33 | def __init__(self, args, layer_idx=-1):
34 | self.num_experts = args.num_experts
35 | self.load_balancing = args.load_balancing
36 | self.use_switch = use_switch(layer_idx)
37 | super().__init__(args)
38 | self.gating_network = nn.Linear(args.encoder_embed_dim, args.num_experts)
39 | if self.use_switch:
40 | self.router = get_router(args.router)(args.num_experts, args.encoder_embed_dim, args.load_balancing, args.jitter_eps)
41 | else:
42 | self.router = None
43 |
44 | def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
45 | if self.use_switch:
46 | return nn.ModuleList(
47 | [quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
48 | for _ in range(self.num_experts)]
49 | )
50 | else:
51 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
52 |
53 | def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
54 | if self.use_switch:
55 | return nn.ModuleList(
56 | [quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
57 | for _ in range(self.num_experts)]
58 | )
59 | else:
60 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
61 |
62 | def forward(
63 | self,
64 | x,
65 | encoder_padding_mask: Optional[Tensor],
66 | attn_mask: Optional[Tensor] = None,
67 | ):
68 | """
69 | Args:
70 | x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
71 | encoder_padding_mask (ByteTensor): binary ByteTensor of shape
72 | `(batch, seq_len)` where padding elements are indicated by ``1``.
73 | attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
74 | where `tgt_len` is the length of output and `src_len` is the
75 | length of input, though here both are equal to `seq_len`.
76 | `attn_mask[tgt_i, src_j] = 1` means that when calculating the
77 | embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
78 | useful for strided self-attention.
79 |
80 | Returns:
81 | encoded output of shape `(seq_len, batch, embed_dim)`
82 | """
83 | # anything in original attn_mask = 1, becomes -1e8
84 | # anything in original attn_mask = 0, becomes 0
85 | # Note that we cannot use -inf here, because at some edge cases,
86 | # the attention weight (before softmax) for some padded element in query
87 | # will become -inf, which results in NaN in model parameters
88 | if attn_mask is not None:
89 | attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
90 |
91 | residual = x
92 | if self.normalize_before:
93 | x = self.self_attn_layer_norm(x)
94 | x, _ = self.self_attn(
95 | query=x,
96 | key=x,
97 | value=x,
98 | key_padding_mask=encoder_padding_mask,
99 | need_weights=False,
100 | attn_mask=attn_mask,
101 | )
102 | x = self.dropout_module(x)
103 | x = self.residual_connection(x, residual)
104 | if not self.normalize_before:
105 | x = self.self_attn_layer_norm(x)
106 |
107 | residual = x
108 | if self.normalize_before:
109 | x = self.final_layer_norm(x)
110 |
111 | num_tokens = None
112 | balance_loss = 0.0
113 | if self.use_switch:
114 | seq_len, bsz, dim = x.shape
115 | x = x.view(-1, dim)
116 | logits = self.gating_network(x)
117 | sample, multiplier, balance_loss = self.router(logits)
118 |
119 | order = sample.argsort(0).squeeze(-1)
120 | num_tokens = F.one_hot(sample.squeeze(), self.num_experts).gt(0).sum(0)
121 | x = x[order] # reorder according to expert number
122 | x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts
123 |
124 | def forward_fc(input_x, expert_idx):
125 | if input_x.numel() > 0:
126 | input_x = self.activation_fn(self.fc1[expert_idx](input_x))
127 | input_x = self.activation_dropout_module(input_x)
128 | input_x = self.fc2[expert_idx](input_x)
129 | return input_x
130 | x = torch.vstack(
131 | [forward_fc(x[i], i) for i in range(self.num_experts)]
132 | )
133 | x = x[order.argsort()] * multiplier
134 |
135 | x = x.view(seq_len, bsz, dim)
136 | else:
137 | x = self.activation_fn(self.fc1(x))
138 | x = self.activation_dropout_module(x)
139 | x = self.fc2(x)
140 |
141 | x = self.dropout_module(x)
142 | x = self.residual_connection(x, residual)
143 |
144 | if not self.normalize_before:
145 | x = self.final_layer_norm(x)
146 | return x, num_tokens, balance_loss
147 |
148 |
149 | class SwitchTransformerDecoderLayer(TransformerDecoderLayer):
150 | """Decoder layer block.
151 |
152 | In the original paper each operation (multi-head attention, encoder
153 | attention or FFN) is postprocessed with: `dropout -> add residual ->
154 | layernorm`. In the tensor2tensor code they suggest that learning is more
155 | robust when preprocessing each layer with layernorm and postprocessing with:
156 | `dropout -> add residual`. We default to the approach in the paper, but the
157 | tensor2tensor approach can be enabled by setting
158 | *args.decoder_normalize_before* to ``True``.
159 |
160 | Args:
161 | args (argparse.Namespace): parsed command-line arguments
162 | no_encoder_attn (bool, optional): whether to attend to encoder outputs
163 | (default: False).
164 | """
165 |
166 | def __init__(
167 | self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, layer_idx=-1,
168 | ):
169 | self.num_experts = args.num_experts
170 | self.load_balancing = args.load_balancing
171 | self.use_switch = use_switch(layer_idx)
172 | super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
173 | self.gating_network = nn.Linear(args.decoder_embed_dim, args.num_experts)
174 | if self.use_switch:
175 | self.router = get_router(args.router)(args.num_experts, args.decoder_embed_dim, args.load_balancing, args.jitter_eps)
176 | else:
177 | self.router = None
178 |
179 | def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
180 | if self.use_switch:
181 | return nn.ModuleList(
182 | [quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
183 | for _ in range(self.num_experts)]
184 | )
185 | else:
186 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
187 |
188 | def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
189 | if self.use_switch:
190 | return nn.ModuleList(
191 | [quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
192 | for _ in range(self.num_experts)]
193 | )
194 | else:
195 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
196 |
197 | def forward(
198 | self,
199 | x,
200 | encoder_out: Optional[torch.Tensor] = None,
201 | encoder_padding_mask: Optional[torch.Tensor] = None,
202 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
203 | prev_self_attn_state: Optional[List[torch.Tensor]] = None,
204 | prev_attn_state: Optional[List[torch.Tensor]] = None,
205 | self_attn_mask: Optional[torch.Tensor] = None,
206 | self_attn_padding_mask: Optional[torch.Tensor] = None,
207 | need_attn: bool = False,
208 | need_head_weights: bool = False,
209 | ):
210 | """
211 | Args:
212 | x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
213 | encoder_padding_mask (ByteTensor, optional): binary
214 | ByteTensor of shape `(batch, src_len)` where padding
215 | elements are indicated by ``1``.
216 | need_attn (bool, optional): return attention weights
217 | need_head_weights (bool, optional): return attention weights
218 | for each head (default: return average over heads).
219 |
220 | Returns:
221 | encoded output of shape `(seq_len, batch, embed_dim)`
222 | """
223 | if need_head_weights:
224 | need_attn = True
225 |
226 | residual = x
227 | if self.normalize_before:
228 | x = self.self_attn_layer_norm(x)
229 | if prev_self_attn_state is not None:
230 | prev_key, prev_value = prev_self_attn_state[:2]
231 | saved_state: Dict[str, Optional[Tensor]] = {
232 | "prev_key": prev_key,
233 | "prev_value": prev_value,
234 | }
235 | if len(prev_self_attn_state) >= 3:
236 | saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
237 | assert incremental_state is not None
238 | self.self_attn._set_input_buffer(incremental_state, saved_state)
239 | _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
240 | if self.cross_self_attention and not (
241 | incremental_state is not None
242 | and _self_attn_input_buffer is not None
243 | and "prev_key" in _self_attn_input_buffer
244 | ):
245 | if self_attn_mask is not None:
246 | assert encoder_out is not None
247 | self_attn_mask = torch.cat(
248 | (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
249 | )
250 | if self_attn_padding_mask is not None:
251 | if encoder_padding_mask is None:
252 | assert encoder_out is not None
253 | encoder_padding_mask = self_attn_padding_mask.new_zeros(
254 | encoder_out.size(1), encoder_out.size(0)
255 | )
256 | self_attn_padding_mask = torch.cat(
257 | (encoder_padding_mask, self_attn_padding_mask), dim=1
258 | )
259 | assert encoder_out is not None
260 | y = torch.cat((encoder_out, x), dim=0)
261 | else:
262 | y = x
263 |
264 | x, attn = self.self_attn(
265 | query=x,
266 | key=y,
267 | value=y,
268 | key_padding_mask=self_attn_padding_mask,
269 | incremental_state=incremental_state,
270 | need_weights=False,
271 | attn_mask=self_attn_mask,
272 | )
273 | x = self.dropout_module(x)
274 | x = self.residual_connection(x, residual)
275 | if not self.normalize_before:
276 | x = self.self_attn_layer_norm(x)
277 |
278 | if self.encoder_attn is not None and encoder_out is not None:
279 | residual = x
280 | if self.normalize_before:
281 | x = self.encoder_attn_layer_norm(x)
282 | if prev_attn_state is not None:
283 | prev_key, prev_value = prev_attn_state[:2]
284 | saved_state: Dict[str, Optional[Tensor]] = {
285 | "prev_key": prev_key,
286 | "prev_value": prev_value,
287 | }
288 | if len(prev_attn_state) >= 3:
289 | saved_state["prev_key_padding_mask"] = prev_attn_state[2]
290 | assert incremental_state is not None
291 | self.encoder_attn._set_input_buffer(incremental_state, saved_state)
292 |
293 | x, attn = self.encoder_attn(
294 | query=x,
295 | key=encoder_out,
296 | value=encoder_out,
297 | key_padding_mask=encoder_padding_mask,
298 | incremental_state=incremental_state,
299 | static_kv=True,
300 | need_weights=need_attn or (not self.training and self.need_attn),
301 | need_head_weights=need_head_weights,
302 | )
303 | x = self.dropout_module(x)
304 | x = self.residual_connection(x, residual)
305 | if not self.normalize_before:
306 | x = self.encoder_attn_layer_norm(x)
307 |
308 | residual = x
309 | if self.normalize_before:
310 | x = self.final_layer_norm(x)
311 |
312 | num_tokens = None
313 | balance_loss = 0.0
314 | if self.use_switch:
315 | seq_len, bsz, dim = x.shape
316 | x = x.view(-1, dim)
317 | logits = self.gating_network(x)
318 | sample, multiplier, balance_loss = self.router(logits)
319 |
320 | order = sample.argsort(0).squeeze(-1)
321 | num_tokens = F.one_hot(sample.squeeze(), self.num_experts).gt(0).sum(0)
322 | x = x[order] # reorder according to expert number
323 | x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts
324 |
325 | def forward_fc(input_x, expert_idx):
326 | if input_x.numel() > 0:
327 | input_x = self.activation_fn(self.fc1[expert_idx](input_x))
328 | input_x = self.activation_dropout_module(input_x)
329 | input_x = self.fc2[expert_idx](input_x)
330 | return input_x
331 | x = torch.vstack(
332 | [forward_fc(x[i], i) for i in range(self.num_experts)]
333 | )
334 | x = x[order.argsort()] * multiplier
335 |
336 | x = x.view(seq_len, bsz, dim)
337 | else:
338 | x = self.activation_fn(self.fc1(x))
339 | x = self.activation_dropout_module(x)
340 | x = self.fc2(x)
341 |
342 | x = self.dropout_module(x)
343 | x = self.residual_connection(x, residual)
344 | if not self.normalize_before:
345 | x = self.final_layer_norm(x)
346 | if self.onnx_trace and incremental_state is not None:
347 | saved_state = self.self_attn._get_input_buffer(incremental_state)
348 | assert saved_state is not None
349 | if self_attn_padding_mask is not None:
350 | self_attn_state = [
351 | saved_state["prev_key"],
352 | saved_state["prev_value"],
353 | saved_state["prev_key_padding_mask"],
354 | ]
355 | else:
356 | self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
357 | return x, attn, self_attn_state
358 | return x, attn, None, num_tokens, balance_loss
359 |
--------------------------------------------------------------------------------
/example/moe/translation_switch.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import json
3 | import torch
4 |
5 | from argparse import Namespace
6 | from dataclasses import dataclass, field
7 | from omegaconf import II
8 |
9 | from fairseq import metrics, models
10 | from fairseq.data import encoders
11 | from fairseq.dataclass import ChoiceEnum
12 | from fairseq.optim.amp_optimizer import AMPOptimizer
13 | from fairseq.tasks import register_task
14 | from fairseq.tasks.translation import TranslationConfig, TranslationTask
15 |
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | @dataclass
21 | class TranslationSwitchConfig(TranslationConfig):
22 | num_experts: int = field(
23 | default=4,
24 | metadata={"help": "number of experts"},
25 | )
26 | router: ChoiceEnum(['SparseMixer', 'SwitchGate']) = field(
27 | default='SparseMixer',
28 | metadata={"help": "choice of router"},
29 | )
30 | load_balancing: bool = field(
31 | default=False,
32 | metadata={"help": "whether to use load balancing"},
33 | )
34 | gumbel: bool = field(
35 | default=False,
36 | metadata={"help": "use gumbel logits for computing balancing loss"},
37 | )
38 | jitter_eps: float = field(
39 | default=0.1,
40 | metadata={"help": "jitter eps"},
41 | )
42 | load_balancing_alpha: float = field(
43 | default=0.01,
44 | metadata={"help": "weight of load balancing loss"},
45 | )
46 | sentence_avg: bool = II("optimization.sentence_avg")
47 |
48 |
49 | @register_task("translation_switch", dataclass=TranslationSwitchConfig)
50 | class TranslationSwitchTask(TranslationTask):
51 | """
52 | Translation task for Switch Transformer models.
53 |
54 | Args:
55 | src_dict (~fairseq.data.Dictionary): dictionary for the source language
56 | tgt_dict (~fairseq.data.Dictionary): dictionary for the target language
57 |
58 | .. note::
59 |
60 | The translation task is compatible with :mod:`fairseq-train`,
61 | :mod:`fairseq-generate` and :mod:`fairseq-interactive`.
62 |
63 | The translation task provides the following additional command-line
64 | arguments:
65 |
66 | .. argparse::
67 | :ref: fairseq.tasks.translation_parser
68 | :prog:
69 | """
70 |
71 | cfg: TranslationSwitchConfig
72 |
73 | def __init__(self, cfg: TranslationSwitchConfig, src_dict, tgt_dict):
74 | super().__init__(cfg, src_dict, tgt_dict)
75 |
76 | def build_model(self, cfg):
77 | model = models.build_model(cfg, self)
78 |
79 | if self.cfg.eval_bleu:
80 | detok_args = json.loads(self.cfg.eval_bleu_detok_args)
81 | self.tokenizer = encoders.build_tokenizer(
82 | Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args)
83 | )
84 |
85 | gen_args = json.loads(self.cfg.eval_bleu_args)
86 | self.sequence_generator = self.build_generator(
87 | [model], Namespace(**gen_args)
88 | )
89 |
90 | return model
91 |
92 | def _get_loss(self, sample, model, criterion):
93 | assert hasattr(
94 | criterion, "compute_loss"
95 | ), "translation_switch task requires the criterion to implement the compute_loss() method"
96 |
97 | encoder_out = model.encoder(
98 | src_tokens=sample["net_input"]["src_tokens"],
99 | src_lengths=sample["net_input"]["src_lengths"],
100 | )
101 | net_output = model.decoder(
102 | prev_output_tokens=sample["net_input"]["prev_output_tokens"],
103 | encoder_out=encoder_out,
104 | src_lengths=sample["net_input"]["src_lengths"],
105 | )
106 | loss, nll_loss = criterion.compute_loss(model, net_output, sample, reduce=True)
107 |
108 | balance_loss = None
109 | if self.cfg.load_balancing:
110 | balance_loss = net_output[1]["balance_loss"]
111 | if encoder_out["balance_loss"] is not None:
112 | balance_loss = balance_loss + encoder_out["balance_loss"]
113 | loss = loss + balance_loss * self.cfg.load_balancing_alpha
114 |
115 | if 'load' in net_output[1]:
116 | load = net_output[1]["load"]
117 | if encoder_out["load"] is not None:
118 | load = torch.cat((encoder_out["load"], load), dim=0)
119 | else:
120 | load = torch.Tensor([1.])
121 |
122 | sample_size = (
123 | sample["target"].size(0) if criterion.sentence_avg else sample["ntokens"]
124 | )
125 | logging_output = {
126 | "loss": loss.data,
127 | "nll_loss": nll_loss.data,
128 | "ntokens": sample["ntokens"],
129 | "nsentences": sample["target"].size(0),
130 | "sample_size": sample_size,
131 | "load": (load * 100).long(),
132 | "balance_loss": balance_loss.data if balance_loss is not None else 0.0,
133 | }
134 | return loss, sample_size, logging_output
135 |
136 | def train_step(
137 | self, sample, model, criterion, optimizer, update_num, ignore_grad=False
138 | ):
139 | model.train()
140 | model.set_num_updates(update_num)
141 | with torch.autograd.profiler.record_function("forward"):
142 | with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))):
143 | loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
144 | if ignore_grad:
145 | loss *= 0
146 | with torch.autograd.profiler.record_function("backward"):
147 | optimizer.backward(loss)
148 | return loss, sample_size, logging_output
149 |
150 | def reduce_metrics(self, logging_outputs, criterion):
151 | super().reduce_metrics(logging_outputs, criterion)
152 |
153 | metrics.log_scalar(
154 | "load",
155 | sum(log["load"] for log in logging_outputs if "load" in log) / torch.cuda.device_count(),
156 | round=3,
157 | )
158 |
159 | temp = [log["balance_loss"] for log in logging_outputs if "balance_loss" in log]
160 | metrics.log_scalar("balance_loss", sum(temp), round=3)
161 |
--------------------------------------------------------------------------------
/example/requirements.txt:
--------------------------------------------------------------------------------
1 | cffi
2 | cython
3 | hydra-core>=1.0.7,<1.1
4 | omegaconf<2.1
5 | numpy>=1.22
6 | regex
7 | sacrebleu>=1.4.12
8 | tqdm
9 | bitarray
10 | tensorboardX
--------------------------------------------------------------------------------
/example/run_moe.sh:
--------------------------------------------------------------------------------
1 | EXPERTS=${1:-"4"}
2 | ROUTER=${2:-"SparseMixer"}
3 | DATA_FOLDER=${3:-"/mnt/azstorage/wmt14_en_de_joined_dict/"}
4 |
5 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
6 |
7 | pip install --upgrade numpy scipy
8 | pip install -r requirements.txt
9 | pip install --no-deps fairseq==0.12.1
10 |
11 | ARGS="--num-experts ${EXPERTS} --jitter-eps 0.1 --load-balancing-alpha 0.01 --router ${ROUTER}"
12 | echo $ARGS
13 |
14 | OUTPUT_FOLDER=./output-${EXPERTS}
15 |
16 | mkdir -p $OUTPUT_FOLDER
17 |
18 | touch $OUTPUT_FOLDER/train.log
19 |
20 | cp $OUTPUT_FOLDER/train.log ./train.log
21 |
22 | GPUCT=$(nvidia-smi --list-gpus | wc -l)
23 | UPDATE_FREQ=$((16/${GPUCT}))
24 | echo $UPDATE_FREQ
25 |
26 | unset RANK
27 | fairseq-train ${DATA_FOLDER} --num-workers 8 --ddp-backend=no_c10d \
28 | --user-dir moe --task translation_switch --load-balancing \
29 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
30 | --arch switch_transformer --share-all-embeddings \
31 | --source-lang en --target-lang de \
32 | --optimizer adam --adam-betas '(0.9,0.98)' --clip-norm 0.0 --lr-scheduler inverse_sqrt \
33 | --warmup-init-lr 1e-07 --warmup-updates 8000 --max-update 400000 \
34 | --lr 7e-4 --max-tokens 8192 --update-freq ${UPDATE_FREQ} \
35 | --weight-decay 0 --dropout 0.1 --activation-dropout 0.1 --attention-dropout 0.1 \
36 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \
37 | --log-format simple --log-interval 100 \
38 | --skip-invalid-size-inputs-valid-test \
39 | --best-checkpoint-metric loss --save-interval 10 \
40 | --encoder-layers 6 --decoder-layers 6 \
41 | --save-dir $OUTPUT_FOLDER $ARGS 2>&1 | tee -a ./train.log
42 |
43 | cp ./train.log $OUTPUT_FOLDER/train.log
44 |
45 | sleep 15m
46 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from setuptools import setup, find_packages
5 |
6 | def read_readme():
7 | with open('README.md') as f:
8 | return f.read()
9 |
10 | with open('HISTORY.rst') as history_file:
11 | history = history_file.read()
12 |
13 | requirements = [
14 | 'torch'
15 | ]
16 |
17 | setup(
18 | name='sparsemixer',
19 | version='0.1.0',
20 | description='SparseMixer Algorithm',
21 | long_description= read_readme(),
22 | long_description_content_type="text/markdown",
23 | author='Lucas Liu',
24 | author_email='llychinalz@gmail.com',
25 | url='https://github.com/microsoft/SparseMixer',
26 | packages=find_packages(exclude=['docs']),
27 | include_package_data=True,
28 | install_requires=requirements,
29 | license='MIT',
30 | zip_safe=False,
31 | classifiers=[
32 | 'Development Status :: 2 - Pre-Alpha',
33 | 'Intended Audience :: Developers',
34 | 'Natural Language :: English',
35 | 'Programming Language :: Python :: 3.7',
36 | 'Programming Language :: Python :: 3.8',
37 | 'Programming Language :: Python :: 3.9',
38 | ]
39 | )
40 |
41 | # python setup.py sdist bdist_wheel --universal
42 | # twine upload dist/*
--------------------------------------------------------------------------------
/sparsemixer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | __author__ = "Liyuan Liu"
5 |
6 | __maintainer__ = "Liyuan Liu"
7 | __email__ = "llychinalz@gmail.com"
8 |
9 | from .sparsemixer import SparseMixer
10 | from .switchgate import SwitchGate
11 |
12 | router_map = {
13 | 'sparsemixer': SparseMixer,
14 | 'switchgate': SwitchGate
15 | }
16 |
17 | def get_router(name):
18 | name=name.lower()
19 | assert name in router_map, f'Currently only supports SparseMixer and SwitchGate. {name} is not supported!'
20 |
21 | return router_map[name]
22 |
--------------------------------------------------------------------------------
/sparsemixer/sparsemixer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class SparseMixerCore(torch.autograd.Function):
6 | """
7 | `torch.autograd.Function` implementation of the balancing strategy used in SparseMixer.
8 | """
9 |
10 | @staticmethod
11 | def forward(
12 | ctx,
13 | multiplier: torch.Tensor,
14 | firstorder_mask: torch.Tensor,
15 | ):
16 | firstorder_mask = torch.add(0.5, firstorder_mask, alpha=0.5).type_as(multiplier)
17 | return multiplier * firstorder_mask # turns [0,1] into [0.5, 1]
18 |
19 | @staticmethod
20 | def backward(
21 | ctx,
22 | grad_at_multiplier: torch.Tensor,
23 | ):
24 | return grad_at_multiplier * 2, None
25 |
26 | class SparseMixer(nn.Module):
27 | def __init__(self, num_experts, embed_dim, compute_balance_loss=False, jitter_eps=0.1):
28 | super(SparseMixer, self).__init__()
29 | self.num_experts = num_experts
30 | self.compute_balance_loss = compute_balance_loss
31 | self.jitter_eps = jitter_eps
32 | self.embed_dim = embed_dim
33 | self.register_parameter('omega', torch.nn.Parameter(torch.ones(embed_dim)))
34 |
35 | def forward(self, logits):
36 |
37 | # masking out experts that are never sampled by jittering
38 | with torch.no_grad():
39 | mask_logits_threshold, max_ind = logits.max(dim=-1, keepdim=True)
40 | factor = logits.abs().clamp(min=mask_logits_threshold)
41 | mask_logits_threshold = (
42 | (mask_logits_threshold - logits) / factor
43 | ) > (2 * self.jitter_eps)
44 | logits = logits.masked_fill_(mask_logits_threshold, float('-inf'))
45 |
46 | p = logits.softmax(dim=-1)
47 | if self.training:
48 | sample = (
49 | logits - torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
50 | ).max(dim=-1)[1].unsqueeze(-1) # gumbel sampling, more robust than than the multinomial method
51 |
52 | multiplier = p.gather(dim=1, index=sample)
53 |
54 | mask_for_firstorder = torch.logical_or(
55 | sample == max_ind,
56 | torch.rand_like(multiplier) > 0.5
57 | ) # compute the mask for applying the first-order method
58 | multiplier = SparseMixerCore.apply(multiplier, mask_for_firstorder) # balance mid-point and euler
59 | else:
60 | sample = max_ind
61 | multiplier = p.gather(dim=1, index=sample)
62 |
63 | multiplier = multiplier * self.omega
64 | balance_loss = 0.0
65 | if self.compute_balance_loss:
66 | num_tokens = F.one_hot(sample.squeeze(-1), self.num_experts).gt(0).sum(0)
67 | f = num_tokens / (num_tokens.sum(0, keepdim=True) + 1e-6)
68 | pmean = p.view(-1, self.num_experts).mean(0)
69 | balance_loss = self.num_experts * torch.sum(pmean * f)
70 |
71 | return sample, multiplier, balance_loss
72 | # return sample, [multiplier], balance_loss
73 |
--------------------------------------------------------------------------------
/sparsemixer/sparsemixer_v2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | uniform_map: Dict[torch.device, Callable] = {}
6 | def multiplicative_jitter(input, epsilon, training):
7 |
8 | if epsilon == 0 or not training:
9 | return input
10 |
11 | uniform = uniform_map.get(input.device)
12 |
13 | if uniform is None:
14 | uniform = Uniform(low=torch.tensor(1.0 - epsilon, device=input.device, dtype=input.dtype),
15 | high=torch.tensor(1.0 + epsilon, device=input.device, dtype=input.dtype)
16 | ).rsample
17 | uniform_map[input.device] = uniform
18 |
19 | return input * uniform(input.shape)
20 |
21 | class v2core(torch.autograd.Function):
22 | @staticmethod
23 | def forward(
24 | ctx,
25 | scores: torch.Tensor,
26 | multiplier: torch.Tensor,
27 | selected_experts: torch.Tensor,
28 | masked_gates: torch.Tensor,
29 | mask_for_one: torch.Tensor,
30 | ):
31 | ctx.save_for_backward(multiplier, selected_experts, masked_gates)
32 | return multiplier * mask_for_one
33 |
34 | @staticmethod
35 | def backward(
36 | ctx,
37 | grad_at_output: torch.Tensor,
38 | ):
39 | multiplier, selected_experts, masked_gates = ctx.saved_tensors
40 |
41 | grad_at_output = grad_at_output * multiplier
42 |
43 | grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1)
44 | grad_at_scores_expaned.scatter_add_(
45 | dim=-1,
46 | index=selected_experts,
47 | src=grad_at_output,
48 | )
49 |
50 | return (
51 | grad_at_scores_expaned,
52 | None,
53 | None,
54 | None,
55 | None,
56 | )
57 |
58 | def sparsemixerv2_routing(scores, top_k, jitter_eps, training):
59 | assert top_k in [1, 2], "only top-1/2 gating has been tested!"
60 |
61 | original_gates = torch.softmax(scores, dim=-1)
62 | ################ first expert ################
63 |
64 | with torch.no_grad():
65 | # compute mask for sparsity
66 | mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
67 | factor = scores.abs().clamp(min=mask_logits_threshold)
68 | mask_logits_threshold = (
69 | (mask_logits_threshold - scores) / factor
70 | ) > (2 * jitter_eps)
71 |
72 | # apply mask
73 | masked_gates = scores.masked_fill(mask_logits_threshold, float('-inf'))
74 | if training:
75 | selected_experts = (
76 | masked_gates - torch.empty_like(masked_gates, memory_format=torch.legacy_contiguous_format).exponential_().log()
77 | ).max(dim=-1)[1].unsqueeze(-1) # gumbel sampling, more robust than than the multinomial method
78 | else:
79 | selected_experts = max_ind
80 |
81 | # compute scores for gradients
82 | masked_gates = torch.softmax(masked_gates, dim=-1)
83 |
84 | # compute midpoint mask
85 | max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True)
86 | mask_for_one = torch.logical_or(
87 | selected_experts == max_ind,
88 | torch.rand_like(max_scores) > 0.75 # Heun's third-order method: f(x) - f(0) = .25 f'(x) + .75 f'(x/3.)
89 | )
90 | # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
91 | mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates)
92 |
93 | multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
94 | multiplier = v2core.apply(
95 | scores,
96 | multiplier_o,
97 | selected_experts,
98 | masked_gates,
99 | mask_for_one,
100 | )
101 |
102 | ################ second expert ################
103 | if top_k > 1:
104 | # masked out first expert
105 | masked_scores = torch.scatter(
106 | scores,
107 | -1,
108 | selected_experts,
109 | float('-inf'),
110 | )
111 | with torch.no_grad():
112 | # compute mask for sparsity
113 | mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
114 | factor = scores.abs().clamp(min=mask_logits_threshold)
115 | mask_logits_threshold = (
116 | (mask_logits_threshold - scores) / factor
117 | ) > (2 * jitter_eps)
118 |
119 | # apply mask
120 | masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float('-inf'))
121 | if training:
122 | selected_experts_top2 = (
123 | masked_gates_top2 - torch.empty_like(masked_gates_top2, memory_format=torch.legacy_contiguous_format).exponential_().log()
124 | ).max(dim=-1)[1].unsqueeze(-1) # gumbel sampling, more robust than than the multinomial method
125 | else:
126 | selected_experts_top2 = max_ind
127 | # compute scores for gradients
128 | masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
129 |
130 | # compute midpoint mask
131 | max_scores, max_ind = masked_gates_top2.max(dim=-1, keepdim=True)
132 | mask_for_one_top2 = torch.logical_or(
133 | selected_experts_top2 == max_ind,
134 | torch.rand_like(max_scores).uniform_() > 0.75 # Heun's third-order method: f(x) - f(0) = .25 f'(x) + .75 f'(x/3.)
135 | )
136 | # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
137 | mask_for_one_top2 = torch.add(0.3333, mask_for_one_top2, alpha=0.6667).type_as(masked_gates_top2)
138 |
139 | multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2)
140 | multiplier_top2 = v2core.apply(
141 | scores,
142 | multiplier_top2_o,
143 | selected_experts_top2,
144 | masked_gates_top2,
145 | mask_for_one_top2,
146 | )
147 |
148 | multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
149 | selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)
150 |
151 | return (
152 | multiplier,
153 | original_gates,
154 | selected_experts,
155 | )
156 |
157 | class SparseMixerV2(nn.Module):
158 | def __init__(self, num_experts, embed_dim, compute_balance_loss=False, jitter_eps=0.1):
159 | super(SparseMixer, self).__init__()
160 | self.num_experts = num_experts
161 | self.compute_balance_loss = compute_balance_loss
162 | self.jitter_eps = jitter_eps
163 |
164 | def forward(self, logits):
165 |
166 | multiplier, original_gates, sample = sparsemixerv2_routing(logits, 1, self.jitter_eps, self.training)
167 |
168 | balance_loss = 0.0
169 | if self.compute_balance_loss:
170 | num_tokens = F.one_hot(sample.squeeze(-1), self.num_experts).gt(0).sum(0)
171 | f = num_tokens / (num_tokens.sum(0, keepdim=True) + 1e-6)
172 | pmean = p.view(-1, self.num_experts).mean(0)
173 | balance_loss = self.num_experts * torch.sum(pmean * f)
174 |
175 | return sample, multiplier, balance_loss
176 |
--------------------------------------------------------------------------------
/sparsemixer/switchgate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class SwitchGate(nn.Module):
7 | def __init__(self, num_experts, embed_dim, compute_balance_loss=False, jitter_eps=0.1):
8 | super(SwitchGate, self).__init__()
9 | self.num_experts = num_experts
10 | self.compute_balance_loss = compute_balance_loss
11 | self.jitter_eps = jitter_eps
12 |
13 | def forward(self, logits):
14 | if self.training:
15 | noise = torch.rand_like(logits)
16 | noise = noise * 2 * self.jitter_eps + 1.0 - self.jitter_eps
17 | logits = logits * noise
18 |
19 | p = logits.softmax(dim=-1)
20 | sample = torch.argmax(p, dim=-1)
21 |
22 | balance_loss = 0.0
23 | if self.compute_balance_loss:
24 | num_tokens = F.one_hot(sample, self.num_experts).gt(0).sum(0)
25 | f = num_tokens / (num_tokens.sum(0, keepdim=True) + 1e-6)
26 | pmean = p.view(-1, self.num_experts).mean(0)
27 | balance_loss = self.num_experts * torch.sum(pmean * f)
28 |
29 | multiplier = p.gather(dim=-1, index=sample.unsqueeze(1))
30 | return sample, multiplier, balance_loss
--------------------------------------------------------------------------------