├── .gitignore ├── CODE_OF_CONDUCT.md ├── HISTORY.rst ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── Transparency_FAQ.md ├── example ├── README.md ├── eval_moe.sh ├── moe │ ├── __init__.py │ ├── transformer_switch.py │ ├── transformer_switch_layer.py │ └── translation_switch.py ├── requirements.txt └── run_moe.sh ├── setup.py └── sparsemixer ├── __init__.py ├── sparsemixer.py ├── sparsemixer_v2.py └── switchgate.py /.gitignore: -------------------------------------------------------------------------------- 1 | output-* 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | History 2 | ======= 3 | 4 | 0.0.0 (2023/10/5) 5 | ------------------ 6 | * empty placeholder 7 | 8 | 0.1.0 (2023/10/5) 9 | ------------------ 10 | * implemented SparseMixer -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![PyTorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=flat&logo=PyTorch&logoColor=white) 2 | ![GitHub](https://img.shields.io/github/license/microsoft/sparsemixer) 3 | 4 |

SparseMixer

5 |

Sparse Backpropagation for Mixture-of-Expert Training

6 | 7 |

8 | Mixture-of-Expert • 9 | SparseMixer • 10 | How to Use? • 11 | Examples • 12 | Citation • 13 | License 14 |

15 | 16 | [SparseMixer](https://arxiv.org/abs/2310.00811), a scalable gradient estimator, bridges the gap between backpropagation and sparse expert routing. 17 | 18 |

What is Mixture-of-Expert

19 | The significant success of large-scale pre-training across various applications has underscored the imperative need for scalable models that are economically feasible. 20 | Recent advances in sparsely activated networks, prominently known as Mixture-of-Experts (MoE), have attracted widespread interest. 21 | Unlike traditional networks that densely activate all modules for all input, MoE selectively activates parts of modules to specific inputs through a process called {expert routing}, leading to notable efficiency enhancements. 22 | 23 | Numerous methods have emerged to bridge discrete and back-propagation, and most of them are based on Straight-Through (ST). 24 | Unfortunately, all existing ST estimators are incompatible with MoE, since they require activating all experts for gradient computing, thereby eliminating all the efficiency improvements of MoE. 25 | Consequently, typical MoE training strategically neglects the gradient computation for routing, trading certain training signals for sparse computation. 26 | Despite the scalability brought by sparse computation, this trade-off may result in slow convergence and improperly trained models. 27 | 28 |

Backpropagation Made Sparse

29 | 30 | We propose [SparseMixer](https://arxiv.org/abs/2310.00811), a scalable gradient estimator, bridges the gap between backpropagation and sparse expert routing. 31 | Grounded in a numerical ODE framework, SparseMixer harnesses the mid-point method, a second-order ODE solver, to deliver precise gradient approximations with negligible computational overhead. 32 | Applying SparseMixer to Switch Transformer on both pre-training and machine translation tasks, SparseMixer showcases considerable performance gain, accelerating training convergence up to 2 times 33 | 34 | ### How to use? 35 | 36 | `sparsemixer` can be installed via `pip` 37 | ``` 38 | pip install sparsemixer 39 | ``` 40 | 41 | ### Examples 42 | 43 | Please check the `example` folder for a working example. 44 | 45 | ### Citation 46 | Please cite the following papers if you found our model useful. Thanks! 47 | 48 | 49 | >Liyuan Liu, Jianfeng Gao, and Weizhu Chen (2023). Sparse Backpropagation for MoE Training. *ArXiv, abs/2310.00811*. 50 | ``` 51 | @inproceedings{liu2023bridging, 52 | title={Sparse Backpropagation for MoE Training}, 53 | author = {Liu, Liyuan and Gao, Jianfeng and Chen, Weizhu}, 54 | booktitle = {arXiv:2310.00811 [cs]}, 55 | year={2023} 56 | } 57 | ``` 58 | 59 | >Liyuan Liu, Chengyu Dong, Xiaodong Liu, Bin Yu, and Jianfeng Gao (2023). Bridging Discrete and Backpropagation: Straight-Through and Beyond. *ArXiv, abs/2304.08612*. 60 | ``` 61 | @inproceedings{liu2023bridging, 62 | title={Bridging Discrete and Backpropagation: Straight-Through and Beyond}, 63 | author = {Liu, Liyuan and Dong, Chengyu and Liu, Xiaodong and Yu, Bin and Gao, Jianfeng}, 64 | booktitle = {arXiv:2304.08612 [cs]}, 65 | year={2023} 66 | } 67 | ``` -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For help and questions about using this project, please contact [Liyuan Liu](https://liyuanlucasliu.github.io/). 10 | 11 | ## Microsoft Support Policy 12 | 13 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 14 | -------------------------------------------------------------------------------- /Transparency_FAQ.md: -------------------------------------------------------------------------------- 1 | # SparseMixer: Responsible AI Frequently Asked Questions 2 | 3 | ## What is SparseMixer? 4 | 5 | SparseMixer is a scalable gradient estimator for Mixture-of-Expert models. 6 | 7 | ## What can SparseMixer do? 8 | 9 | SparseMixer is able to provide reliable gradient approximation for expert routing, with only sparsely activated networks. 10 | 11 | ## What is SparseMixer’s intended use? 12 | 13 | SparseMixer aims to facilitate the training of Mixture-of-Expert models. 14 | 15 | ## How was SparseMixer evaluated? What metrics are used to measure performance? 16 | 17 | We conduct experiments on applying SparseMixer to Neural Machine Translation and Electra pre-training. 18 | SparseMixer consistently outperforms the baseline methods in all 8 settings. 19 | More details are elaborated in our paper. 20 | 21 | 22 | ## What are the limitations of SparseMixer? How can users minimize the impact of SparseMixer’s limitations when using the system? 23 | 24 | SparseMixer has first-order and second-order accuracy for gradient computation, and is only an approximation to the gradient. It excels as it facilitates a bias-variance tradeoff for the gradient estimation. 25 | 26 | ## How to use SparseMixer? 27 | 28 | `sparsemixer` can be installed via `pip` 29 | ``` 30 | pip install sparsemixer 31 | ``` 32 | Also, please check the `example` folder for a working example. -------------------------------------------------------------------------------- /example/README.md: -------------------------------------------------------------------------------- 1 | ## Neural Machine Translation 2 | 3 | Please note that, this working example is build upon the the [THOR](https://github.com/microsoft/Stochastic-Mixture-of-Experts/tree/main) repo. 4 | 5 | ### Pre-processing 6 | 7 | Please refer to the [Transformer-Clinic](https://github.com/LiyuanLucasLiu/Transformer-Clinic/blob/master/pre-process/wmt14en-de.sh) repo for data preparation. 8 | 9 | 10 | ### Environment 11 | 12 | We recommend to use the docker image `nvcr.io/nvidia/pytorch:22.02-py3` for this example. 13 | 14 | ### Training and EVALUATION 15 | 16 | ``` 17 | # for model training, the resulting model will be saved to `output-${NUM_OF_EXPERTS}`. 18 | bash run_moe.sh ${NUM_OF_EXPERTS} ${ROUTER} ${PATH_TO_DATA} 19 | 20 | # for model inference, the script will load model weights from `output-${NUM_OF_EXPERTS}/checkpoint_best.pt`. 21 | bash eval_moe.sh output-${NUM_OF_EXPERTS}/checkpoint_best.pt ${GPU_ID} ${PATH_TO_DATA} 22 | ``` 23 | -------------------------------------------------------------------------------- /example/eval_moe.sh: -------------------------------------------------------------------------------- 1 | MODELDIR=${1} 2 | DEVICE=${2} 3 | DATA_FOLDER=${3} 4 | 5 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python 6 | 7 | CUDA_VISIBLE_DEVICES=${DEVICE} fairseq-generate $DATA_FOLDER \ 8 | --path $MODELDIR \ 9 | --batch-size 128 --beam 4 --lenpen 0.6 --remove-bpe \ 10 | --quiet --fp16 --user-dir moe 11 | -------------------------------------------------------------------------------- /example/moe/__init__.py: -------------------------------------------------------------------------------- 1 | from . import translation_switch 2 | from . import transformer_switch 3 | -------------------------------------------------------------------------------- /example/moe/transformer_switch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import Tensor 4 | from typing import Dict, List, Optional 5 | 6 | from .transformer_switch_layer import SwitchTransformerDecoderLayer, SwitchTransformerEncoderLayer 7 | 8 | from fairseq.distributed import fsdp_wrap 9 | from fairseq.models import register_model, register_model_architecture 10 | from fairseq.models.transformer import ( 11 | base_architecture, 12 | transformer_iwslt_de_en, 13 | transformer_vaswani_wmt_en_de_big, 14 | transformer_vaswani_wmt_en_fr_big, 15 | transformer_wmt_en_de_big, 16 | TransformerDecoder, 17 | TransformerEncoder, 18 | TransformerModel 19 | ) 20 | from fairseq.modules.checkpoint_activations import checkpoint_wrapper 21 | 22 | 23 | DEFAULT_MAX_SOURCE_POSITIONS = 1024 24 | DEFAULT_MAX_TARGET_POSITIONS = 1024 25 | 26 | DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) 27 | 28 | 29 | @register_model("switch_transformer") 30 | class SwitchTransformerModel(TransformerModel): 31 | def __init__(self, args, encoder, decoder): 32 | super().__init__(args, encoder, decoder) 33 | self.args = args 34 | 35 | @classmethod 36 | def build_model(cls, args, task): 37 | """Build a new model instance.""" 38 | 39 | # make sure all arguments are present in older models 40 | base_architecture_switch(args) 41 | 42 | if args.encoder_layers_to_keep: 43 | args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) 44 | if args.decoder_layers_to_keep: 45 | args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) 46 | 47 | if getattr(args, "max_source_positions", None) is None: 48 | args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS 49 | if getattr(args, "max_target_positions", None) is None: 50 | args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS 51 | 52 | src_dict, tgt_dict = task.source_dictionary, task.target_dictionary 53 | 54 | if args.share_all_embeddings: 55 | if src_dict != tgt_dict: 56 | raise ValueError("--share-all-embeddings requires a joined dictionary") 57 | if args.encoder_embed_dim != args.decoder_embed_dim: 58 | raise ValueError( 59 | "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" 60 | ) 61 | if args.decoder_embed_path and ( 62 | args.decoder_embed_path != args.encoder_embed_path 63 | ): 64 | raise ValueError( 65 | "--share-all-embeddings not compatible with --decoder-embed-path" 66 | ) 67 | encoder_embed_tokens = cls.build_embedding( 68 | args, src_dict, args.encoder_embed_dim, args.encoder_embed_path 69 | ) 70 | decoder_embed_tokens = encoder_embed_tokens 71 | args.share_decoder_input_output_embed = True 72 | else: 73 | encoder_embed_tokens = cls.build_embedding( 74 | args, src_dict, args.encoder_embed_dim, args.encoder_embed_path 75 | ) 76 | decoder_embed_tokens = cls.build_embedding( 77 | args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path 78 | ) 79 | if getattr(args, "offload_activations", False): 80 | args.checkpoint_activations = True # offloading implies checkpointing 81 | 82 | encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) 83 | decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) 84 | 85 | if not args.share_all_embeddings: 86 | min_params_to_wrap = getattr( 87 | args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP 88 | ) 89 | # fsdp_wrap is a no-op when --ddp-backend != fully_sharded 90 | encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) 91 | decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) 92 | return cls(args, encoder, decoder) 93 | 94 | @classmethod 95 | def build_encoder(cls, args, src_dict, embed_tokens): 96 | return SwitchTransformerEncoder(args, src_dict, embed_tokens) 97 | 98 | @classmethod 99 | def build_decoder(cls, args, tgt_dict, embed_tokens): 100 | return SwitchTransformerDecoder( 101 | args, 102 | tgt_dict, 103 | embed_tokens, 104 | no_encoder_attn=getattr(args, "no_cross_attention", False), 105 | ) 106 | 107 | 108 | class SwitchTransformerEncoder(TransformerEncoder): 109 | """ 110 | Transformer encoder consisting of *args.encoder_layers* layers. Each layer 111 | is a :class:`SwitchTransformerEncoderLayer`. 112 | 113 | Args: 114 | args (argparse.Namespace): parsed command-line arguments 115 | dictionary (~fairseq.data.Dictionary): encoding dictionary 116 | embed_tokens (torch.nn.Embedding): input embedding 117 | """ 118 | 119 | def __init__(self, args, dictionary, embed_tokens): 120 | self.layer_idx = 0 # for building encoder layers 121 | super().__init__(args, dictionary, embed_tokens) 122 | assert self.layer_idx == args.encoder_layers 123 | 124 | def build_encoder_layer(self, args): 125 | layer = SwitchTransformerEncoderLayer(args, layer_idx=self.layer_idx) 126 | checkpoint = getattr(args, "checkpoint_activations", False) 127 | if checkpoint: 128 | offload_to_cpu = getattr(args, "offload_activations", False) 129 | layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) 130 | # if we are checkpointing, enforce that FSDP always wraps the 131 | # checkpointed layer, regardless of layer size 132 | min_params_to_wrap = ( 133 | getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) 134 | if not checkpoint 135 | else 0 136 | ) 137 | layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) 138 | self.layer_idx += 1 139 | return layer 140 | 141 | def forward( 142 | self, 143 | src_tokens, 144 | src_lengths: Optional[torch.Tensor] = None, 145 | return_all_hiddens: bool = False, 146 | token_embeddings: Optional[torch.Tensor] = None, 147 | ): 148 | """ 149 | Args: 150 | src_tokens (LongTensor): tokens in the source language of shape 151 | `(batch, src_len)` 152 | src_lengths (torch.LongTensor): lengths of each source sentence of 153 | shape `(batch)` 154 | return_all_hiddens (bool, optional): also return all of the 155 | intermediate hidden states (default: False). 156 | token_embeddings (torch.Tensor, optional): precomputed embeddings 157 | default `None` will recompute embeddings 158 | 159 | Returns: 160 | dict: 161 | - **encoder_out** (Tensor): the last encoder layer's output of 162 | shape `(src_len, batch, embed_dim)` 163 | - **encoder_padding_mask** (ByteTensor): the positions of 164 | padding elements of shape `(batch, src_len)` 165 | - **encoder_embedding** (Tensor): the (scaled) embedding lookup 166 | of shape `(batch, src_len, embed_dim)` 167 | - **encoder_states** (List[Tensor]): all intermediate 168 | hidden states of shape `(src_len, batch, embed_dim)`. 169 | Only populated if *return_all_hiddens* is True. 170 | """ 171 | return self.forward_scriptable( 172 | src_tokens, src_lengths, return_all_hiddens, token_embeddings 173 | ) 174 | 175 | # TorchScript doesn't support super() method so that the scriptable Subclass 176 | # can't access the base class model in Torchscript. 177 | # Current workaround is to add a helper function with different name and 178 | # call the helper function from scriptable Subclass. 179 | def forward_scriptable( 180 | self, 181 | src_tokens, 182 | src_lengths: Optional[torch.Tensor] = None, 183 | return_all_hiddens: bool = False, 184 | token_embeddings: Optional[torch.Tensor] = None, 185 | ): 186 | """ 187 | Args: 188 | src_tokens (LongTensor): tokens in the source language of shape 189 | `(batch, src_len)` 190 | src_lengths (torch.LongTensor): lengths of each source sentence of 191 | shape `(batch)` 192 | return_all_hiddens (bool, optional): also return all of the 193 | intermediate hidden states (default: False). 194 | token_embeddings (torch.Tensor, optional): precomputed embeddings 195 | default `None` will recompute embeddings 196 | 197 | Returns: 198 | dict: 199 | - **encoder_out** (Tensor): the last encoder layer's output of 200 | shape `(src_len, batch, embed_dim)` 201 | - **encoder_padding_mask** (ByteTensor): the positions of 202 | padding elements of shape `(batch, src_len)` 203 | - **encoder_embedding** (Tensor): the (scaled) embedding lookup 204 | of shape `(batch, src_len, embed_dim)` 205 | - **encoder_states** (List[Tensor]): all intermediate 206 | hidden states of shape `(src_len, batch, embed_dim)`. 207 | Only populated if *return_all_hiddens* is True. 208 | """ 209 | # compute padding mask 210 | encoder_padding_mask = src_tokens.eq(self.padding_idx) 211 | has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() 212 | 213 | x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) 214 | 215 | # account for padding while computing the representation 216 | if has_pads: 217 | x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) 218 | 219 | # B x T x C -> T x B x C 220 | x = x.transpose(0, 1) 221 | 222 | encoder_states = [] 223 | 224 | if return_all_hiddens: 225 | encoder_states.append(x) 226 | 227 | # encoder layers 228 | load = [] 229 | balance_loss = 0 230 | for layer in self.layers: 231 | x, layer_load, layer_balance = layer( 232 | x, 233 | encoder_padding_mask=encoder_padding_mask if has_pads else None, 234 | ) 235 | if return_all_hiddens: 236 | assert encoder_states is not None 237 | encoder_states.append(x) 238 | if layer_load is not None: 239 | load.append(layer_load) 240 | if layer_balance is not None: 241 | balance_loss = balance_loss + layer_balance 242 | 243 | if len(load) > 0: 244 | load = torch.vstack(load) 245 | load = load / load.sum(1, keepdim=True) 246 | else: 247 | load = None 248 | balance_loss = None 249 | 250 | if self.layer_norm is not None: 251 | x = self.layer_norm(x) 252 | 253 | # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in 254 | # `forward` so we use a dictionary instead. 255 | # TorchScript does not support mixed values so the values are all lists. 256 | # The empty list is equivalent to None. 257 | return { 258 | "encoder_out": [x], # T x B x C 259 | "encoder_padding_mask": [encoder_padding_mask], # B x T 260 | "encoder_embedding": [encoder_embedding], # B x T x C 261 | "encoder_states": encoder_states, # List[T x B x C] 262 | "src_tokens": [], 263 | "src_lengths": [], 264 | "load": load, 265 | "balance_loss": balance_loss 266 | } 267 | 268 | 269 | class SwitchTransformerDecoder(TransformerDecoder): 270 | """ 271 | Transformer decoder consisting of *args.decoder_layers* layers. Each layer 272 | is a :class:`SwitchTransformerDecoderLayer`. 273 | 274 | Args: 275 | args (argparse.Namespace): parsed command-line arguments 276 | dictionary (~fairseq.data.Dictionary): decoding dictionary 277 | embed_tokens (torch.nn.Embedding): output embedding 278 | no_encoder_attn (bool, optional): whether to attend to encoder outputs 279 | (default: False). 280 | """ 281 | 282 | def __init__( 283 | self, 284 | args, 285 | dictionary, 286 | embed_tokens, 287 | no_encoder_attn=False, 288 | output_projection=None, 289 | ): 290 | self.layer_idx = 0 # for building decoder layers 291 | super().__init__(args, dictionary, embed_tokens, no_encoder_attn, output_projection) 292 | assert self.layer_idx == args.decoder_layers 293 | 294 | def build_decoder_layer(self, args, no_encoder_attn=False): 295 | layer = SwitchTransformerDecoderLayer(args, no_encoder_attn, layer_idx=self.layer_idx) 296 | checkpoint = getattr(args, "checkpoint_activations", False) 297 | if checkpoint: 298 | offload_to_cpu = getattr(args, "offload_activations", False) 299 | layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) 300 | # if we are checkpointing, enforce that FSDP always wraps the 301 | # checkpointed layer, regardless of layer size 302 | min_params_to_wrap = ( 303 | getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) 304 | if not checkpoint 305 | else 0 306 | ) 307 | layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) 308 | self.layer_idx += 1 309 | return layer 310 | 311 | def extract_features_scriptable( 312 | self, 313 | prev_output_tokens, 314 | encoder_out: Optional[Dict[str, List[Tensor]]], 315 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 316 | full_context_alignment: bool = False, 317 | alignment_layer: Optional[int] = None, 318 | alignment_heads: Optional[int] = None, 319 | ): 320 | """ 321 | Similar to *forward* but only return features. 322 | 323 | Includes several features from "Jointly Learning to Align and 324 | Translate with Transformer Models" (Garg et al., EMNLP 2019). 325 | 326 | Args: 327 | full_context_alignment (bool, optional): don't apply 328 | auto-regressive mask to self-attention (default: False). 329 | alignment_layer (int, optional): return mean alignment over 330 | heads at this layer (default: last layer). 331 | alignment_heads (int, optional): only average alignment over 332 | this many heads (default: all heads). 333 | 334 | Returns: 335 | tuple: 336 | - the decoder's features of shape `(batch, tgt_len, embed_dim)` 337 | - a dictionary with any model-specific outputs 338 | """ 339 | bs, slen = prev_output_tokens.size() 340 | if alignment_layer is None: 341 | alignment_layer = self.num_layers - 1 342 | 343 | enc: Optional[Tensor] = None 344 | padding_mask: Optional[Tensor] = None 345 | if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: 346 | enc = encoder_out["encoder_out"][0] 347 | assert ( 348 | enc.size()[1] == bs 349 | ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" 350 | if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: 351 | padding_mask = encoder_out["encoder_padding_mask"][0] 352 | 353 | # embed positions 354 | positions = None 355 | if self.embed_positions is not None: 356 | positions = self.embed_positions( 357 | prev_output_tokens, incremental_state=incremental_state 358 | ) 359 | 360 | if incremental_state is not None: 361 | prev_output_tokens = prev_output_tokens[:, -1:] 362 | if positions is not None: 363 | positions = positions[:, -1:] 364 | 365 | # embed tokens and positions 366 | x = self.embed_scale * self.embed_tokens(prev_output_tokens) 367 | 368 | if self.quant_noise is not None: 369 | x = self.quant_noise(x) 370 | 371 | if self.project_in_dim is not None: 372 | x = self.project_in_dim(x) 373 | 374 | if positions is not None: 375 | x += positions 376 | 377 | if self.layernorm_embedding is not None: 378 | x = self.layernorm_embedding(x) 379 | 380 | x = self.dropout_module(x) 381 | 382 | # B x T x C -> T x B x C 383 | x = x.transpose(0, 1) 384 | 385 | self_attn_padding_mask: Optional[Tensor] = None 386 | if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): 387 | self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) 388 | 389 | # decoder layers 390 | attn: Optional[Tensor] = None 391 | inner_states: List[Optional[Tensor]] = [x] 392 | load = [] 393 | balance_loss = 0 394 | for idx, layer in enumerate(self.layers): 395 | if incremental_state is None and not full_context_alignment: 396 | self_attn_mask = self.buffered_future_mask(x) 397 | else: 398 | self_attn_mask = None 399 | 400 | x, layer_attn, _, layer_load, layer_balance = layer( 401 | x, 402 | enc, 403 | padding_mask, 404 | incremental_state, 405 | self_attn_mask=self_attn_mask, 406 | self_attn_padding_mask=self_attn_padding_mask, 407 | need_attn=bool((idx == alignment_layer)), 408 | need_head_weights=bool((idx == alignment_layer)), 409 | ) 410 | inner_states.append(x) 411 | if layer_load is not None: 412 | load.append(layer_load) 413 | if layer_balance is not None: 414 | balance_loss = balance_loss + layer_balance 415 | if layer_attn is not None and idx == alignment_layer: 416 | attn = layer_attn.float().to(x) 417 | 418 | load = torch.vstack(load) 419 | load = load / load.sum(1, keepdim=True) 420 | 421 | if attn is not None: 422 | if alignment_heads is not None: 423 | attn = attn[:alignment_heads] 424 | 425 | # average probabilities over heads 426 | attn = attn.mean(dim=0) 427 | 428 | if self.layer_norm is not None: 429 | x = self.layer_norm(x) 430 | 431 | # T x B x C -> B x T x C 432 | x = x.transpose(0, 1) 433 | 434 | if self.project_out_dim is not None: 435 | x = self.project_out_dim(x) 436 | 437 | return x, {"attn": [attn], "inner_states": inner_states, "load": load, "balance_loss": balance_loss} 438 | 439 | 440 | @register_model_architecture('switch_transformer', 'switch_transformer') 441 | def base_architecture_switch(args): 442 | base_architecture(args) 443 | 444 | 445 | @register_model_architecture("switch_transformer", "switch_transformer_iwslt_de_en") 446 | def switch_transformer_iwslt_de_en(args): 447 | transformer_iwslt_de_en(args) 448 | 449 | 450 | @register_model_architecture('switch_transformer', 'switch_transformer_wmt_en_de_big') 451 | def switch_transformer_wmt_en_de_big(args): 452 | transformer_wmt_en_de_big(args) 453 | 454 | 455 | @register_model_architecture('switch_transformer', 'switch_transformer_vaswani_wmt_en_de_big') 456 | def switch_transformer_vaswani_wmt_en_de_big(args): 457 | transformer_vaswani_wmt_en_de_big(args) 458 | 459 | 460 | @register_model_architecture("switch_transformer", "switch_transformer_vaswani_wmt_en_fr_big") 461 | def switch_transformer_vaswani_wmt_en_fr_big(args): 462 | transformer_vaswani_wmt_en_fr_big(args) 463 | -------------------------------------------------------------------------------- /example/moe/transformer_switch_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch import Tensor 6 | from typing import Dict, List, Optional 7 | 8 | from sparsemixer import get_router 9 | 10 | from fairseq.modules.quant_noise import quant_noise 11 | from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer 12 | 13 | 14 | def use_switch(layer_idx): 15 | return layer_idx % 2 == 0 16 | 17 | 18 | class SwitchTransformerEncoderLayer(TransformerEncoderLayer): 19 | """Encoder layer block. 20 | 21 | In the original paper each operation (multi-head attention or FFN) is 22 | postprocessed with: `dropout -> add residual -> layernorm`. In the 23 | tensor2tensor code they suggest that learning is more robust when 24 | preprocessing each layer with layernorm and postprocessing with: 25 | `dropout -> add residual`. We default to the approach in the paper, but the 26 | tensor2tensor approach can be enabled by setting 27 | *args.encoder_normalize_before* to ``True``. 28 | 29 | Args: 30 | args (argparse.Namespace): parsed command-line arguments 31 | """ 32 | 33 | def __init__(self, args, layer_idx=-1): 34 | self.num_experts = args.num_experts 35 | self.load_balancing = args.load_balancing 36 | self.use_switch = use_switch(layer_idx) 37 | super().__init__(args) 38 | self.gating_network = nn.Linear(args.encoder_embed_dim, args.num_experts) 39 | if self.use_switch: 40 | self.router = get_router(args.router)(args.num_experts, args.encoder_embed_dim, args.load_balancing, args.jitter_eps) 41 | else: 42 | self.router = None 43 | 44 | def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): 45 | if self.use_switch: 46 | return nn.ModuleList( 47 | [quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 48 | for _ in range(self.num_experts)] 49 | ) 50 | else: 51 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 52 | 53 | def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): 54 | if self.use_switch: 55 | return nn.ModuleList( 56 | [quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 57 | for _ in range(self.num_experts)] 58 | ) 59 | else: 60 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 61 | 62 | def forward( 63 | self, 64 | x, 65 | encoder_padding_mask: Optional[Tensor], 66 | attn_mask: Optional[Tensor] = None, 67 | ): 68 | """ 69 | Args: 70 | x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` 71 | encoder_padding_mask (ByteTensor): binary ByteTensor of shape 72 | `(batch, seq_len)` where padding elements are indicated by ``1``. 73 | attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, 74 | where `tgt_len` is the length of output and `src_len` is the 75 | length of input, though here both are equal to `seq_len`. 76 | `attn_mask[tgt_i, src_j] = 1` means that when calculating the 77 | embedding for `tgt_i`, we exclude (mask out) `src_j`. This is 78 | useful for strided self-attention. 79 | 80 | Returns: 81 | encoded output of shape `(seq_len, batch, embed_dim)` 82 | """ 83 | # anything in original attn_mask = 1, becomes -1e8 84 | # anything in original attn_mask = 0, becomes 0 85 | # Note that we cannot use -inf here, because at some edge cases, 86 | # the attention weight (before softmax) for some padded element in query 87 | # will become -inf, which results in NaN in model parameters 88 | if attn_mask is not None: 89 | attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) 90 | 91 | residual = x 92 | if self.normalize_before: 93 | x = self.self_attn_layer_norm(x) 94 | x, _ = self.self_attn( 95 | query=x, 96 | key=x, 97 | value=x, 98 | key_padding_mask=encoder_padding_mask, 99 | need_weights=False, 100 | attn_mask=attn_mask, 101 | ) 102 | x = self.dropout_module(x) 103 | x = self.residual_connection(x, residual) 104 | if not self.normalize_before: 105 | x = self.self_attn_layer_norm(x) 106 | 107 | residual = x 108 | if self.normalize_before: 109 | x = self.final_layer_norm(x) 110 | 111 | num_tokens = None 112 | balance_loss = 0.0 113 | if self.use_switch: 114 | seq_len, bsz, dim = x.shape 115 | x = x.view(-1, dim) 116 | logits = self.gating_network(x) 117 | sample, multiplier, balance_loss = self.router(logits) 118 | 119 | order = sample.argsort(0).squeeze(-1) 120 | num_tokens = F.one_hot(sample.squeeze(), self.num_experts).gt(0).sum(0) 121 | x = x[order] # reorder according to expert number 122 | x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts 123 | 124 | def forward_fc(input_x, expert_idx): 125 | if input_x.numel() > 0: 126 | input_x = self.activation_fn(self.fc1[expert_idx](input_x)) 127 | input_x = self.activation_dropout_module(input_x) 128 | input_x = self.fc2[expert_idx](input_x) 129 | return input_x 130 | x = torch.vstack( 131 | [forward_fc(x[i], i) for i in range(self.num_experts)] 132 | ) 133 | x = x[order.argsort()] * multiplier 134 | 135 | x = x.view(seq_len, bsz, dim) 136 | else: 137 | x = self.activation_fn(self.fc1(x)) 138 | x = self.activation_dropout_module(x) 139 | x = self.fc2(x) 140 | 141 | x = self.dropout_module(x) 142 | x = self.residual_connection(x, residual) 143 | 144 | if not self.normalize_before: 145 | x = self.final_layer_norm(x) 146 | return x, num_tokens, balance_loss 147 | 148 | 149 | class SwitchTransformerDecoderLayer(TransformerDecoderLayer): 150 | """Decoder layer block. 151 | 152 | In the original paper each operation (multi-head attention, encoder 153 | attention or FFN) is postprocessed with: `dropout -> add residual -> 154 | layernorm`. In the tensor2tensor code they suggest that learning is more 155 | robust when preprocessing each layer with layernorm and postprocessing with: 156 | `dropout -> add residual`. We default to the approach in the paper, but the 157 | tensor2tensor approach can be enabled by setting 158 | *args.decoder_normalize_before* to ``True``. 159 | 160 | Args: 161 | args (argparse.Namespace): parsed command-line arguments 162 | no_encoder_attn (bool, optional): whether to attend to encoder outputs 163 | (default: False). 164 | """ 165 | 166 | def __init__( 167 | self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, layer_idx=-1, 168 | ): 169 | self.num_experts = args.num_experts 170 | self.load_balancing = args.load_balancing 171 | self.use_switch = use_switch(layer_idx) 172 | super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn) 173 | self.gating_network = nn.Linear(args.decoder_embed_dim, args.num_experts) 174 | if self.use_switch: 175 | self.router = get_router(args.router)(args.num_experts, args.decoder_embed_dim, args.load_balancing, args.jitter_eps) 176 | else: 177 | self.router = None 178 | 179 | def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): 180 | if self.use_switch: 181 | return nn.ModuleList( 182 | [quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 183 | for _ in range(self.num_experts)] 184 | ) 185 | else: 186 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 187 | 188 | def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): 189 | if self.use_switch: 190 | return nn.ModuleList( 191 | [quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 192 | for _ in range(self.num_experts)] 193 | ) 194 | else: 195 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 196 | 197 | def forward( 198 | self, 199 | x, 200 | encoder_out: Optional[torch.Tensor] = None, 201 | encoder_padding_mask: Optional[torch.Tensor] = None, 202 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 203 | prev_self_attn_state: Optional[List[torch.Tensor]] = None, 204 | prev_attn_state: Optional[List[torch.Tensor]] = None, 205 | self_attn_mask: Optional[torch.Tensor] = None, 206 | self_attn_padding_mask: Optional[torch.Tensor] = None, 207 | need_attn: bool = False, 208 | need_head_weights: bool = False, 209 | ): 210 | """ 211 | Args: 212 | x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` 213 | encoder_padding_mask (ByteTensor, optional): binary 214 | ByteTensor of shape `(batch, src_len)` where padding 215 | elements are indicated by ``1``. 216 | need_attn (bool, optional): return attention weights 217 | need_head_weights (bool, optional): return attention weights 218 | for each head (default: return average over heads). 219 | 220 | Returns: 221 | encoded output of shape `(seq_len, batch, embed_dim)` 222 | """ 223 | if need_head_weights: 224 | need_attn = True 225 | 226 | residual = x 227 | if self.normalize_before: 228 | x = self.self_attn_layer_norm(x) 229 | if prev_self_attn_state is not None: 230 | prev_key, prev_value = prev_self_attn_state[:2] 231 | saved_state: Dict[str, Optional[Tensor]] = { 232 | "prev_key": prev_key, 233 | "prev_value": prev_value, 234 | } 235 | if len(prev_self_attn_state) >= 3: 236 | saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] 237 | assert incremental_state is not None 238 | self.self_attn._set_input_buffer(incremental_state, saved_state) 239 | _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) 240 | if self.cross_self_attention and not ( 241 | incremental_state is not None 242 | and _self_attn_input_buffer is not None 243 | and "prev_key" in _self_attn_input_buffer 244 | ): 245 | if self_attn_mask is not None: 246 | assert encoder_out is not None 247 | self_attn_mask = torch.cat( 248 | (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 249 | ) 250 | if self_attn_padding_mask is not None: 251 | if encoder_padding_mask is None: 252 | assert encoder_out is not None 253 | encoder_padding_mask = self_attn_padding_mask.new_zeros( 254 | encoder_out.size(1), encoder_out.size(0) 255 | ) 256 | self_attn_padding_mask = torch.cat( 257 | (encoder_padding_mask, self_attn_padding_mask), dim=1 258 | ) 259 | assert encoder_out is not None 260 | y = torch.cat((encoder_out, x), dim=0) 261 | else: 262 | y = x 263 | 264 | x, attn = self.self_attn( 265 | query=x, 266 | key=y, 267 | value=y, 268 | key_padding_mask=self_attn_padding_mask, 269 | incremental_state=incremental_state, 270 | need_weights=False, 271 | attn_mask=self_attn_mask, 272 | ) 273 | x = self.dropout_module(x) 274 | x = self.residual_connection(x, residual) 275 | if not self.normalize_before: 276 | x = self.self_attn_layer_norm(x) 277 | 278 | if self.encoder_attn is not None and encoder_out is not None: 279 | residual = x 280 | if self.normalize_before: 281 | x = self.encoder_attn_layer_norm(x) 282 | if prev_attn_state is not None: 283 | prev_key, prev_value = prev_attn_state[:2] 284 | saved_state: Dict[str, Optional[Tensor]] = { 285 | "prev_key": prev_key, 286 | "prev_value": prev_value, 287 | } 288 | if len(prev_attn_state) >= 3: 289 | saved_state["prev_key_padding_mask"] = prev_attn_state[2] 290 | assert incremental_state is not None 291 | self.encoder_attn._set_input_buffer(incremental_state, saved_state) 292 | 293 | x, attn = self.encoder_attn( 294 | query=x, 295 | key=encoder_out, 296 | value=encoder_out, 297 | key_padding_mask=encoder_padding_mask, 298 | incremental_state=incremental_state, 299 | static_kv=True, 300 | need_weights=need_attn or (not self.training and self.need_attn), 301 | need_head_weights=need_head_weights, 302 | ) 303 | x = self.dropout_module(x) 304 | x = self.residual_connection(x, residual) 305 | if not self.normalize_before: 306 | x = self.encoder_attn_layer_norm(x) 307 | 308 | residual = x 309 | if self.normalize_before: 310 | x = self.final_layer_norm(x) 311 | 312 | num_tokens = None 313 | balance_loss = 0.0 314 | if self.use_switch: 315 | seq_len, bsz, dim = x.shape 316 | x = x.view(-1, dim) 317 | logits = self.gating_network(x) 318 | sample, multiplier, balance_loss = self.router(logits) 319 | 320 | order = sample.argsort(0).squeeze(-1) 321 | num_tokens = F.one_hot(sample.squeeze(), self.num_experts).gt(0).sum(0) 322 | x = x[order] # reorder according to expert number 323 | x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts 324 | 325 | def forward_fc(input_x, expert_idx): 326 | if input_x.numel() > 0: 327 | input_x = self.activation_fn(self.fc1[expert_idx](input_x)) 328 | input_x = self.activation_dropout_module(input_x) 329 | input_x = self.fc2[expert_idx](input_x) 330 | return input_x 331 | x = torch.vstack( 332 | [forward_fc(x[i], i) for i in range(self.num_experts)] 333 | ) 334 | x = x[order.argsort()] * multiplier 335 | 336 | x = x.view(seq_len, bsz, dim) 337 | else: 338 | x = self.activation_fn(self.fc1(x)) 339 | x = self.activation_dropout_module(x) 340 | x = self.fc2(x) 341 | 342 | x = self.dropout_module(x) 343 | x = self.residual_connection(x, residual) 344 | if not self.normalize_before: 345 | x = self.final_layer_norm(x) 346 | if self.onnx_trace and incremental_state is not None: 347 | saved_state = self.self_attn._get_input_buffer(incremental_state) 348 | assert saved_state is not None 349 | if self_attn_padding_mask is not None: 350 | self_attn_state = [ 351 | saved_state["prev_key"], 352 | saved_state["prev_value"], 353 | saved_state["prev_key_padding_mask"], 354 | ] 355 | else: 356 | self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] 357 | return x, attn, self_attn_state 358 | return x, attn, None, num_tokens, balance_loss 359 | -------------------------------------------------------------------------------- /example/moe/translation_switch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import json 3 | import torch 4 | 5 | from argparse import Namespace 6 | from dataclasses import dataclass, field 7 | from omegaconf import II 8 | 9 | from fairseq import metrics, models 10 | from fairseq.data import encoders 11 | from fairseq.dataclass import ChoiceEnum 12 | from fairseq.optim.amp_optimizer import AMPOptimizer 13 | from fairseq.tasks import register_task 14 | from fairseq.tasks.translation import TranslationConfig, TranslationTask 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | @dataclass 21 | class TranslationSwitchConfig(TranslationConfig): 22 | num_experts: int = field( 23 | default=4, 24 | metadata={"help": "number of experts"}, 25 | ) 26 | router: ChoiceEnum(['SparseMixer', 'SwitchGate']) = field( 27 | default='SparseMixer', 28 | metadata={"help": "choice of router"}, 29 | ) 30 | load_balancing: bool = field( 31 | default=False, 32 | metadata={"help": "whether to use load balancing"}, 33 | ) 34 | gumbel: bool = field( 35 | default=False, 36 | metadata={"help": "use gumbel logits for computing balancing loss"}, 37 | ) 38 | jitter_eps: float = field( 39 | default=0.1, 40 | metadata={"help": "jitter eps"}, 41 | ) 42 | load_balancing_alpha: float = field( 43 | default=0.01, 44 | metadata={"help": "weight of load balancing loss"}, 45 | ) 46 | sentence_avg: bool = II("optimization.sentence_avg") 47 | 48 | 49 | @register_task("translation_switch", dataclass=TranslationSwitchConfig) 50 | class TranslationSwitchTask(TranslationTask): 51 | """ 52 | Translation task for Switch Transformer models. 53 | 54 | Args: 55 | src_dict (~fairseq.data.Dictionary): dictionary for the source language 56 | tgt_dict (~fairseq.data.Dictionary): dictionary for the target language 57 | 58 | .. note:: 59 | 60 | The translation task is compatible with :mod:`fairseq-train`, 61 | :mod:`fairseq-generate` and :mod:`fairseq-interactive`. 62 | 63 | The translation task provides the following additional command-line 64 | arguments: 65 | 66 | .. argparse:: 67 | :ref: fairseq.tasks.translation_parser 68 | :prog: 69 | """ 70 | 71 | cfg: TranslationSwitchConfig 72 | 73 | def __init__(self, cfg: TranslationSwitchConfig, src_dict, tgt_dict): 74 | super().__init__(cfg, src_dict, tgt_dict) 75 | 76 | def build_model(self, cfg): 77 | model = models.build_model(cfg, self) 78 | 79 | if self.cfg.eval_bleu: 80 | detok_args = json.loads(self.cfg.eval_bleu_detok_args) 81 | self.tokenizer = encoders.build_tokenizer( 82 | Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args) 83 | ) 84 | 85 | gen_args = json.loads(self.cfg.eval_bleu_args) 86 | self.sequence_generator = self.build_generator( 87 | [model], Namespace(**gen_args) 88 | ) 89 | 90 | return model 91 | 92 | def _get_loss(self, sample, model, criterion): 93 | assert hasattr( 94 | criterion, "compute_loss" 95 | ), "translation_switch task requires the criterion to implement the compute_loss() method" 96 | 97 | encoder_out = model.encoder( 98 | src_tokens=sample["net_input"]["src_tokens"], 99 | src_lengths=sample["net_input"]["src_lengths"], 100 | ) 101 | net_output = model.decoder( 102 | prev_output_tokens=sample["net_input"]["prev_output_tokens"], 103 | encoder_out=encoder_out, 104 | src_lengths=sample["net_input"]["src_lengths"], 105 | ) 106 | loss, nll_loss = criterion.compute_loss(model, net_output, sample, reduce=True) 107 | 108 | balance_loss = None 109 | if self.cfg.load_balancing: 110 | balance_loss = net_output[1]["balance_loss"] 111 | if encoder_out["balance_loss"] is not None: 112 | balance_loss = balance_loss + encoder_out["balance_loss"] 113 | loss = loss + balance_loss * self.cfg.load_balancing_alpha 114 | 115 | if 'load' in net_output[1]: 116 | load = net_output[1]["load"] 117 | if encoder_out["load"] is not None: 118 | load = torch.cat((encoder_out["load"], load), dim=0) 119 | else: 120 | load = torch.Tensor([1.]) 121 | 122 | sample_size = ( 123 | sample["target"].size(0) if criterion.sentence_avg else sample["ntokens"] 124 | ) 125 | logging_output = { 126 | "loss": loss.data, 127 | "nll_loss": nll_loss.data, 128 | "ntokens": sample["ntokens"], 129 | "nsentences": sample["target"].size(0), 130 | "sample_size": sample_size, 131 | "load": (load * 100).long(), 132 | "balance_loss": balance_loss.data if balance_loss is not None else 0.0, 133 | } 134 | return loss, sample_size, logging_output 135 | 136 | def train_step( 137 | self, sample, model, criterion, optimizer, update_num, ignore_grad=False 138 | ): 139 | model.train() 140 | model.set_num_updates(update_num) 141 | with torch.autograd.profiler.record_function("forward"): 142 | with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): 143 | loss, sample_size, logging_output = self._get_loss(sample, model, criterion) 144 | if ignore_grad: 145 | loss *= 0 146 | with torch.autograd.profiler.record_function("backward"): 147 | optimizer.backward(loss) 148 | return loss, sample_size, logging_output 149 | 150 | def reduce_metrics(self, logging_outputs, criterion): 151 | super().reduce_metrics(logging_outputs, criterion) 152 | 153 | metrics.log_scalar( 154 | "load", 155 | sum(log["load"] for log in logging_outputs if "load" in log) / torch.cuda.device_count(), 156 | round=3, 157 | ) 158 | 159 | temp = [log["balance_loss"] for log in logging_outputs if "balance_loss" in log] 160 | metrics.log_scalar("balance_loss", sum(temp), round=3) 161 | -------------------------------------------------------------------------------- /example/requirements.txt: -------------------------------------------------------------------------------- 1 | cffi 2 | cython 3 | hydra-core>=1.0.7,<1.1 4 | omegaconf<2.1 5 | numpy>=1.22 6 | regex 7 | sacrebleu>=1.4.12 8 | tqdm 9 | bitarray 10 | tensorboardX -------------------------------------------------------------------------------- /example/run_moe.sh: -------------------------------------------------------------------------------- 1 | EXPERTS=${1:-"4"} 2 | ROUTER=${2:-"SparseMixer"} 3 | DATA_FOLDER=${3:-"/mnt/azstorage/wmt14_en_de_joined_dict/"} 4 | 5 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python 6 | 7 | pip install --upgrade numpy scipy 8 | pip install -r requirements.txt 9 | pip install --no-deps fairseq==0.12.1 10 | 11 | ARGS="--num-experts ${EXPERTS} --jitter-eps 0.1 --load-balancing-alpha 0.01 --router ${ROUTER}" 12 | echo $ARGS 13 | 14 | OUTPUT_FOLDER=./output-${EXPERTS} 15 | 16 | mkdir -p $OUTPUT_FOLDER 17 | 18 | touch $OUTPUT_FOLDER/train.log 19 | 20 | cp $OUTPUT_FOLDER/train.log ./train.log 21 | 22 | GPUCT=$(nvidia-smi --list-gpus | wc -l) 23 | UPDATE_FREQ=$((16/${GPUCT})) 24 | echo $UPDATE_FREQ 25 | 26 | unset RANK 27 | fairseq-train ${DATA_FOLDER} --num-workers 8 --ddp-backend=no_c10d \ 28 | --user-dir moe --task translation_switch --load-balancing \ 29 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 30 | --arch switch_transformer --share-all-embeddings \ 31 | --source-lang en --target-lang de \ 32 | --optimizer adam --adam-betas '(0.9,0.98)' --clip-norm 0.0 --lr-scheduler inverse_sqrt \ 33 | --warmup-init-lr 1e-07 --warmup-updates 8000 --max-update 400000 \ 34 | --lr 7e-4 --max-tokens 8192 --update-freq ${UPDATE_FREQ} \ 35 | --weight-decay 0 --dropout 0.1 --activation-dropout 0.1 --attention-dropout 0.1 \ 36 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \ 37 | --log-format simple --log-interval 100 \ 38 | --skip-invalid-size-inputs-valid-test \ 39 | --best-checkpoint-metric loss --save-interval 10 \ 40 | --encoder-layers 6 --decoder-layers 6 \ 41 | --save-dir $OUTPUT_FOLDER $ARGS 2>&1 | tee -a ./train.log 42 | 43 | cp ./train.log $OUTPUT_FOLDER/train.log 44 | 45 | sleep 15m 46 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from setuptools import setup, find_packages 5 | 6 | def read_readme(): 7 | with open('README.md') as f: 8 | return f.read() 9 | 10 | with open('HISTORY.rst') as history_file: 11 | history = history_file.read() 12 | 13 | requirements = [ 14 | 'torch' 15 | ] 16 | 17 | setup( 18 | name='sparsemixer', 19 | version='0.1.0', 20 | description='SparseMixer Algorithm', 21 | long_description= read_readme(), 22 | long_description_content_type="text/markdown", 23 | author='Lucas Liu', 24 | author_email='llychinalz@gmail.com', 25 | url='https://github.com/microsoft/SparseMixer', 26 | packages=find_packages(exclude=['docs']), 27 | include_package_data=True, 28 | install_requires=requirements, 29 | license='MIT', 30 | zip_safe=False, 31 | classifiers=[ 32 | 'Development Status :: 2 - Pre-Alpha', 33 | 'Intended Audience :: Developers', 34 | 'Natural Language :: English', 35 | 'Programming Language :: Python :: 3.7', 36 | 'Programming Language :: Python :: 3.8', 37 | 'Programming Language :: Python :: 3.9', 38 | ] 39 | ) 40 | 41 | # python setup.py sdist bdist_wheel --universal 42 | # twine upload dist/* -------------------------------------------------------------------------------- /sparsemixer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | __author__ = "Liyuan Liu" 5 | 6 | __maintainer__ = "Liyuan Liu" 7 | __email__ = "llychinalz@gmail.com" 8 | 9 | from .sparsemixer import SparseMixer 10 | from .switchgate import SwitchGate 11 | 12 | router_map = { 13 | 'sparsemixer': SparseMixer, 14 | 'switchgate': SwitchGate 15 | } 16 | 17 | def get_router(name): 18 | name=name.lower() 19 | assert name in router_map, f'Currently only supports SparseMixer and SwitchGate. {name} is not supported!' 20 | 21 | return router_map[name] 22 | -------------------------------------------------------------------------------- /sparsemixer/sparsemixer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class SparseMixerCore(torch.autograd.Function): 6 | """ 7 | `torch.autograd.Function` implementation of the balancing strategy used in SparseMixer. 8 | """ 9 | 10 | @staticmethod 11 | def forward( 12 | ctx, 13 | multiplier: torch.Tensor, 14 | firstorder_mask: torch.Tensor, 15 | ): 16 | firstorder_mask = torch.add(0.5, firstorder_mask, alpha=0.5).type_as(multiplier) 17 | return multiplier * firstorder_mask # turns [0,1] into [0.5, 1] 18 | 19 | @staticmethod 20 | def backward( 21 | ctx, 22 | grad_at_multiplier: torch.Tensor, 23 | ): 24 | return grad_at_multiplier * 2, None 25 | 26 | class SparseMixer(nn.Module): 27 | def __init__(self, num_experts, embed_dim, compute_balance_loss=False, jitter_eps=0.1): 28 | super(SparseMixer, self).__init__() 29 | self.num_experts = num_experts 30 | self.compute_balance_loss = compute_balance_loss 31 | self.jitter_eps = jitter_eps 32 | self.embed_dim = embed_dim 33 | self.register_parameter('omega', torch.nn.Parameter(torch.ones(embed_dim))) 34 | 35 | def forward(self, logits): 36 | 37 | # masking out experts that are never sampled by jittering 38 | with torch.no_grad(): 39 | mask_logits_threshold, max_ind = logits.max(dim=-1, keepdim=True) 40 | factor = logits.abs().clamp(min=mask_logits_threshold) 41 | mask_logits_threshold = ( 42 | (mask_logits_threshold - logits) / factor 43 | ) > (2 * self.jitter_eps) 44 | logits = logits.masked_fill_(mask_logits_threshold, float('-inf')) 45 | 46 | p = logits.softmax(dim=-1) 47 | if self.training: 48 | sample = ( 49 | logits - torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() 50 | ).max(dim=-1)[1].unsqueeze(-1) # gumbel sampling, more robust than than the multinomial method 51 | 52 | multiplier = p.gather(dim=1, index=sample) 53 | 54 | mask_for_firstorder = torch.logical_or( 55 | sample == max_ind, 56 | torch.rand_like(multiplier) > 0.5 57 | ) # compute the mask for applying the first-order method 58 | multiplier = SparseMixerCore.apply(multiplier, mask_for_firstorder) # balance mid-point and euler 59 | else: 60 | sample = max_ind 61 | multiplier = p.gather(dim=1, index=sample) 62 | 63 | multiplier = multiplier * self.omega 64 | balance_loss = 0.0 65 | if self.compute_balance_loss: 66 | num_tokens = F.one_hot(sample.squeeze(-1), self.num_experts).gt(0).sum(0) 67 | f = num_tokens / (num_tokens.sum(0, keepdim=True) + 1e-6) 68 | pmean = p.view(-1, self.num_experts).mean(0) 69 | balance_loss = self.num_experts * torch.sum(pmean * f) 70 | 71 | return sample, multiplier, balance_loss 72 | # return sample, [multiplier], balance_loss 73 | -------------------------------------------------------------------------------- /sparsemixer/sparsemixer_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | uniform_map: Dict[torch.device, Callable] = {} 6 | def multiplicative_jitter(input, epsilon, training): 7 | 8 | if epsilon == 0 or not training: 9 | return input 10 | 11 | uniform = uniform_map.get(input.device) 12 | 13 | if uniform is None: 14 | uniform = Uniform(low=torch.tensor(1.0 - epsilon, device=input.device, dtype=input.dtype), 15 | high=torch.tensor(1.0 + epsilon, device=input.device, dtype=input.dtype) 16 | ).rsample 17 | uniform_map[input.device] = uniform 18 | 19 | return input * uniform(input.shape) 20 | 21 | class v2core(torch.autograd.Function): 22 | @staticmethod 23 | def forward( 24 | ctx, 25 | scores: torch.Tensor, 26 | multiplier: torch.Tensor, 27 | selected_experts: torch.Tensor, 28 | masked_gates: torch.Tensor, 29 | mask_for_one: torch.Tensor, 30 | ): 31 | ctx.save_for_backward(multiplier, selected_experts, masked_gates) 32 | return multiplier * mask_for_one 33 | 34 | @staticmethod 35 | def backward( 36 | ctx, 37 | grad_at_output: torch.Tensor, 38 | ): 39 | multiplier, selected_experts, masked_gates = ctx.saved_tensors 40 | 41 | grad_at_output = grad_at_output * multiplier 42 | 43 | grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1) 44 | grad_at_scores_expaned.scatter_add_( 45 | dim=-1, 46 | index=selected_experts, 47 | src=grad_at_output, 48 | ) 49 | 50 | return ( 51 | grad_at_scores_expaned, 52 | None, 53 | None, 54 | None, 55 | None, 56 | ) 57 | 58 | def sparsemixerv2_routing(scores, top_k, jitter_eps, training): 59 | assert top_k in [1, 2], "only top-1/2 gating has been tested!" 60 | 61 | original_gates = torch.softmax(scores, dim=-1) 62 | ################ first expert ################ 63 | 64 | with torch.no_grad(): 65 | # compute mask for sparsity 66 | mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) 67 | factor = scores.abs().clamp(min=mask_logits_threshold) 68 | mask_logits_threshold = ( 69 | (mask_logits_threshold - scores) / factor 70 | ) > (2 * jitter_eps) 71 | 72 | # apply mask 73 | masked_gates = scores.masked_fill(mask_logits_threshold, float('-inf')) 74 | if training: 75 | selected_experts = ( 76 | masked_gates - torch.empty_like(masked_gates, memory_format=torch.legacy_contiguous_format).exponential_().log() 77 | ).max(dim=-1)[1].unsqueeze(-1) # gumbel sampling, more robust than than the multinomial method 78 | else: 79 | selected_experts = max_ind 80 | 81 | # compute scores for gradients 82 | masked_gates = torch.softmax(masked_gates, dim=-1) 83 | 84 | # compute midpoint mask 85 | max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True) 86 | mask_for_one = torch.logical_or( 87 | selected_experts == max_ind, 88 | torch.rand_like(max_scores) > 0.75 # Heun's third-order method: f(x) - f(0) = .25 f'(x) + .75 f'(x/3.) 89 | ) 90 | # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5 91 | mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates) 92 | 93 | multiplier_o = masked_gates.gather(dim=-1, index=selected_experts) 94 | multiplier = v2core.apply( 95 | scores, 96 | multiplier_o, 97 | selected_experts, 98 | masked_gates, 99 | mask_for_one, 100 | ) 101 | 102 | ################ second expert ################ 103 | if top_k > 1: 104 | # masked out first expert 105 | masked_scores = torch.scatter( 106 | scores, 107 | -1, 108 | selected_experts, 109 | float('-inf'), 110 | ) 111 | with torch.no_grad(): 112 | # compute mask for sparsity 113 | mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True) 114 | factor = scores.abs().clamp(min=mask_logits_threshold) 115 | mask_logits_threshold = ( 116 | (mask_logits_threshold - scores) / factor 117 | ) > (2 * jitter_eps) 118 | 119 | # apply mask 120 | masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float('-inf')) 121 | if training: 122 | selected_experts_top2 = ( 123 | masked_gates_top2 - torch.empty_like(masked_gates_top2, memory_format=torch.legacy_contiguous_format).exponential_().log() 124 | ).max(dim=-1)[1].unsqueeze(-1) # gumbel sampling, more robust than than the multinomial method 125 | else: 126 | selected_experts_top2 = max_ind 127 | # compute scores for gradients 128 | masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) 129 | 130 | # compute midpoint mask 131 | max_scores, max_ind = masked_gates_top2.max(dim=-1, keepdim=True) 132 | mask_for_one_top2 = torch.logical_or( 133 | selected_experts_top2 == max_ind, 134 | torch.rand_like(max_scores).uniform_() > 0.75 # Heun's third-order method: f(x) - f(0) = .25 f'(x) + .75 f'(x/3.) 135 | ) 136 | # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5 137 | mask_for_one_top2 = torch.add(0.3333, mask_for_one_top2, alpha=0.6667).type_as(masked_gates_top2) 138 | 139 | multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2) 140 | multiplier_top2 = v2core.apply( 141 | scores, 142 | multiplier_top2_o, 143 | selected_experts_top2, 144 | masked_gates_top2, 145 | mask_for_one_top2, 146 | ) 147 | 148 | multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) 149 | selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1) 150 | 151 | return ( 152 | multiplier, 153 | original_gates, 154 | selected_experts, 155 | ) 156 | 157 | class SparseMixerV2(nn.Module): 158 | def __init__(self, num_experts, embed_dim, compute_balance_loss=False, jitter_eps=0.1): 159 | super(SparseMixer, self).__init__() 160 | self.num_experts = num_experts 161 | self.compute_balance_loss = compute_balance_loss 162 | self.jitter_eps = jitter_eps 163 | 164 | def forward(self, logits): 165 | 166 | multiplier, original_gates, sample = sparsemixerv2_routing(logits, 1, self.jitter_eps, self.training) 167 | 168 | balance_loss = 0.0 169 | if self.compute_balance_loss: 170 | num_tokens = F.one_hot(sample.squeeze(-1), self.num_experts).gt(0).sum(0) 171 | f = num_tokens / (num_tokens.sum(0, keepdim=True) + 1e-6) 172 | pmean = p.view(-1, self.num_experts).mean(0) 173 | balance_loss = self.num_experts * torch.sum(pmean * f) 174 | 175 | return sample, multiplier, balance_loss 176 | -------------------------------------------------------------------------------- /sparsemixer/switchgate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SwitchGate(nn.Module): 7 | def __init__(self, num_experts, embed_dim, compute_balance_loss=False, jitter_eps=0.1): 8 | super(SwitchGate, self).__init__() 9 | self.num_experts = num_experts 10 | self.compute_balance_loss = compute_balance_loss 11 | self.jitter_eps = jitter_eps 12 | 13 | def forward(self, logits): 14 | if self.training: 15 | noise = torch.rand_like(logits) 16 | noise = noise * 2 * self.jitter_eps + 1.0 - self.jitter_eps 17 | logits = logits * noise 18 | 19 | p = logits.softmax(dim=-1) 20 | sample = torch.argmax(p, dim=-1) 21 | 22 | balance_loss = 0.0 23 | if self.compute_balance_loss: 24 | num_tokens = F.one_hot(sample, self.num_experts).gt(0).sum(0) 25 | f = num_tokens / (num_tokens.sum(0, keepdim=True) + 1e-6) 26 | pmean = p.view(-1, self.num_experts).mean(0) 27 | balance_loss = self.num_experts * torch.sum(pmean * f) 28 | 29 | multiplier = p.gather(dim=-1, index=sample.unsqueeze(1)) 30 | return sample, multiplier, balance_loss --------------------------------------------------------------------------------