├── Figs └── framework.jpg ├── LICENSE ├── README.md ├── custom_gate.py ├── custom_layers.py ├── custom_transformer.py ├── custom_transformer2.py ├── custom_utils.py ├── data_utils.py ├── fastermoe ├── __init__.py ├── config.py ├── expert_utils.py ├── schedule.py └── shadow_policy.py ├── functions.py ├── gates ├── __init__.py ├── base_gate.py ├── faster_gate.py ├── gshard_gate.py ├── naive_gate.py ├── noisy_gate.py ├── swipe_gate.py ├── switch_gate.py ├── utils.py └── zero_gate.py ├── linear.py ├── mem_transformer.py ├── mem_transformer_sst2.py ├── new_utils.py ├── script ├── figure5 │ ├── 12layers_smoe_dropout.sh │ └── 8layers_smoe_dropout.sh ├── table1 │ └── transformer_xl │ │ ├── directly_dense_training.sh │ │ └── smoe_dropout.sh └── table2 │ └── sst2 │ ├── dense_model.sh │ └── smoe_dropout.sh ├── train.py ├── train_sst2.py └── utils ├── adaptive_softmax.py ├── data_parallel.py ├── exp_utils.py ├── log_uniform_sampler.py ├── proj_adaptive_softmax.py └── vocabulary.py /Figs/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/Random-MoE-as-Dropout/0272cead5067d40108b4209ba87d512949dd7580/Figs/framework.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse MoE as the New Dropout: Scaling Dense and Self-Slimmable Transformers 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 4 | 5 | Code for this paper [Sparse MoE as the New Dropout: Scaling Dense and Self-Slimmable Transformers](https://openreview.net/forum?id=w1hwFUb_81) 6 | 7 | Tianlong Chen\*, Zhenyu Zhang\*, Ajay Jaiswal, Shiwei Liu, Zhangyang Wang 8 | 9 | Our implementation is based on [fastmoe repo](https://github.com/laekov/fastmoe) and [huggingface repo](https://github.com/huggingface/transformers). More training script and pre-trained models are coming soon. 10 | 11 | 12 | 13 | ## Overview 14 | 15 | Despite their remarkable achievement, gigantic transformers encounter significant drawbacks, including exorbitant computational and memory footprints during training, as well as severe collapse evidenced by a high degree of parameter redundancy. Sparsely-activated Mixture-of-Experts (SMoEs) have shown promise to mitigate the issue of training efficiency, yet they are prone to (1) redundant experts, due to representational collapse; and (2) poor expert scalability for inference and downstream fine-tuning, primarily due to overfitting of the learned routing policy to the number of activated experts during training. As recent research efforts are predominantly focused on improving routing policies to encourage expert specializations, this work focuses on exploring the overlooked scalability bottleneck of SMoEs and leveraging it to effectively scale dense transformers. To this end, we propose a new plug-and-play training framework, SMoE-Dropout, to enable scaling transformers to better accuracy in their full capacity without collapse. Specifically, SMoE-Dropout consists of a randomly initialized and fixed router network to activate experts and gradually increases the activated expert number as training progresses over time. Transformers trained by SMoE-Dropout naturally exhibit a self-slimmable property subject to resource availability, offering smooth and consistent performance boosts with an increase in activated experts during inference or fine-tuning. The framework of our SMoE-Dropout is demonstrated in the following figure. 16 | 17 | ![](Figs/framework.jpg) 18 | 19 | 20 | ## Prerequisite 21 | 22 | - pytorch 23 | - fastmoe: https://github.com/laekov/fastmoe 24 | - transformer: https://github.com/huggingface/transformers 25 | 26 | ## Usage 27 | 28 | ##### Pretraining Transformer-XL on enwik8: 29 | 30 | ``` # Table 1: 31 | bash script/table1/smoe_dropout.sh 32 | bash script/table1/directly_dense_training.sh 33 | ``` 34 | 35 | ##### Transfor pretrained model on SST-2: 36 | 37 | ``` 38 | bash script/table2/sst2/dense_model.sh [pretrained-checkpoint] 39 | bash script/table2/sst2/smoe_dropout.sh [pretrained-checkpoint] 40 | ``` 41 | 42 | ##### Ablation: 43 | 44 | ``` 45 | bash script/figure5/8layers_smoe_dropout.sh 46 | bash script/figure5/12layers_smoe_dropout.sh 47 | ``` 48 | 49 | ## Citation 50 | 51 | ``` 52 | @inproceedings{ 53 | chen2023sparse, 54 | title={Sparse MoE as the New Dropout: Scaling Dense and Self-Slimmable Transformers}, 55 | author={Tianlong Chen and Zhenyu Zhang and AJAY KUMAR JAISWAL and Shiwei Liu and Zhangyang Wang}, 56 | booktitle={The Eleventh International Conference on Learning Representations }, 57 | year={2023}, 58 | url={https://openreview.net/forum?id=w1hwFUb_81} 59 | } 60 | ``` 61 | 62 | -------------------------------------------------------------------------------- /custom_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Custom Gate 3 | """ 4 | from fmoe.gates.base_gate import BaseGate 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import pdb 11 | import numpy as np 12 | 13 | __all__ = ['CustomNaiveGate', 'CustomDropGate', 'CustomRandomGate', 'CustomRandomGate_Dense', 14 | 'CustomDTSGate', 'CustomDTSRandomGate', 'CustomDTSGate_softmax', 'CustomDTSRandomGate_softmax', 15 | 'CustomDenseGate', 'CustomHashGate', 'CustomNaiveGate_Balance', 'CustomNaiveGate_Attn'] 16 | 17 | 18 | class CustomNaiveGate(BaseGate): 19 | r""" 20 | Naive Gate 21 | """ 22 | 23 | def __init__(self, d_model, num_expert, world_size, top_k=2): 24 | super().__init__(num_expert, world_size) 25 | self.gate = nn.Linear(d_model, self.tot_expert) 26 | self.top_k = top_k 27 | self.dense_moe_flag = False 28 | 29 | def forward(self, inp, return_all_scores=False): 30 | 31 | gate = self.gate(inp) 32 | 33 | if self.dense_moe_flag: 34 | gate = torch.ones_like(gate) # average the importance of all experts 35 | gate_top_k_val, gate_top_k_idx = torch.topk( 36 | gate, k=self.tot_expert, dim=-1, largest=True, sorted=False 37 | ) 38 | gate_top_k_val = gate_top_k_val.view(-1, self.tot_expert) 39 | else: 40 | gate_top_k_val, gate_top_k_idx = torch.topk( 41 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 42 | ) # [.. x top_k] 43 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) 44 | # (BxL) x 1 x top_k 45 | 46 | gate_score = F.softmax(gate_top_k_val, dim=-1) 47 | 48 | if return_all_scores: 49 | return gate_top_k_idx, gate_score, gate 50 | return gate_top_k_idx, gate_score 51 | 52 | 53 | class CustomNaiveGate_Attn(BaseGate): 54 | r""" 55 | Naive Gate 56 | """ 57 | 58 | def __init__(self, d_model, num_expert, world_size, top_k=2): 59 | super().__init__(num_expert, world_size) 60 | self.gate = nn.Linear(d_model, self.tot_expert) 61 | self.top_k = top_k 62 | self.dense_moe_flag = False 63 | 64 | def forward(self, inp, return_all_scores=False): 65 | 66 | gate = self.gate(inp) 67 | 68 | if self.dense_moe_flag: 69 | gate = torch.ones_like(gate) # average the importance of all experts 70 | gate_top_k_val, gate_top_k_idx = torch.topk( 71 | gate, k=self.tot_expert, dim=-1, largest=True, sorted=False 72 | ) 73 | gate_top_k_val = gate_top_k_val.view(-1, self.tot_expert) 74 | else: 75 | gate_top_k_val, gate_top_k_idx = torch.topk( 76 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 77 | ) # [.. x top_k] 78 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) 79 | # (BxL) x 1 x top_k 80 | 81 | gate_score = F.softmax(gate_top_k_val, dim=-1) 82 | 83 | if return_all_scores: 84 | return gate_top_k_idx, gate_score, gate 85 | return gate_top_k_idx, gate_score 86 | 87 | 88 | class CustomNaiveGate_Balance(BaseGate): 89 | r""" 90 | Naive Gate with Balance loss 91 | """ 92 | 93 | def __init__(self, d_model, num_expert, world_size, top_k=2): 94 | super().__init__(num_expert, world_size) 95 | self.gate = nn.Linear(d_model, self.tot_expert) 96 | self.top_k = top_k 97 | self.dense_moe_flag = False 98 | self.loss = None 99 | 100 | def set_load_balance(self, gate, gate_top_k_idx): 101 | # gate_top_k_idx (tokens_number, top-k) 102 | # gate_top_k_val (tokens_number, top-k) 103 | 104 | score = F.softmax(gate, dim=-1) 105 | valid_idx = gate_top_k_idx[gate_top_k_idx > -1] 106 | fraction_expert = torch.scatter_add( 107 | torch.zeros(self.tot_expert, device=valid_idx.device), 108 | 0, 109 | valid_idx, 110 | torch.ones_like(valid_idx, dtype=torch.float), 111 | ) / valid_idx.numel() 112 | prob_expert = score.sum(dim=0) / valid_idx.numel() 113 | 114 | loss = (fraction_expert * prob_expert).sum() * self.tot_expert 115 | self.loss = loss 116 | 117 | def forward(self, inp, return_all_scores=False): 118 | 119 | gate = self.gate(inp) 120 | 121 | if self.dense_moe_flag: 122 | gate = torch.ones_like(gate) # average the importance of all experts 123 | gate_top_k_val, gate_top_k_idx = torch.topk( 124 | gate, k=self.tot_expert, dim=-1, largest=True, sorted=False 125 | ) 126 | gate_top_k_val = gate_top_k_val.view(-1, self.tot_expert) 127 | else: 128 | gate_top_k_val, gate_top_k_idx = torch.topk( 129 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 130 | ) # [.. x top_k] 131 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) 132 | # (BxL) x 1 x top_k 133 | 134 | gate_score = F.softmax(gate_top_k_val, dim=-1) 135 | 136 | self.set_load_balance(gate, gate_top_k_idx) 137 | 138 | if return_all_scores: 139 | return gate_top_k_idx, gate_score, gate 140 | return gate_top_k_idx, gate_score 141 | 142 | 143 | class CustomHashGate(BaseGate): 144 | 145 | def __init__(self, d_model, num_expert, world_size, top_k=2): 146 | super().__init__(num_expert, world_size) 147 | self.gate = nn.Linear(d_model, self.tot_expert) 148 | self.top_k = top_k 149 | 150 | def forward(self, inp, return_all_scores=False): 151 | 152 | if not hasattr(self, 'hash_gate'): 153 | # generate hash gate 154 | print('Generate Hash Mapping') 155 | token_num = inp.shape[0] 156 | self.register_buffer('hash_gate', torch.rand(token_num, self.tot_expert).to(inp.device)) 157 | print(self.hash_gate.shape) 158 | else: 159 | if self.hash_gate.shape[0] != inp.shape[0]: 160 | if not hasattr(self, 'hash_gate_v2'): 161 | print('Generate New Hash Mapping v2') 162 | token_num = inp.shape[0] 163 | self.register_buffer('hash_gate_v2', torch.rand(token_num, self.tot_expert).to(inp.device)) 164 | print(self.hash_gate_v2.shape) 165 | else: 166 | if self.hash_gate_v2.shape[0] != inp.shape[0]: 167 | if not hasattr(self, 'hash_gate_v3'): 168 | print('Generate New Hash Mapping v3') 169 | token_num = inp.shape[0] 170 | self.register_buffer('hash_gate_v3', torch.rand(token_num, self.tot_expert).to(inp.device)) 171 | print(self.hash_gate_v3.shape) 172 | else: 173 | if self.hash_gate_v3.shape[0] != inp.shape[0]: 174 | if not hasattr(self, 'hash_gate_v4'): 175 | print('Generate New Hash Mapping v4') 176 | token_num = inp.shape[0] 177 | self.register_buffer('hash_gate_v4', torch.rand(token_num, self.tot_expert).to(inp.device)) 178 | print(self.hash_gate_v4.shape) 179 | else: 180 | if self.hash_gate_v4.shape[0] != inp.shape[0]: 181 | print('Generate New Hash Mapping v5') 182 | token_num = inp.shape[0] 183 | self.register_buffer('hash_gate_v5', torch.rand(token_num, self.tot_expert).to(inp.device)) 184 | print(self.hash_gate_v5.shape) 185 | 186 | if inp.shape[0] == self.hash_gate.shape[0]: 187 | gate = self.hash_gate 188 | elif inp.shape[0] == self.hash_gate_v2.shape[0]: 189 | gate = self.hash_gate_v2 190 | elif inp.shape[0] == self.hash_gate_v3.shape[0]: 191 | gate = self.hash_gate_v3 192 | elif inp.shape[0] == self.hash_gate_v4.shape[0]: 193 | gate = self.hash_gate_v4 194 | elif inp.shape[0] == self.hash_gate_v5.shape[0]: 195 | gate = self.hash_gate_v5 196 | else: 197 | assert False 198 | 199 | gate_top_k_val, gate_top_k_idx = torch.topk( 200 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 201 | ) # [.. x top_k] 202 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) 203 | # (BxL) x 1 x top_k 204 | 205 | gate_top_k_val = torch.ones_like(gate_top_k_val) 206 | gate_score = F.softmax(gate_top_k_val, dim=-1) 207 | 208 | if return_all_scores: 209 | return gate_top_k_idx, gate_score, gate 210 | return gate_top_k_idx, gate_score 211 | 212 | 213 | 214 | 215 | class CustomDropGate(BaseGate): 216 | r""" 217 | Dropout Gate 218 | """ 219 | 220 | def __init__(self, d_model, num_expert, world_size, top_k=2): 221 | super().__init__(num_expert, world_size) 222 | self.gate = nn.Linear(d_model, self.tot_expert) 223 | self.top_k = top_k 224 | self.dense_moe_flag = False 225 | self.dropout = nn.Dropout(p=0.5) 226 | 227 | def forward(self, inp, return_all_scores=False): 228 | 229 | gate = self.gate(inp) 230 | 231 | if self.training: 232 | gate = self.dropout(gate) 233 | 234 | if self.dense_moe_flag: 235 | gate = torch.ones_like(gate) # average the importance of all experts 236 | gate_top_k_val, gate_top_k_idx = torch.topk( 237 | gate, k=self.tot_expert, dim=-1, largest=True, sorted=False 238 | ) 239 | gate_top_k_val = gate_top_k_val.view(-1, self.tot_expert) 240 | else: 241 | gate_top_k_val, gate_top_k_idx = torch.topk( 242 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 243 | ) # [.. x top_k] 244 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) 245 | # (BxL) x 1 x top_k 246 | 247 | gate_score = F.softmax(gate_top_k_val, dim=-1) 248 | 249 | if return_all_scores: 250 | return gate_top_k_idx, gate_score, gate 251 | return gate_top_k_idx, gate_score 252 | 253 | class CustomRandomGate(BaseGate): 254 | r""" 255 | Random Assign Gate 256 | """ 257 | 258 | def __init__(self, d_model, num_expert, world_size, top_k=2): 259 | super().__init__(num_expert, world_size) 260 | self.gate = nn.Linear(d_model, self.tot_expert) 261 | self.top_k = top_k 262 | self.dense_moe_flag = False 263 | 264 | def forward(self, inp, return_all_scores=False): 265 | 266 | gate = self.gate(inp) 267 | 268 | # random gate uniform distribution 269 | gate = torch.rand_like(gate) 270 | 271 | if self.dense_moe_flag: 272 | gate = torch.ones_like(gate) 273 | gate_top_k_val, gate_top_k_idx = torch.topk( 274 | gate, k=self.tot_expert, dim=-1, largest=True, sorted=False 275 | ) 276 | gate_top_k_val = gate_top_k_val.view(-1, self.tot_expert) 277 | else: 278 | gate_top_k_val, gate_top_k_idx = torch.topk( 279 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 280 | ) # [.. x top_k] 281 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) 282 | # (BxL) x 1 x top_k 283 | 284 | gate_score = F.softmax(gate_top_k_val, dim=-1) 285 | 286 | if return_all_scores: 287 | return gate_top_k_idx, gate_score, gate 288 | return gate_top_k_idx, gate_score 289 | 290 | class CustomRandomGate_Dense(BaseGate): 291 | r""" 292 | Random Assign Gate 293 | """ 294 | 295 | def __init__(self, d_model, num_expert, world_size, top_k=2): 296 | super().__init__(num_expert, world_size) 297 | self.gate = nn.Linear(d_model, self.tot_expert) 298 | self.top_k = top_k 299 | self.dense_moe_flag = False 300 | 301 | def forward(self, inp, return_all_scores=False): 302 | 303 | gate = self.gate(inp) 304 | 305 | # random gate uniform distribution 306 | gate = torch.ones_like(gate) 307 | 308 | if self.dense_moe_flag: 309 | gate = torch.ones_like(gate) 310 | gate_top_k_val, gate_top_k_idx = torch.topk( 311 | gate, k=self.tot_expert, dim=-1, largest=True, sorted=False 312 | ) 313 | gate_top_k_val = gate_top_k_val.view(-1, self.tot_expert) 314 | else: 315 | gate_top_k_val, gate_top_k_idx = torch.topk( 316 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 317 | ) # [.. x top_k] 318 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) 319 | # (BxL) x 1 x top_k 320 | 321 | gate_score = F.softmax(gate_top_k_val, dim=-1) 322 | 323 | if return_all_scores: 324 | return gate_top_k_idx, gate_score, gate 325 | return gate_top_k_idx, gate_score 326 | 327 | 328 | # Dense to Sparse 329 | class CustomDTSGate(BaseGate): 330 | r""" 331 | Dense to Sparse Gate 332 | """ 333 | 334 | def __init__(self, d_model, num_expert, world_size, top_k=2): 335 | super().__init__(num_expert, world_size) 336 | self.gate = nn.Linear(d_model, self.tot_expert) 337 | self.top_k = top_k 338 | self.dense_moe_flag = False 339 | 340 | self.temperature = 1 341 | self.threshold = 0.001 342 | self.sum_top_k = 0 343 | self.forward_n = 0 344 | self.dynamic_top_k = top_k 345 | 346 | def _sample_gumbel(self, tensor, eps=1e-10): 347 | U = torch.rand_like(tensor).uniform_() 348 | return - torch.log(eps - torch.log(U + eps)) 349 | 350 | def forward(self, inp, return_all_scores=False): 351 | 352 | gate = self.gate(inp) 353 | 354 | if self.training: 355 | # dts 356 | gumber_noise = self._sample_gumbel(gate) 357 | gate_noise = (gate + gumber_noise) / self.temperature 358 | gate_noise = F.softmax(gate_noise, dim=-1) 359 | 360 | # calculate top-k number 361 | enable_gate_number = gate_noise.gt(self.threshold).sum(dim=-1) 362 | dynamic_top_k = enable_gate_number.float().mean().int().item() 363 | self.dynamic_top_k = max(self.top_k, dynamic_top_k) 364 | 365 | self.forward_n += 1 366 | self.sum_top_k += self.dynamic_top_k 367 | 368 | gate_top_k_val, gate_top_k_idx = torch.topk( 369 | gate_noise, k=self.dynamic_top_k, dim=-1, largest=True, sorted=False 370 | ) # [.. x top_k] 371 | gate_score = gate_top_k_val.view(-1, self.dynamic_top_k) 372 | 373 | else: 374 | self.dynamic_top_k = self.top_k 375 | gate_top_k_val, gate_top_k_idx = torch.topk( 376 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 377 | ) # [.. x top_k] 378 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) 379 | gate_score = F.softmax(gate_top_k_val, dim=-1) 380 | 381 | if return_all_scores: 382 | return gate_top_k_idx, gate_score, gate 383 | return gate_top_k_idx, gate_score 384 | 385 | class CustomDTSRandomGate(BaseGate): 386 | r""" 387 | Dense to Sparse Gate Random Assign 388 | """ 389 | 390 | def __init__(self, d_model, num_expert, world_size, top_k=2): 391 | super().__init__(num_expert, world_size) 392 | self.gate = nn.Linear(d_model, self.tot_expert) 393 | self.top_k = top_k 394 | self.dense_moe_flag = False 395 | 396 | self.temperature = 1 397 | self.threshold = 0.001 398 | self.sum_top_k = 0 399 | self.forward_n = 0 400 | self.dynamic_top_k = top_k 401 | 402 | def _sample_gumbel(self, tensor, eps=1e-10): 403 | U = torch.rand_like(tensor).uniform_() 404 | return - torch.log(eps - torch.log(U + eps)) 405 | 406 | def forward(self, inp, return_all_scores=False): 407 | 408 | gate = self.gate(inp) 409 | gate = torch.rand_like(gate) 410 | 411 | if self.training: 412 | # dts 413 | gumber_noise = self._sample_gumbel(gate) 414 | gate_noise = (gate + gumber_noise) / self.temperature 415 | gate_noise = F.softmax(gate_noise, dim=-1) 416 | 417 | # calculate top-k number 418 | enable_gate_number = gate_noise.gt(self.threshold).sum(dim=-1) 419 | dynamic_top_k = enable_gate_number.float().mean().int().item() 420 | self.dynamic_top_k = max(self.top_k, dynamic_top_k) 421 | 422 | self.forward_n += 1 423 | self.sum_top_k += self.dynamic_top_k 424 | 425 | gate_top_k_val, gate_top_k_idx = torch.topk( 426 | gate_noise, k=self.dynamic_top_k, dim=-1, largest=True, sorted=False 427 | ) # [.. x top_k] 428 | gate_score = gate_top_k_val.view(-1, self.dynamic_top_k) 429 | 430 | else: 431 | self.dynamic_top_k = self.top_k 432 | gate_top_k_val, gate_top_k_idx = torch.topk( 433 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 434 | ) # [.. x top_k] 435 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) 436 | gate_score = F.softmax(gate_top_k_val, dim=-1) 437 | 438 | if return_all_scores: 439 | return gate_top_k_idx, gate_score, gate 440 | return gate_top_k_idx, gate_score 441 | 442 | class CustomDTSGate_softmax(BaseGate): 443 | r""" 444 | Dense to Sparse Gate 445 | """ 446 | 447 | def __init__(self, d_model, num_expert, world_size, top_k=2): 448 | super().__init__(num_expert, world_size) 449 | self.gate = nn.Linear(d_model, self.tot_expert) 450 | self.top_k = top_k 451 | self.dense_moe_flag = False 452 | 453 | self.temperature = 1 454 | self.threshold = 0.001 455 | self.sum_top_k = 0 456 | self.forward_n = 0 457 | self.dynamic_top_k = top_k 458 | 459 | def _sample_gumbel(self, tensor, eps=1e-10): 460 | U = torch.rand_like(tensor).uniform_() 461 | return - torch.log(eps - torch.log(U + eps)) 462 | 463 | def forward(self, inp, return_all_scores=False): 464 | 465 | gate = self.gate(inp) 466 | 467 | if self.training: 468 | # dts 469 | gumber_noise = self._sample_gumbel(gate) 470 | gate_noise = (gate + gumber_noise) / self.temperature 471 | gate_noise = F.softmax(gate_noise, dim=-1) 472 | 473 | # calculate top-k number 474 | enable_gate_number = gate_noise.gt(self.threshold).sum(dim=-1) 475 | dynamic_top_k = enable_gate_number.float().mean().int().item() 476 | self.dynamic_top_k = max(self.top_k, dynamic_top_k) 477 | 478 | self.forward_n += 1 479 | self.sum_top_k += self.dynamic_top_k 480 | 481 | gate_top_k_val, gate_top_k_idx = torch.topk( 482 | gate_noise, k=self.dynamic_top_k, dim=-1, largest=True, sorted=False 483 | ) # [.. x top_k] 484 | gate_score = gate_top_k_val.view(-1, self.dynamic_top_k) 485 | 486 | else: 487 | gate = F.softmax(gate, dim=-1) 488 | self.dynamic_top_k = self.top_k 489 | gate_top_k_val, gate_top_k_idx = torch.topk( 490 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 491 | ) # [.. x top_k] 492 | gate_score = gate_top_k_val.view(-1, self.top_k) 493 | 494 | if return_all_scores: 495 | return gate_top_k_idx, gate_score, gate 496 | return gate_top_k_idx, gate_score 497 | 498 | class CustomDTSRandomGate_softmax(BaseGate): 499 | r""" 500 | Dense to Sparse Gate Random Assign 501 | """ 502 | 503 | def __init__(self, d_model, num_expert, world_size, top_k=2): 504 | super().__init__(num_expert, world_size) 505 | self.gate = nn.Linear(d_model, self.tot_expert) 506 | self.top_k = top_k 507 | self.dense_moe_flag = False 508 | 509 | self.temperature = 1 510 | self.threshold = 0.001 511 | self.sum_top_k = 0 512 | self.forward_n = 0 513 | self.dynamic_top_k = top_k 514 | 515 | def _sample_gumbel(self, tensor, eps=1e-10): 516 | U = torch.rand_like(tensor).uniform_() 517 | return - torch.log(eps - torch.log(U + eps)) 518 | 519 | def forward(self, inp, return_all_scores=False): 520 | 521 | gate = self.gate(inp) 522 | gate = torch.rand_like(gate) 523 | 524 | if self.training: 525 | # dts 526 | gumber_noise = self._sample_gumbel(gate) 527 | gate_noise = (gate + gumber_noise) / self.temperature 528 | gate_noise = F.softmax(gate_noise, dim=-1) 529 | 530 | # calculate top-k number 531 | enable_gate_number = gate_noise.gt(self.threshold).sum(dim=-1) 532 | dynamic_top_k = enable_gate_number.float().mean().int().item() 533 | self.dynamic_top_k = max(self.top_k, dynamic_top_k) 534 | 535 | self.forward_n += 1 536 | self.sum_top_k += self.dynamic_top_k 537 | 538 | gate_top_k_val, gate_top_k_idx = torch.topk( 539 | gate_noise, k=self.dynamic_top_k, dim=-1, largest=True, sorted=False 540 | ) # [.. x top_k] 541 | gate_score = gate_top_k_val.view(-1, self.dynamic_top_k) 542 | 543 | else: 544 | gate = F.softmax(gate, dim=-1) 545 | self.dynamic_top_k = self.top_k 546 | gate_top_k_val, gate_top_k_idx = torch.topk( 547 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 548 | ) # [.. x top_k] 549 | gate_score = gate_top_k_val.view(-1, self.top_k) 550 | 551 | if return_all_scores: 552 | return gate_top_k_idx, gate_score, gate 553 | return gate_top_k_idx, gate_score 554 | 555 | class CustomDenseGate(BaseGate): 556 | r""" 557 | Dense Gate 558 | """ 559 | 560 | def __init__(self, d_model, num_expert, world_size, top_k=2): 561 | super().__init__(num_expert, world_size) 562 | self.gate = nn.Linear(d_model, self.tot_expert) 563 | self.top_k = top_k 564 | self.dense_moe_flag = False 565 | 566 | def forward(self, inp, return_all_scores=False): 567 | 568 | gate = self.gate(inp) 569 | repeat_shape = list(gate.shape[:-1]) 570 | repeat_shape.append(1) 571 | 572 | gate_top_k_idx = torch.arange(self.tot_expert).repeat(repeat_shape).to(gate.device) 573 | 574 | gate_top_k_val = gate.view(-1, self.tot_expert) 575 | gate_score = F.softmax(gate_top_k_val, dim=-1) 576 | 577 | if return_all_scores: 578 | return gate_top_k_idx, gate_score, gate 579 | return gate_top_k_idx, gate_score 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | -------------------------------------------------------------------------------- /custom_layers.py: -------------------------------------------------------------------------------- 1 | r""" 2 | FMoE core layer 3 | """ 4 | import tree 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | 9 | from functions import prepare_forward, ensure_comm 10 | from functions import MOEScatter, MOEGather 11 | from functions import AllGather, Slice 12 | from gates import NaiveGate 13 | 14 | from fastermoe.config import switch_from_env 15 | 16 | 17 | def mark_module_parallel_comm(module, comm): 18 | r""" 19 | Mark all parameters in `module` as doing data parallel in `comm`, where 20 | `comm` may be one of `'world', 'dp', 'none'`. 21 | """ 22 | for p in module.parameters(): 23 | setattr(p, "dp_comm", comm) 24 | 25 | 26 | def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size, **kwargs): 27 | r""" 28 | A private function that performs the following steps to complete the MoE 29 | computation. 30 | * Count the number of tokens from each worker to each expert. 31 | * Send the features to their target position so that input features to each 32 | expert are contiguous in memory. 33 | * Perform the forward computation of the experts using `expert_fn` 34 | * Gather the output features of experts back, and reorder them as sentences. 35 | Intermediate results like expert counts are hidden from users by this 36 | function. 37 | """ 38 | ( 39 | pos, 40 | local_expert_count, 41 | global_expert_count, 42 | fwd_expert_count, 43 | fwd_batch_size, 44 | ) = prepare_forward(gate, num_expert, world_size) 45 | topk = 1 46 | if len(gate.shape) == 2: 47 | topk = gate.shape[1] 48 | 49 | def scatter_func(tensor): 50 | return MOEScatter.apply( 51 | tensor, 52 | torch.div(pos, topk, rounding_mode='floor'), 53 | local_expert_count, 54 | global_expert_count, 55 | fwd_batch_size, 56 | world_size, 57 | ) 58 | 59 | x = tree.map_structure(scatter_func, inp) 60 | 61 | x = expert_fn(x, fwd_expert_count) 62 | 63 | out_batch_size = tree.flatten(inp)[0].shape[0] 64 | if len(gate.shape) == 2: 65 | out_batch_size *= gate.shape[1] 66 | 67 | def gather_func(tensor): 68 | return MOEGather.apply( 69 | tensor, 70 | pos, 71 | local_expert_count, 72 | global_expert_count, 73 | out_batch_size, 74 | world_size, 75 | ) 76 | 77 | outp = tree.map_structure(gather_func, x) 78 | return outp 79 | 80 | 81 | fmoe_faster_schedule = False 82 | if switch_from_env('FMOE_FASTER_SCHEDULE_ENABLE', False): 83 | fmoe_faster_schedule = True 84 | from .fastermoe.schedule import _fmoe_general_global_forward 85 | 86 | 87 | class FMoE(nn.Module): 88 | r""" 89 | A general moe implementation that supports an arbitrary module as the 90 | expert. 91 | * `num_expert` stands for the number of experts on **each** worker. 92 | * `world_size` stands for the total number of workers that contains 93 | different experts. 94 | * `slice_group` can be a torch's communication group, indicating that 95 | specific model parallel is applied across the group, and workers in the 96 | group hold the same copy of input feature, and requires the same copy of 97 | the output. For each worker, FMoE only computes the output of a certain 98 | slice of the input batch, and will all-gather the outputs after 99 | computation. 100 | * `top_k` stands for the number of experts each token is going to. 101 | * `gate` is a gate class which can found in `fmoe.gates`. 102 | * `expert` can be specified as a module class, it is used to generate 103 | `num_expert` expert modules. 104 | """ 105 | 106 | def __init__( 107 | self, 108 | num_expert=32, 109 | d_model=1024, 110 | world_size=1, 111 | mp_group=None, # being deprecated 112 | slice_group=None, 113 | moe_group=None, 114 | top_k=2, 115 | gate=NaiveGate, 116 | expert=None, 117 | gate_hook=None, 118 | mask=None, 119 | mask_dict=None, 120 | ): 121 | super().__init__() 122 | self.num_expert = num_expert 123 | self.d_model = d_model 124 | self.world_size = world_size 125 | 126 | self.slice_group = slice_group 127 | if mp_group is not None: 128 | print("[Warning] mp_group is being deprecated") 129 | self.slice_group = mp_group 130 | if self.slice_group is None: 131 | self.slice_size = 1 132 | self.slice_rank = 0 133 | else: 134 | self.slice_size = self.slice_group.size() 135 | self.slice_rank = self.slice_group.rank() 136 | 137 | self.top_k = top_k 138 | if type(expert) is list: 139 | self.experts = nn.ModuleList([e(d_model) for e in expert]) 140 | self.experts_fused = False 141 | self.num_expert = num_expert = len(expert) 142 | elif expert is not None: 143 | self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)]) 144 | self.experts_fused = False 145 | else: 146 | self.experts_fused = True 147 | 148 | self.gate = gate(d_model, num_expert, world_size, top_k) 149 | self.gate_hook = gate_hook 150 | self.mask = mask 151 | self.mask_dict = mask_dict 152 | self.moe_group = moe_group 153 | 154 | def expert_fn(self, inp, fwd_expert_count): 155 | r""" 156 | The default expert function which either calls the experts as a whole 157 | or as separate experts. 158 | """ 159 | if self.experts_fused: 160 | return self.experts(inp, fwd_expert_count) 161 | if isinstance(fwd_expert_count, torch.Tensor): 162 | fwd_expert_count = fwd_expert_count.cpu().numpy() 163 | outputs = [] 164 | base_idx = 0 165 | for i in range(self.num_expert): 166 | batch_size = fwd_expert_count[i] 167 | inp_slice = inp[base_idx : base_idx + batch_size] 168 | outputs.append(self.experts[i](inp_slice)) 169 | base_idx += batch_size 170 | return torch.cat(outputs, dim=0) 171 | 172 | def mark_parallel_comm(self, expert_dp_comm="none"): 173 | r""" 174 | Automatically mark the data parallel comms of the parameters within the 175 | module. This can be typically called at the end of the __init__ function 176 | in child classes. 177 | """ 178 | if self.experts is not None: 179 | comm = expert_dp_comm 180 | if isinstance(self.experts, list): 181 | for e in self.experts: 182 | mark_module_parallel_comm(e, comm) 183 | else: 184 | mark_module_parallel_comm(self.experts, comm) 185 | mark_module_parallel_comm(self.gate, "gate") 186 | 187 | def forward(self, moe_inp): 188 | 189 | r""" 190 | The FMoE module first computes gate output, and then conduct MoE forward 191 | according to the gate. The score of the selected gate given by the 192 | expert is multiplied to the experts' output tensors as a weight. 193 | """ 194 | 195 | moe_inp_batch_size = tree.flatten( 196 | tree.map_structure(lambda tensor: tensor.shape[0], moe_inp) 197 | ) 198 | assert all( 199 | [batch_size == moe_inp_batch_size[0] for batch_size in moe_inp_batch_size] 200 | ), "MoE inputs must have the same batch size" 201 | 202 | if self.world_size > 1: 203 | 204 | def ensure_comm_func(tensor): 205 | ensure_comm(tensor, self.moe_group) 206 | 207 | tree.map_structure(ensure_comm_func, moe_inp) 208 | if self.slice_size > 1: 209 | 210 | def slice_func(tensor): 211 | return Slice.apply( 212 | tensor, self.slice_rank, self.slice_size, self.slice_group 213 | ) 214 | 215 | moe_inp = tree.map_structure(slice_func, moe_inp) 216 | 217 | gate_top_k_idx, gate_score = self.gate(moe_inp) 218 | 219 | if hasattr(self.gate, 'dynamic_top_k'): 220 | self.top_k = self.gate.dynamic_top_k 221 | 222 | if self.gate_hook is not None: 223 | self.gate_hook(gate_top_k_idx, gate_score, None) 224 | 225 | # delete masked tensors 226 | if self.mask is not None and self.mask_dict is not None: 227 | # TODO: to fix 228 | def delete_mask_func(tensor): 229 | # to: (BxL') x d_model 230 | tensor = tensor[mask == 0, :] 231 | return tensor 232 | 233 | mask = self.mask.view(-1) 234 | moe_inp = tree.map_structure(delete_mask_func, moe_inp) 235 | gate_top_k_idx = gate_top_k_idx[mask == 0, :] 236 | 237 | fwd = _fmoe_general_global_forward( 238 | moe_inp, gate_top_k_idx, self.expert_fn, 239 | self.num_expert, self.world_size, 240 | experts=self.experts 241 | ) 242 | 243 | # recover deleted tensors 244 | if self.mask is not None and self.mask_dict is not None: 245 | 246 | def recover_func(tensor): 247 | # to: (BxL') x top_k x dim 248 | dim = tensor.shape[-1] 249 | tensor = tensor.view(-1, self.top_k, dim) 250 | # to: (BxL) x top_k x d_model 251 | x = torch.zeros( 252 | mask.shape[0], 253 | self.top_k, 254 | dim, 255 | device=tensor.device, 256 | dtype=tensor.dtype, 257 | ) 258 | # recover 259 | x[mask == 0] = tensor 260 | for k, v in self.mask_dict.items(): 261 | x[mask == k] = v 262 | return x 263 | 264 | moe_outp = tree.map_structure(recover_func, fwd) 265 | else: 266 | 267 | def view_func(tensor): 268 | dim = tensor.shape[-1] 269 | tensor = tensor.view(-1, self.top_k, dim) 270 | return tensor 271 | 272 | moe_outp = tree.map_structure(view_func, fwd) 273 | 274 | gate_score = gate_score.view(-1, 1, self.top_k) 275 | 276 | def bmm_func(tensor): 277 | dim = tensor.shape[-1] 278 | tensor = torch.bmm(gate_score, tensor).reshape(-1, dim) 279 | return tensor 280 | 281 | moe_outp = tree.map_structure(bmm_func, moe_outp) 282 | 283 | if self.slice_size > 1: 284 | 285 | def all_gather_func(tensor): 286 | return AllGather.apply( 287 | tensor, self.slice_rank, self.slice_size, self.slice_group 288 | ) 289 | 290 | moe_outp = tree.map_structure(all_gather_func, moe_outp) 291 | 292 | moe_outp_batch_size = tree.flatten( 293 | tree.map_structure(lambda tensor: tensor.shape[0], moe_outp) 294 | ) 295 | assert all( 296 | [batch_size == moe_outp_batch_size[0] for batch_size in moe_outp_batch_size] 297 | ), "MoE outputs must have the same batch size" 298 | return moe_outp 299 | -------------------------------------------------------------------------------- /custom_transformer.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Adaption to act as the MLP layer using an MoE MLP layer in transformer. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from custom_layers import FMoE 7 | from linear import FMoELinear 8 | 9 | 10 | class _Expert(nn.Module): 11 | r""" 12 | An expert using 2 FMoELinear modules to speed up the computation of experts 13 | within one worker. 14 | """ 15 | 16 | def __init__(self, num_expert, d_model, d_hidden, activation, rank=0): 17 | super().__init__() 18 | self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank) 19 | self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank) 20 | self.activation = activation 21 | 22 | def forward(self, inp, fwd_expert_count): 23 | r""" 24 | First expand input to 4h (the hidden size is variable, but is called h4 25 | for convenience). Then perform activation. Finally shirink back to h. 26 | """ 27 | x = self.htoh4(inp, fwd_expert_count) 28 | x = self.activation(x) 29 | x = self.h4toh(x, fwd_expert_count) 30 | return x 31 | 32 | 33 | class FMoETransformerMLP(FMoE): 34 | r""" 35 | A complete MoE MLP module in a Transformer block. 36 | * `activation` is the activation function to be used in MLP in each expert. 37 | * `d_hidden` is the dimension of the MLP layer. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | num_expert=32, 43 | d_model=1024, 44 | d_hidden=4096, 45 | activation=torch.nn.GELU(), 46 | expert_dp_comm="none", 47 | expert_rank=0, 48 | **kwargs 49 | ): 50 | super().__init__(num_expert=num_expert, d_model=d_model, **kwargs) 51 | self.experts = _Expert( 52 | num_expert, d_model, d_hidden, activation, rank=expert_rank 53 | ) 54 | self.mark_parallel_comm(expert_dp_comm) 55 | 56 | def forward(self, inp: torch.Tensor): 57 | r""" 58 | This module wraps up the FMoE module with reshape, residual and layer 59 | normalization. 60 | """ 61 | original_shape = inp.shape 62 | inp = inp.reshape(-1, self.d_model) 63 | output = super().forward(inp) 64 | return output.reshape(original_shape) 65 | -------------------------------------------------------------------------------- /custom_transformer2.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Adaption to act as the MLP layer using an MoE MLP layer in transformer. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from custom_layers import FMoE 7 | from linear import FMoELinear 8 | 9 | 10 | class _Expert(nn.Module): 11 | r""" 12 | An expert using 2 FMoELinear modules to speed up the computation of experts 13 | within one worker. 14 | """ 15 | 16 | def __init__(self, num_expert, d_model, d_hidden, activation, rank=0): 17 | super().__init__() 18 | self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank) 19 | 20 | def forward(self, inp, fwd_expert_count): 21 | r""" 22 | First expand input to 4h (the hidden size is variable, but is called h4 23 | for convenience). Then perform activation. Finally shirink back to h. 24 | """ 25 | x = self.htoh4(inp, fwd_expert_count) 26 | return x 27 | 28 | 29 | class FMoETransformerMLP(FMoE): 30 | r""" 31 | A complete MoE MLP module in a Transformer block. 32 | * `activation` is the activation function to be used in MLP in each expert. 33 | * `d_hidden` is the dimension of the MLP layer. 34 | """ 35 | 36 | def __init__( 37 | self, 38 | num_expert=32, 39 | d_model=1024, 40 | d_hidden=4096, 41 | activation=torch.nn.GELU(), 42 | expert_dp_comm="none", 43 | expert_rank=0, 44 | **kwargs 45 | ): 46 | super().__init__(num_expert=num_expert, d_model=d_model, **kwargs) 47 | self.experts = _Expert( 48 | num_expert, d_model, d_hidden, activation, rank=expert_rank 49 | ) 50 | self.mark_parallel_comm(expert_dp_comm) 51 | 52 | def forward(self, inp: torch.Tensor): 53 | r""" 54 | This module wraps up the FMoE module with reshape, residual and layer 55 | normalization. 56 | """ 57 | original_shape = inp.shape 58 | inp = inp.reshape(-1, self.d_model) 59 | output = super().forward(inp) 60 | return output.reshape(original_shape[0], original_shape[1], -1) 61 | -------------------------------------------------------------------------------- /custom_utils.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Utils to play with PyTorch. 3 | """ 4 | import torch.distributed as dist 5 | 6 | 7 | # pylint: disable=broad-except 8 | # pylint: disable=protected-access 9 | def get_torch_default_comm(): 10 | r""" 11 | The NCCL communicator is needed so that Fast MoE can perform customized 12 | communication operators in the C code. However, it is not a publicly 13 | available variable. Therefore, a hacking class of the `ProcessGroupNCCL` 14 | in Fast MoE's C code takes the `_default_pg` and tries to dig the 15 | communicator out from the object. As PyTorch's private interface varies from 16 | time to time, different hacking techniques are tried one-by-one to be 17 | compatible with various versions of PyTorch. 18 | """ 19 | try: 20 | comm = dist.distributed_c10d._get_default_group() 21 | return comm 22 | except Exception as _: 23 | pass 24 | try: 25 | comm = dist.distributed_c10d._default_pg 26 | if comm is not None: 27 | return comm 28 | except Exception as _: 29 | pass 30 | raise RuntimeError("Unsupported PyTorch version") 31 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import glob 3 | import pdb 4 | from collections import Counter, OrderedDict 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from utils.vocabulary import Vocab 9 | import torch.nn.functional as F 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | 13 | def pad_sequence_reverse(data): 14 | # data should be a list of 1D tensors 15 | 16 | assert data[0].dim() == 1 17 | device = data[0].device 18 | length_list = [] 19 | for item in data: 20 | length_list.append(item.shape[0]) 21 | max_length = max(length_list) 22 | 23 | # padding 24 | padded_data_list = [] 25 | for item in data: 26 | padded_item = torch.cat([torch.zeros(max_length - item.shape[0], dtype=item.dtype).to(device), item]).reshape(-1, 1) 27 | padded_data_list.append(padded_item) 28 | padded_data_list = torch.cat(padded_data_list, dim=1) 29 | return padded_data_list 30 | 31 | 32 | class LMOrderedIterator(object): 33 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None): 34 | """ 35 | data -- LongTensor -- the LongTensor is strictly ordered 36 | """ 37 | self.bsz = bsz 38 | self.bptt = bptt 39 | self.ext_len = ext_len if ext_len is not None else 0 40 | 41 | self.device = device 42 | 43 | # Work out how cleanly we can divide the dataset into bsz parts. 44 | self.n_step = data.size(0) // bsz 45 | 46 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 47 | data = data.narrow(0, 0, self.n_step * bsz) 48 | 49 | # Evenly divide the data across the bsz batches. 50 | self.data = data.view(bsz, -1).t().contiguous().to(device) 51 | 52 | # Number of mini-batches 53 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt 54 | 55 | def get_batch(self, i, bptt=None): 56 | if bptt is None: bptt = self.bptt 57 | seq_len = min(bptt, self.data.size(0) - 1 - i) 58 | 59 | end_idx = i + seq_len 60 | beg_idx = max(0, i - self.ext_len) 61 | 62 | data = self.data[beg_idx:end_idx] 63 | target = self.data[i+1:i+1+seq_len] 64 | 65 | return data, target, seq_len 66 | 67 | def get_fixlen_iter(self, start=0): 68 | for i in range(start, self.data.size(0) - 1, self.bptt): 69 | yield self.get_batch(i) 70 | 71 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 72 | max_len = self.bptt + max_deviation * std 73 | i = start 74 | while True: 75 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. 76 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) 77 | data, target, seq_len = self.get_batch(i, bptt) 78 | i += seq_len 79 | yield data, target, seq_len 80 | if i >= self.data.size(0) - 2: 81 | break 82 | 83 | def __iter__(self): 84 | return self.get_fixlen_iter() 85 | 86 | 87 | 88 | class SST2Iterator(object): 89 | def __init__(self, data, bsz): 90 | """ 91 | data: [encoded, labels] 92 | """ 93 | 94 | self.bsz = bsz 95 | 96 | self.encoded = data[0] 97 | self.labels = data[1] # Tensor 98 | 99 | self.n_step = self.labels.size(0) // bsz 100 | self.n_samples = self.labels.size(0) 101 | self.sequence_array = np.arange(self.n_samples) 102 | 103 | def get_batch(self, index_list): 104 | 105 | subencoded = [] 106 | mask_idx_pre = [] 107 | sublabels = [] 108 | 109 | for idx in index_list: 110 | subencoded.append(self.encoded[idx]) 111 | sublabels.append(self.labels[idx]) 112 | mask_idx_pre.append(torch.ones(self.encoded[idx].shape[0])) 113 | 114 | subencoded = pad_sequence_reverse(subencoded) 115 | mask_idx = 1 - pad_sequence_reverse(mask_idx_pre) 116 | length = mask_idx.shape[0] 117 | 118 | expand_mask_idx = mask_idx.unsqueeze(1).repeat(1, length, 1) # length, length, batch-size 119 | expand_mask_idx = ((expand_mask_idx + mask_idx)>0).byte() 120 | 121 | # mask_idx = pad_sequence(mask_idx) 122 | sublabels = torch.LongTensor(sublabels) 123 | 124 | return subencoded, expand_mask_idx, sublabels 125 | 126 | def get_varlen_iter(self, start=0): 127 | sample_array = np.random.permutation(self.n_samples) 128 | for i in range(self.n_step): 129 | sub_index = sample_array[i*self.bsz:i*self.bsz+self.bsz] 130 | yield self.get_batch(sub_index) 131 | 132 | def get_fixlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 133 | for i in range(self.n_step): 134 | sub_index = self.sequence_array[i*self.bsz:i*self.bsz+self.bsz] 135 | yield self.get_batch(sub_index) 136 | 137 | def __iter__(self): 138 | return self.get_fixlen_iter() 139 | 140 | 141 | class LMShuffledIterator(object): 142 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False): 143 | """ 144 | data -- list[LongTensor] -- there is no order among the LongTensors 145 | """ 146 | self.data = data 147 | 148 | self.bsz = bsz 149 | self.bptt = bptt 150 | self.ext_len = ext_len if ext_len is not None else 0 151 | 152 | self.device = device 153 | self.shuffle = shuffle 154 | 155 | def get_sent_stream(self): 156 | # index iterator 157 | epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \ 158 | else np.array(range(len(self.data))) 159 | 160 | # sentence iterator 161 | for idx in epoch_indices: 162 | yield self.data[idx] 163 | 164 | def stream_iterator(self, sent_stream): 165 | # streams for each data in the batch 166 | streams = [None] * self.bsz 167 | 168 | data = torch.LongTensor(self.bptt, self.bsz) 169 | target = torch.LongTensor(self.bptt, self.bsz) 170 | 171 | n_retain = 0 172 | 173 | while True: 174 | # data : [n_retain+bptt x bsz] 175 | # target : [bptt x bsz] 176 | data[n_retain:].fill_(-1) 177 | target.fill_(-1) 178 | 179 | valid_batch = True 180 | 181 | for i in range(self.bsz): 182 | n_filled = 0 183 | try: 184 | while n_filled < self.bptt: 185 | if streams[i] is None or len(streams[i]) <= 1: 186 | streams[i] = next(sent_stream) 187 | # number of new tokens to fill in 188 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled) 189 | # first n_retain tokens are retained from last batch 190 | data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \ 191 | streams[i][:n_new] 192 | target[n_filled:n_filled+n_new, i] = \ 193 | streams[i][1:n_new+1] 194 | streams[i] = streams[i][n_new:] 195 | n_filled += n_new 196 | except StopIteration: 197 | valid_batch = False 198 | break 199 | 200 | if not valid_batch: 201 | return 202 | 203 | data = data.to(self.device) 204 | target = target.to(self.device) 205 | 206 | yield data, target, self.bptt 207 | 208 | n_retain = min(data.size(0), self.ext_len) 209 | if n_retain > 0: 210 | data[:n_retain] = data[-n_retain:] 211 | data.resize_(n_retain + self.bptt, data.size(1)) 212 | 213 | def __iter__(self): 214 | # sent_stream is an iterator 215 | sent_stream = self.get_sent_stream() 216 | 217 | for batch in self.stream_iterator(sent_stream): 218 | yield batch 219 | 220 | 221 | class LMMultiFileIterator(LMShuffledIterator): 222 | def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None, 223 | shuffle=False): 224 | 225 | self.paths = paths 226 | self.vocab = vocab 227 | 228 | self.bsz = bsz 229 | self.bptt = bptt 230 | self.ext_len = ext_len if ext_len is not None else 0 231 | 232 | self.device = device 233 | self.shuffle = shuffle 234 | 235 | def get_sent_stream(self, path): 236 | sents = self.vocab.encode_file(path, add_double_eos=True) 237 | if self.shuffle: 238 | np.random.shuffle(sents) 239 | sent_stream = iter(sents) 240 | 241 | return sent_stream 242 | 243 | def __iter__(self): 244 | if self.shuffle: 245 | np.random.shuffle(self.paths) 246 | 247 | for path in self.paths: 248 | # sent_stream is an iterator 249 | sent_stream = self.get_sent_stream(path) 250 | for batch in self.stream_iterator(sent_stream): 251 | yield batch 252 | 253 | 254 | class Corpus(object): 255 | def __init__(self, path, dataset, *args, **kwargs): 256 | self.dataset = dataset 257 | self.vocab = Vocab(*args, **kwargs) 258 | 259 | if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']: 260 | self.vocab.count_file(os.path.join(path, 'train.txt')) 261 | self.vocab.count_file(os.path.join(path, 'valid.txt')) 262 | self.vocab.count_file(os.path.join(path, 'test.txt')) 263 | elif self.dataset == 'wt103': 264 | self.vocab.count_file(os.path.join(path, 'train.txt')) 265 | elif self.dataset == 'lm1b': 266 | train_path_pattern = os.path.join( 267 | path, '1-billion-word-language-modeling-benchmark-r13output', 268 | 'training-monolingual.tokenized.shuffled', 'news.en-*') 269 | train_paths = glob.glob(train_path_pattern) 270 | # the vocab will load from file when build_vocab() is called 271 | 272 | elif self.dataset == 'csqa': 273 | self.vocab.count_csqa(os.path.join(path, 'train_rand_split.jsonl'), add_cls_token=True) 274 | self.vocab.count_csqa(os.path.join(path, 'dev_rand_split.jsonl'), add_cls_token=True) 275 | self.vocab.count_csqa(os.path.join(path, 'test_rand_split_no_answers.jsonl'), add_cls_token=True) 276 | 277 | elif self.dataset in ['sst2', 'sst2_v2']: 278 | self.vocab.count_sst2(os.path.join(path, 'train.tsv'), add_cls_token=True) 279 | self.vocab.count_sst2(os.path.join(path, 'dev.tsv'), add_cls_token=True) 280 | self.vocab.count_sst2(os.path.join(path, 'test.tsv'), add_cls_token=True) 281 | 282 | self.vocab.build_vocab() 283 | 284 | if self.dataset in ['ptb', 'wt2', 'wt103']: 285 | self.train = self.vocab.encode_file( 286 | os.path.join(path, 'train.txt'), ordered=True) 287 | self.valid = self.vocab.encode_file( 288 | os.path.join(path, 'valid.txt'), ordered=True) 289 | self.test = self.vocab.encode_file( 290 | os.path.join(path, 'test.txt'), ordered=True) 291 | elif self.dataset in ['enwik8', 'text8']: 292 | self.train = self.vocab.encode_file( 293 | os.path.join(path, 'train.txt'), ordered=True, add_eos=False) 294 | self.valid = self.vocab.encode_file( 295 | os.path.join(path, 'valid.txt'), ordered=True, add_eos=False) 296 | self.test = self.vocab.encode_file( 297 | os.path.join(path, 'test.txt'), ordered=True, add_eos=False) 298 | elif self.dataset == 'lm1b': 299 | self.train = train_paths 300 | self.valid = self.vocab.encode_file( 301 | os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True) 302 | self.test = self.vocab.encode_file( 303 | os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True) 304 | elif self.dataset == 'csqa': 305 | self.train = self.vocab.encode_csqa_file( 306 | os.path.join(path, 'train_rand_split.jsonl'), ordered=True, add_cls_token=True) 307 | self.valid = self.vocab.encode_csqa_file( 308 | os.path.join(path, 'dev_rand_split.jsonl'), ordered=True, add_cls_token=True) 309 | elif self.dataset == 'sst2': 310 | self.train = self.vocab.encode_sst2_file( 311 | os.path.join(path, 'train.tsv'), add_cls_token=True) 312 | self.valid = self.vocab.encode_sst2_file( 313 | os.path.join(path, 'dev.tsv'), add_cls_token=True) 314 | elif self.dataset == 'sst2_v2': 315 | self.train = self.vocab.encode_sst2_file_v2( 316 | os.path.join(path, 'train.tsv'), add_cls_token_last=True) 317 | self.valid = self.vocab.encode_sst2_file_v2( 318 | os.path.join(path, 'dev.tsv'), add_cls_token_last=True) 319 | 320 | 321 | def get_iterator(self, split, *args, **kwargs): 322 | 323 | if split == 'train': 324 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 325 | data_iter = LMOrderedIterator(self.train, *args, **kwargs) 326 | elif self.dataset == 'lm1b': 327 | kwargs['shuffle'] = True 328 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) 329 | elif self.dataset == 'csqa': 330 | data_iter = CSQAIterator(self.train, *args, **kwargs) 331 | elif self.dataset == 'sst2': 332 | data_iter = SST2Iterator(self.train, *args, **kwargs) 333 | # dataset = CSQADataset(self.train) 334 | # data_iter = DataLoader(dataset, *args, shuffle=True, 335 | # num_workers=4, drop_last=False, pin_memory=True) 336 | 337 | elif split in ['valid', 'test']: 338 | data = self.valid if split == 'valid' else self.test 339 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 340 | data_iter = LMOrderedIterator(data, *args, **kwargs) 341 | elif self.dataset == 'lm1b': 342 | data_iter = LMShuffledIterator(data, *args, **kwargs) 343 | elif self.dataset == 'csqa': 344 | data_iter = CSQAIterator(self.valid, *args, **kwargs) 345 | elif self.dataset == 'sst2': 346 | data_iter = SST2Iterator(self.valid, *args, **kwargs) 347 | 348 | # dataset = CSQADataset(self.valid) 349 | # data_iter = DataLoader(dataset, *args, shuffle=False, 350 | # num_workers=4, drop_last=False, pin_memory=True) 351 | return data_iter 352 | 353 | 354 | def get_lm_corpus(datadir, dataset): 355 | 356 | fn = os.path.join(datadir, 'cache.pt') 357 | if os.path.exists(fn): 358 | print('Loading cached dataset...') 359 | corpus = torch.load(fn) 360 | else: 361 | print('Producing dataset {}...'.format(dataset)) 362 | kwargs = {} 363 | if dataset in ['wt103', 'wt2']: 364 | kwargs['special'] = [''] 365 | kwargs['lower_case'] = False 366 | elif dataset == 'ptb': 367 | kwargs['special'] = [''] 368 | kwargs['lower_case'] = True 369 | elif dataset == 'lm1b': 370 | kwargs['special'] = [] 371 | kwargs['lower_case'] = False 372 | kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt') 373 | elif dataset in ['csqa', 'sst2', 'sst2_v2']: 374 | kwargs['special'] = [''] 375 | elif dataset in ['enwik8', 'text8']: 376 | pass 377 | 378 | corpus = Corpus(datadir, dataset, **kwargs) 379 | torch.save(corpus, fn) 380 | 381 | return corpus 382 | 383 | if __name__ == '__main__': 384 | import argparse 385 | parser = argparse.ArgumentParser(description='unit test') 386 | parser.add_argument('--datadir', type=str, default='../data/text8', 387 | help='location of the data corpus') 388 | parser.add_argument('--dataset', type=str, default='text8', 389 | choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'], 390 | help='dataset name') 391 | args = parser.parse_args() 392 | 393 | corpus = get_lm_corpus(args.datadir, args.dataset) 394 | print('Vocab size : {}'.format(len(corpus.vocab.idx2sym))) 395 | -------------------------------------------------------------------------------- /fastermoe/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/Random-MoE-as-Dropout/0272cead5067d40108b4209ba87d512949dd7580/fastermoe/__init__.py -------------------------------------------------------------------------------- /fastermoe/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def float_from_env(key, default=-1): 5 | if key in os.environ: 6 | return float(os.environ[key]) 7 | return default 8 | 9 | 10 | def switch_from_env(key, default=False): 11 | if key in os.environ: 12 | return os.environ[key] in ['1', 'ON'] 13 | return default 14 | -------------------------------------------------------------------------------- /fastermoe/expert_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_expert_param_size(e): 5 | return sum(map(lambda x: x.numel(), e.parameters())) 6 | 7 | 8 | def get_expert_params(e, out): 9 | offset = 0 10 | for n, p in e.named_parameters(): 11 | seg = out[offset:offset + p.numel()] 12 | offset += p.numel() 13 | seg.copy_(p.data.flatten()) 14 | 15 | 16 | def stash_expert_params(e, params): 17 | if not hasattr(e, 'expert_param_stash'): 18 | setattr(e, 'expert_param_stash', dict()) 19 | offset = 0 20 | for n, p in e.named_parameters(): 21 | if n not in e.expert_param_stash: 22 | e.expert_param_stash[n] = p.data.clone() 23 | with torch.no_grad(): 24 | seg = params[offset:offset + p.numel()] 25 | offset += p.numel() 26 | p.copy_(seg.reshape(p.shape)) 27 | 28 | 29 | def pop_expert_params(e): 30 | if not hasattr(e, 'expert_param_stash'): 31 | return 32 | for n, p in e.named_parameters(): 33 | with torch.no_grad(): 34 | p.copy_(e.expert_param_stash[n]) 35 | e.expert_param_stash.clear() 36 | 37 | 38 | def collect_expert_grads(e, grads): 39 | offset = 0 40 | for _, p in e.named_parameters(): 41 | seg = grads[offset:offset + p.numel()] 42 | offset += p.numel() 43 | if p.grad is not None: 44 | seg.copy_(p.grad.flatten()) 45 | p.grad = None 46 | else: 47 | seg.zero_() 48 | 49 | 50 | def set_grads(e, grads): 51 | offset = 0 52 | for n, p in e.named_parameters(): 53 | seg = grads[offset:offset + p.numel()] 54 | offset += p.numel() 55 | if p.grad is None: 56 | p.grad = seg.clone() 57 | else: 58 | p.grad += seg.reshape(p.shape) 59 | -------------------------------------------------------------------------------- /fastermoe/schedule.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The smart schedule proposed in FasterMoE. 3 | """ 4 | import torch 5 | from torch.autograd.function import Function 6 | 7 | from fmoe.functions import prepare_forward, ensure_comm 8 | from fmoe.functions import _local_scatter, _local_gather 9 | import fmoe_cuda as fmoe_native 10 | from fmoe.fastermoe import expert_utils 11 | 12 | from .shadow_policy import get_shadow_policy 13 | 14 | 15 | class MoEForward(Function): 16 | @staticmethod 17 | def forward( 18 | ctx, 19 | expert_fn, 20 | experts, 21 | inp, # models, 22 | pos_s, pos_g, 23 | local_expert_count, global_expert_count, 24 | stored_models, 25 | fwd_batch_size, out_batch_size, 26 | world_size): 27 | local_input_buf = _local_scatter(inp, pos_s) 28 | 29 | ctx.gibs = [None] * (world_size * 2) 30 | ctx.gobs = [None] * (world_size * 2) 31 | def _expert_forward(x, y, idx): 32 | nothing = lambda a: a 33 | x = x.data 34 | with torch.enable_grad(): 35 | x.requires_grad = True 36 | # To skip torch autograd's version check. 37 | with torch.autograd.graph.saved_tensors_hooks(nothing, nothing): 38 | y0 = expert_fn(x, [x.shape[0]]) 39 | ctx.gibs[idx] = x 40 | ctx.gobs[idx] = y0 41 | y.copy_(y0) 42 | 43 | ctx.experts = experts 44 | if stored_models.any(): 45 | ctx.expert_size = expert_utils.get_expert_param_size(experts) 46 | else: 47 | ctx.expert_size = 0 48 | get_param_fn = lambda out: expert_utils.get_expert_params(experts, out) 49 | pop_fn = lambda: expert_utils.pop_expert_params(experts) 50 | ctx.shadows = [None] * world_size 51 | def stash_fn(params, idx): 52 | expert_utils.stash_expert_params(experts, params) 53 | ctx.shadows[idx] = params 54 | 55 | local_output_buf, gib = fmoe_native.smart_sch_forward( 56 | local_input_buf, 57 | local_expert_count, global_expert_count, 58 | stored_models, fwd_batch_size, ctx.expert_size, 59 | world_size, _expert_forward, get_param_fn, stash_fn, pop_fn) 60 | 61 | out = _local_gather(local_output_buf, pos_g, out_batch_size, 62 | maybe_overlap=False) 63 | 64 | # gib and local_input_buf are necessary, because ctx.gibs are created 65 | # based on their memory 66 | variables = (pos_s, pos_g, local_expert_count, global_expert_count, 67 | stored_models, gib, local_input_buf) 68 | 69 | ctx.moe_args = fwd_batch_size, inp.shape[0], world_size 70 | ctx.save_for_backward(*variables) 71 | 72 | return out 73 | 74 | @staticmethod 75 | def backward(ctx, grad_out): 76 | (pos_s, pos_g, local_expert_count, global_expert_count, 77 | stored_models, _1, _2) = ctx.saved_tensors 78 | (fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args 79 | 80 | def _expert_backward(grad_y, grad_x, idx): 81 | y = ctx.gobs[idx] 82 | x = ctx.gibs[idx] 83 | torch.autograd.backward([y], [grad_y]) 84 | grad_x.copy_(x.grad) 85 | 86 | experts = ctx.experts 87 | def stash_fn(idx): 88 | expert_utils.stash_expert_params(experts, ctx.shadows[idx]) 89 | pop_fn = lambda: expert_utils.pop_expert_params(experts) 90 | def collect_fn(idx, root): 91 | grad = ctx.shadows[idx] 92 | expert_utils.collect_expert_grads(experts, grad) 93 | fmoe_native.reduce_grad(grad, root, ctx.expert_size) 94 | set_grad_fn = lambda idx: expert_utils.set_grads(experts, ctx.shadows[idx]) 95 | 96 | grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g) 97 | grad_in_buf = fmoe_native.smart_sch_backward( 98 | grad_out_buf, 99 | local_expert_count, global_expert_count, 100 | stored_models, 101 | pos_s.shape[0], fwd_batch_size, 102 | world_size, 103 | _expert_backward, stash_fn, pop_fn, collect_fn, set_grad_fn) 104 | grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size) 105 | 106 | return (None, None, grad_in, None, None, None, None, None, None, None, None) 107 | 108 | 109 | policy_fn = None 110 | 111 | 112 | def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None, stored_models=None): 113 | # TODO: Using multiple tensors as input is to be supported. 114 | assert(isinstance(inp, torch.Tensor)) 115 | # TODO: Support many experts on each process 116 | assert(n_expert == 1) 117 | ( 118 | pos, 119 | local_expert_count, 120 | global_expert_count, 121 | fwd_expert_count, 122 | fwd_batch_size, 123 | ) = prepare_forward(gate, n_expert, world_size) 124 | 125 | global policy_fn 126 | if policy_fn is None: 127 | policy_fn = get_shadow_policy(d_model=inp.shape[-1]) 128 | 129 | if stored_models is None: 130 | stored_models = policy_fn(local_expert_count, global_expert_count, 131 | n_expert, world_size) 132 | 133 | topk = 1 134 | if len(gate.shape) == 2: 135 | topk = gate.shape[1] 136 | out_batch_size = inp.shape[0] * topk 137 | 138 | return MoEForward.apply(expert_fn, experts, inp, 139 | torch.div(pos, topk, rounding_mode='floor'), pos, 140 | local_expert_count, global_expert_count, stored_models, 141 | fwd_batch_size, out_batch_size, world_size) 142 | -------------------------------------------------------------------------------- /fastermoe/shadow_policy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | 6 | from .config import float_from_env, switch_from_env 7 | from fmoe.functions import get_moe_group 8 | 9 | 10 | def global_policy(local_expert_count, _gec, num_expert, world_size): 11 | r""" 12 | This is the policy for two-layer MLPs, using the formula in the PPoPP paper. 13 | A few parameters are used in this policy. 14 | * `d_model`: feature length of the MLP input and output. 15 | * `alpha`: the ratio of the MLP's hidden size to `d_model`. 16 | * `bw_net`: bandwidth of the network (GBps) 17 | * `bw_mm`: computation throughput of performing GeMM (FLOPs) 18 | """ 19 | bw_net = float_from_env('FMOE_FASTER_GLBPLC_NETBW', 50 * 1e9 / 8) 20 | bw_mm = float_from_env('FMOE_FASTER_GLBPLC_GPUTP', 11.5e12) 21 | alpha = float_from_env('FMOE_FASTER_GLBPLC_ALPHA', 2) 22 | d_model = float_from_env('FMOE_FASTER_GLBPLC_DMODEL', 2048) 23 | 24 | moe_group = get_moe_group() 25 | local_expert_count = local_expert_count.cuda() 26 | agecs = [torch.empty_like(local_expert_count) for _ in range(world_size)] 27 | dist.all_gather(agecs, local_expert_count, group=moe_group) 28 | all_global_expert_count = torch.stack(agecs) 29 | 30 | # TODO: data type other than float 31 | data_size = 4 32 | 33 | fwd_expert_counts = all_global_expert_count.sum(1).cpu() 34 | B_ws, indices = fwd_expert_counts.flatten().sort(0, descending=True) 35 | 36 | alphaH2 = alpha * (d_model ** 2) 37 | B_w = B_ws[0] 38 | 39 | comm = float('+inf') 40 | send_feature_time = d_model * data_size / bw_net 41 | send_model_time = 2 * alphaH2 * data_size / bw_net 42 | comp_time = 4 * alphaH2 / bw_mm 43 | lat_base = 3 * comp_time * B_w + 4 * send_feature_time * B_w 44 | 45 | res = torch.zeros(world_size * num_expert, dtype=torch.bool) 46 | shadow_time = 0 47 | 48 | for i, index in enumerate(indices): 49 | if i + 1 == indices.numel(): 50 | break 51 | B_k = B_ws[i + 1] 52 | shadow_time += send_model_time 53 | lat_new = 3 * comp_time * B_k + 4 * send_feature_time * B_k + shadow_time 54 | 55 | if lat_new < lat_base: 56 | lat_base = lat_new 57 | res[index] = True 58 | else: 59 | break 60 | return res 61 | 62 | 63 | def no_shadow_policy(_lec, _gec, num_expert, world_size): 64 | res = torch.zeros(world_size * num_expert, dtype=bool) 65 | return res 66 | 67 | 68 | def get_shadow_policy(d_model=None): 69 | if d_model is not None and 'FMOE_FASTER_GLBPLC_DMODEL' not in os.environ: 70 | os.environ['FMOE_FASTER_GLBPLC_DMODEL'] = str(d_model) 71 | if not switch_from_env('FMOE_FASTER_SHADOW_ENABLE'): 72 | return no_policy 73 | return global_policy 74 | -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The fmoe.functions module contains functions that are directly warped up from 3 | C/CUDA functions to complete distributed communication, computation and gradient 4 | computation. 5 | """ 6 | 7 | import torch 8 | from torch.autograd import Function 9 | import fmoe_cuda 10 | from custom_utils import get_torch_default_comm 11 | 12 | 13 | _moe_group = None 14 | 15 | 16 | def ensure_comm(t, comm): 17 | if comm is None: 18 | comm = get_torch_default_comm() 19 | global _moe_group 20 | _moe_group = comm 21 | fmoe_cuda.ensure_nccl(comm, t) 22 | 23 | 24 | def get_moe_group(): 25 | return _moe_group 26 | 27 | 28 | def count_by_gate(gate, num_expert, world_size, require_pos=True): 29 | with torch.no_grad(): 30 | local_expert_count = torch.zeros( 31 | num_expert * world_size, device=gate.device, dtype=torch.int32 32 | ) 33 | fmoe_cuda.expert_count(gate, local_expert_count) 34 | local_expert_count = local_expert_count.long() 35 | 36 | if world_size > 1: 37 | global_expert_count = fmoe_cuda.expert_exchange( 38 | local_expert_count, num_expert, world_size 39 | ) 40 | else: 41 | global_expert_count = local_expert_count 42 | if not require_pos: 43 | pos = None 44 | else: 45 | lec_cum = torch.cumsum(local_expert_count, dim=0).int() 46 | pos_size = lec_cum[-1].item() 47 | pos = torch.empty((pos_size,), device=gate.device, dtype=torch.long) 48 | fmoe_cuda.assign_pos(lec_cum, gate, pos) 49 | return pos, local_expert_count, global_expert_count 50 | 51 | 52 | def prepare_forward(gate, num_expert, world_size): 53 | r""" 54 | Prepare necessary information from gate output for MoE computation. 55 | 56 | Args: 57 | gate: a 1-d Long Tensor representing the target expert of each input 58 | sample. 59 | num_expert: number of experts on each worker. 60 | world_size: number of workers that hold different experts. 61 | comm: the communicator of all workers in the expert-parallel group. 62 | """ 63 | pos, local_expert_count, global_expert_count = count_by_gate(gate, 64 | num_expert, world_size) 65 | with torch.no_grad(): 66 | fwd_expert_count = global_expert_count.view(world_size, 67 | num_expert).sum(dim=0) 68 | fwd_batch_size = int(fwd_expert_count.sum().item()) 69 | return ( 70 | pos, 71 | local_expert_count.cpu(), 72 | global_expert_count.cpu(), 73 | fwd_expert_count.cpu(), 74 | fwd_batch_size, 75 | ) 76 | 77 | 78 | def _local_scatter(inp, pos): 79 | inp_buf = torch.index_select(inp, 0, pos) 80 | return inp_buf 81 | 82 | 83 | def _local_gather(inp, pos, out_batch_size, maybe_overlap=True): 84 | inp_buf = torch.zeros(out_batch_size, inp.shape[-1], 85 | dtype=inp.dtype, device=inp.device) 86 | if maybe_overlap: 87 | inp_buf.index_add_(0, pos, inp) 88 | else: 89 | inp_buf.index_copy_(0, pos, inp) 90 | return inp_buf 91 | 92 | 93 | class MOEScatter(Function): 94 | r""" 95 | Scatter input samples from [batch x sequences] to contiguous alone experts. 96 | If `world_size` is greater than 1, the samples will first be locally 97 | scattered, and then exchanged across workers. 98 | """ 99 | 100 | @staticmethod 101 | def forward( 102 | ctx, 103 | inp, 104 | pos, 105 | local_expert_count, 106 | global_expert_count, 107 | fwd_batch_size, 108 | world_size, 109 | ): 110 | local_input_buf = _local_scatter(inp, pos) 111 | if world_size > 1: 112 | global_input_buf = fmoe_cuda.global_scatter( 113 | local_input_buf, 114 | local_expert_count, 115 | global_expert_count, 116 | fwd_batch_size, 117 | world_size, 118 | ) 119 | else: 120 | global_input_buf = local_input_buf 121 | ctx.moe_args = inp.shape[0], pos.shape[0], world_size 122 | variables = (pos, local_expert_count, global_expert_count) 123 | ctx.save_for_backward(*variables) 124 | return global_input_buf 125 | 126 | @staticmethod 127 | def backward(ctx, global_grad_in): 128 | (pos, local_expert_count, global_expert_count) = ctx.saved_tensors 129 | (inp_batch_size, buf_batch_size, world_size) = ctx.moe_args 130 | 131 | if world_size > 1: 132 | local_grad_in = fmoe_cuda.global_gather( 133 | global_grad_in, 134 | local_expert_count, 135 | global_expert_count, 136 | buf_batch_size, 137 | world_size, 138 | ) 139 | else: 140 | local_grad_in = global_grad_in 141 | grad_in = _local_gather(local_grad_in, pos, inp_batch_size) 142 | return grad_in, None, None, None, None, None 143 | 144 | class MOEGather(Function): 145 | r""" 146 | Gather output samples from contiguous alone experts back to [batch x 147 | sequences]. Works symmetrically with MOEScatter. 148 | """ 149 | 150 | @staticmethod 151 | def forward( 152 | ctx, 153 | global_output_buf, 154 | pos, 155 | local_expert_count, 156 | global_expert_count, 157 | local_batch_size, 158 | world_size, 159 | ): 160 | if world_size > 1: 161 | local_output_buf = fmoe_cuda.global_gather( 162 | global_output_buf, 163 | local_expert_count, 164 | global_expert_count, 165 | pos.shape[0], 166 | world_size, 167 | ) 168 | else: 169 | local_output_buf = global_output_buf 170 | output = _local_gather(local_output_buf, pos, local_batch_size, 171 | maybe_overlap=False) 172 | 173 | ctx.moe_args = (global_output_buf.shape[0], world_size) 174 | variables = (pos, local_expert_count, global_expert_count) 175 | ctx.save_for_backward(*variables) 176 | return output 177 | 178 | @staticmethod 179 | def backward(ctx, grad_out): 180 | pos, local_expert_count, global_expert_count = ctx.saved_tensors 181 | fwd_batch_size, world_size = ctx.moe_args 182 | grad_out_buf = _local_scatter(grad_out.contiguous(), pos) 183 | if world_size > 1: 184 | global_grad_out_buf = fmoe_cuda.global_scatter( 185 | grad_out_buf, 186 | local_expert_count, 187 | global_expert_count, 188 | fwd_batch_size, 189 | world_size, 190 | ) 191 | else: 192 | global_grad_out_buf = grad_out_buf 193 | return global_grad_out_buf, None, None, None, None, None 194 | 195 | 196 | class AllGather(Function): 197 | r""" 198 | A wrapper for the All-Gather function to support auto-differentiation. 199 | """ 200 | 201 | @staticmethod 202 | def forward(ctx, inp, rank, world_size, group): 203 | tensor_list = [torch.empty_like(inp) for _ in range(world_size)] 204 | torch.distributed.all_gather(tensor_list, inp, group=group) 205 | torch.cuda.synchronize() 206 | output = torch.cat(tensor_list, dim=0) 207 | ctx.args = rank, inp.shape[0] 208 | return output 209 | 210 | @staticmethod 211 | def backward(ctx, grad_out): 212 | rank, dim0 = ctx.args 213 | return grad_out[rank * dim0 : (rank + 1) * dim0], None, None, None 214 | 215 | 216 | class Slice(Function): 217 | r""" 218 | A wrapper for the Slice function to support auto-differentiation. 219 | """ 220 | 221 | @staticmethod 222 | def forward(ctx, inp, rank, world_size, group): 223 | B: int = inp.shape[0] 224 | local_batch_size = B // world_size 225 | batch_start = local_batch_size * rank 226 | batch_end = min(batch_start + local_batch_size, B) 227 | inp = inp[batch_start:batch_end] 228 | ctx.args = world_size, group 229 | return inp 230 | 231 | @staticmethod 232 | def backward(ctx, grad_out): 233 | world_size, group = ctx.args 234 | tensor_list = [torch.empty_like(grad_out) for _ in range(world_size)] 235 | torch.distributed.all_gather(tensor_list, grad_out, group=group) 236 | torch.cuda.synchronize() 237 | grad_out = torch.cat(tensor_list, dim=0) 238 | return grad_out, None, None, None 239 | -------------------------------------------------------------------------------- /gates/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Different implementations of the Gate are located in separate files here. 3 | """ 4 | from .zero_gate import ZeroGate 5 | from .naive_gate import NaiveGate 6 | from .noisy_gate import NoisyGate 7 | 8 | from .gshard_gate import GShardGate 9 | from .switch_gate import SwitchGate 10 | 11 | from .swipe_gate import SwipeGate 12 | -------------------------------------------------------------------------------- /gates/base_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Base gate with standard interface 3 | """ 4 | import torch.nn as nn 5 | 6 | 7 | class BaseGate(nn.Module): 8 | def __init__(self, num_expert, world_size): 9 | super().__init__() 10 | self.world_size = world_size 11 | self.num_expert = num_expert 12 | self.tot_expert = world_size * num_expert 13 | self.loss = None 14 | 15 | def forward(self, x): 16 | raise NotImplementedError('Base gate cannot be directly used for fwd') 17 | 18 | def set_loss(self, loss): 19 | self.loss = loss 20 | 21 | def get_loss(self, clear=True): 22 | loss = self.loss 23 | if clear: 24 | self.loss = None 25 | return loss 26 | 27 | @property 28 | def has_loss(self): 29 | return self.loss is not None 30 | -------------------------------------------------------------------------------- /gates/faster_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The example topology-aware gate for two-layer tree-like topology, proposed by 3 | the PPoPP'22 paper, FasterMoE. Limited number of tokens are sent across the 4 | upper-level slow connection, and other ones are re-directed to experts in the 5 | local network. 6 | 7 | The number of GPUs to form such a local network is defined by an environment 8 | variable `FMOE_TOPO_GPUS_PER_NODE`, and it is by default `8`. 9 | 10 | The fraction of tokens that are allowed to be sent across nodes is defined by 11 | an environement variable `FMOE_TOPO_OUTGOING_FRACTION`, and it is by default 12 | `0.14`. Users are supposed to set the proper value in their own environemnt, 13 | guided by some performance model, to achieve maximum throughput. 14 | """ 15 | from .naive_gate import NaiveGate 16 | 17 | import os 18 | import sys 19 | import torch 20 | import torch.nn.functional as F 21 | from .utils import limit_by_capacity 22 | import fmoe_cuda 23 | from fmoe.functions import count_by_gate 24 | 25 | 26 | nw_per_node = 8 27 | try: 28 | nw_per_node = int(os.environ['FMOE_TOPO_GPUS_PER_NODE']) 29 | except Exception: 30 | pass 31 | 32 | 33 | class FasterGate(NaiveGate): 34 | def __init__(self, d_model, n_expert, world_size, node_rank): 35 | super().__init__(d_model, n_expert, world_size, top_k=2) 36 | self.ne_per_node = nw_per_node * n_expert 37 | self.ogn_ratio = .14 38 | try: 39 | self.ogn_ratio = float(os.environ['FMOE_TOPO_OUTGOING_FRACTION']) 40 | except Exception: 41 | pass 42 | self.node_rank = node_rank 43 | 44 | mask = [1] * world_size * n_expert 45 | for i in range(n_expert * world_size): 46 | if i // self.ne_per_node == self.node_rank: 47 | mask[i] = 0 48 | self.mask = torch.Tensor(mask).bool() 49 | self.policy_fn = None 50 | print('node rank {} mask {}'.format(node_rank, mask)) 51 | 52 | def forward(self, inp): 53 | if self.mask.device != inp.device: 54 | self.mask = self.mask.to(inp.device) 55 | 56 | gate_score = self.gate(inp) 57 | lim_mask = self.mask 58 | 59 | top2_val, top2_idx = torch.topk(gate_score, k=2, dim=-1) 60 | S = gate_score.shape[0] 61 | top_k = 2 62 | 63 | with torch.no_grad(): 64 | top1_idx = top2_idx.view((-1, top_k))[:, 0] 65 | top1_val = top2_val.view((-1, top_k))[:, 0] 66 | c_e = torch.scatter_add( 67 | torch.zeros(self.tot_expert, device=top1_idx.device), 68 | 0, 69 | top1_idx, 70 | torch.ones_like(top1_idx, dtype=torch.float), 71 | ) / S 72 | m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0) 73 | loss = torch.mean(c_e * m_e) * (self.num_expert ** 2) 74 | self.set_loss(loss) 75 | 76 | with torch.no_grad(): 77 | if self.policy_fn is None: 78 | stored_models = torch.zeros(self.num_expert * self.world_size, 79 | dtype=torch.bool) 80 | else: 81 | # TODO: Fix this after expert shadowing is ported 82 | _, lec, aec, gec, agec = count_by_gate(top2_idx, 83 | self.num_expert, self.world_size, require_pos=False) 84 | stored_models = self.policy_fn(aec, agec, 85 | self.num_expert, self.world_size, inp.shape[-1], True) 86 | lim_mask = lim_mask & ~stored_models.view(-1).to(lim_mask.device) 87 | 88 | ogn_mask = lim_mask[top1_idx] 89 | ogn_thres = int(inp.shape[0] * self.ogn_ratio) 90 | 91 | if ogn_mask.sum().item() < ogn_thres: 92 | topk_val, topk_idx = torch.topk(gate_score, k=self.top_k) 93 | topk_val = F.softmax(topk_val, dim=-1) 94 | return topk_idx, topk_val 95 | 96 | with torch.no_grad(): 97 | top1_val[~ogn_mask] = float('-inf') 98 | _, top_ogn = torch.topk(top1_val.view(-1), k=ogn_thres) 99 | cand = gate_score.clone() 100 | cand[:, lim_mask] = float('-inf') 101 | _, topk_idx = torch.topk(cand, k=self.top_k) 102 | topk_idx[top_ogn, 1] = top1_idx.view(-1)[top_ogn] 103 | 104 | idx_x = torch.arange(inp.shape[0], device=inp.device).repeat_interleave(2) 105 | topk_val = gate_score[idx_x, topk_idx.view(-1)].view(-1, self.top_k) 106 | 107 | topk_val = F.softmax(topk_val, dim=-1) 108 | 109 | return topk_idx, topk_val 110 | 111 | 112 | def gen_faster_gate(rank): 113 | def _gen(d_model, n_expert, world_size, top_k=2): 114 | assert top_k == 2 115 | return FasterGate(d_model, n_expert, world_size, rank // nw_per_node) 116 | return _gen 117 | -------------------------------------------------------------------------------- /gates/gshard_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Balanced gate with GShard's policy (Google, 2020) 3 | """ 4 | import math 5 | import torch 6 | import torch.nn.functional as F 7 | from .naive_gate import NaiveGate 8 | from .utils import limit_by_capacity 9 | 10 | 11 | class GShardGate(NaiveGate): 12 | def __init__(self, d_model, num_expert, world_size, 13 | topk=2, capacity=(1.2, 2.4), random_routing=True): 14 | assert topk == 2, 'topk should be 2 in gshard' 15 | super().__init__(d_model, num_expert, world_size, top_k=2) 16 | self.capacity = capacity 17 | self.random_routing = random_routing 18 | 19 | def forward(self, x): 20 | naive_outs = super().forward(x, return_all_scores=True) 21 | topk_idx, topk_val, gate_score = naive_outs 22 | 23 | S = gate_score.shape[0] 24 | top_k = topk_idx.shape[0] // gate_score.shape[0] 25 | top1_idx = topk_idx.view((-1, top_k))[:, 0] 26 | c_e = torch.scatter_add( 27 | torch.zeros(self.tot_expert, device=top1_idx.device), 28 | 0, 29 | top1_idx, 30 | torch.ones_like(top1_idx, dtype=torch.float), 31 | ) / S 32 | m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0) 33 | loss = torch.mean(c_e * m_e) * (self.num_expert ** 2) 34 | self.set_loss(loss) 35 | 36 | cap_rate = self.capacity[0 if self.training else 1] 37 | capacity = math.ceil(cap_rate * x.shape[0]) 38 | _new_lec, _new_gec, topk_idx = limit_by_capacity( 39 | topk_idx, self.num_expert, self.world_size, capacity) 40 | 41 | if self.random_routing: 42 | rand_routing_prob = torch.rand(gate_score.size(0), device=x.device) 43 | mask = (2 * topk_val[:, 1] < rand_routing_prob) 44 | topk_idx[:, 1].masked_fill_(mask, -1) 45 | 46 | return topk_idx, topk_val 47 | -------------------------------------------------------------------------------- /gates/naive_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Naive gate 3 | """ 4 | from .base_gate import BaseGate 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class NaiveGate(BaseGate): 12 | r""" 13 | A naive gate implementation that defines the standard behavior of the gate 14 | which determines which experts the tokens are going to. 15 | Both the indicies and the score, or confidence, are output to the parent 16 | module. 17 | The load-balance strategies are also designed to be implemented within the 18 | `Gate` module. 19 | """ 20 | 21 | def __init__(self, d_model, num_expert, world_size, top_k=2): 22 | super().__init__(num_expert, world_size) 23 | self.gate = nn.Linear(d_model, self.tot_expert) 24 | self.top_k = top_k 25 | 26 | def forward(self, inp, return_all_scores=False): 27 | r""" 28 | The naive implementation simply calculates the top-k of a linear layer's 29 | output. 30 | """ 31 | gate = self.gate(inp) 32 | gate_top_k_val, gate_top_k_idx = torch.topk( 33 | gate, k=self.top_k, dim=-1, largest=True, sorted=False 34 | ) # [.. x top_k] 35 | gate_top_k_val = gate_top_k_val.view(-1, self.top_k) 36 | 37 | # (BxL) x 1 x top_k 38 | gate_score = F.softmax(gate_top_k_val, dim=-1) 39 | 40 | if return_all_scores: 41 | return gate_top_k_idx, gate_score, gate 42 | return gate_top_k_idx, gate_score 43 | -------------------------------------------------------------------------------- /gates/noisy_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Noisy gate for gshard and switch 3 | """ 4 | from .base_gate import BaseGate 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.distributions.normal import Normal 10 | import math 11 | 12 | 13 | class NoisyGate(BaseGate): 14 | def __init__(self, d_model, num_expert, world_size, top_k=2): 15 | super().__init__(num_expert, world_size) 16 | self.w_gate = nn.Parameter( 17 | torch.zeros(d_model, self.tot_expert), requires_grad=True 18 | ) 19 | self.w_noise = nn.Parameter( 20 | torch.zeros(d_model, self.tot_expert), requires_grad=True 21 | ) 22 | self.top_k = top_k 23 | self.softplus = nn.Softplus() 24 | self.softmax = nn.Softmax(1) 25 | 26 | self.noise_epsilon = 1e-2 27 | 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | # Approach is the same as in torch.nn.Linear 32 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88 33 | 34 | torch.nn.init.kaiming_uniform_(self.w_gate, a=math.sqrt(5)) 35 | torch.nn.init.kaiming_uniform_(self.w_noise, a=math.sqrt(5)) 36 | 37 | 38 | def _gates_to_load(self, gates): 39 | """Compute the true load per expert, given the gates. 40 | The load is the number of examples for which the corresponding gate is >0. 41 | Args: 42 | gates: a `Tensor` of shape [batch_size, n] 43 | Returns: 44 | a float32 `Tensor` of shape [n] 45 | """ 46 | return (gates > 0).sum(0) 47 | 48 | def _prob_in_top_k( 49 | self, clean_values, noisy_values, noise_stddev, noisy_top_values 50 | ): 51 | """Helper function to NoisyTopKGating. 52 | Computes the probability that value is in top k, given different random noise. 53 | This gives us a way of backpropagating from a loss that balances the number 54 | of times each expert is in the top k experts per example. 55 | In the case of no noise, pass in None for noise_stddev, and the result will 56 | not be differentiable. 57 | Args: 58 | clean_values: a `Tensor` of shape [batch, n]. 59 | noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus 60 | normally distributed noise with standard deviation noise_stddev. 61 | noise_stddev: a `Tensor` of shape [batch, n], or None 62 | noisy_top_values: a `Tensor` of shape [batch, m]. 63 | "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 64 | Returns: 65 | a `Tensor` of shape [batch, n]. 66 | """ 67 | 68 | batch = clean_values.size(0) 69 | m = noisy_top_values.size(1) 70 | top_values_flat = noisy_top_values.flatten() 71 | threshold_positions_if_in = ( 72 | torch.arange(batch, device=clean_values.device) * m + self.top_k 73 | ) 74 | threshold_if_in = torch.unsqueeze( 75 | torch.gather(top_values_flat, 0, threshold_positions_if_in), 1 76 | ) 77 | is_in = torch.gt(noisy_values, threshold_if_in) 78 | threshold_positions_if_out = threshold_positions_if_in - 1 79 | threshold_if_out = torch.unsqueeze( 80 | torch.gather(top_values_flat, 0, threshold_positions_if_out), 1 81 | ) 82 | # is each value currently in the top k. 83 | normal = Normal( 84 | torch.tensor([0.0], device=clean_values.device), 85 | torch.tensor([1.0], device=clean_values.device), 86 | ) 87 | prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev) 88 | prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev) 89 | prob = torch.where(is_in, prob_if_in, prob_if_out) 90 | return prob 91 | 92 | def cv_squared(self, x): 93 | """The squared coefficient of variation of a sample. 94 | Useful as a loss to encourage a positive distribution to be more uniform. 95 | Epsilons added for numerical stability. 96 | Returns 0 for an empty Tensor. 97 | Args: 98 | x: a `Tensor`. 99 | Returns: 100 | a `Scalar`. 101 | """ 102 | eps = 1e-10 103 | # if only num_expert = 1 104 | if x.shape[0] == 1: 105 | return torch.Tensor([0]) 106 | return x.float().var() / (x.float().mean() ** 2 + eps) 107 | 108 | def forward(self, inp): 109 | clean_logits = inp @ self.w_gate 110 | raw_noise_stddev = inp @ self.w_noise 111 | noise_stddev = ( 112 | self.softplus(raw_noise_stddev) + self.noise_epsilon 113 | ) * self.training 114 | noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) 115 | logits = noisy_logits 116 | 117 | # calculate topk + 1 that will be needed for the noisy gates 118 | top_logits, top_indices = logits.topk( 119 | min(self.top_k + 1, self.tot_expert), dim=1 120 | ) 121 | top_k_logits = top_logits[:, : self.top_k] 122 | top_k_indices = top_indices[:, : self.top_k] 123 | top_k_gates = self.softmax(top_k_logits) 124 | 125 | zeros = torch.zeros_like(logits, requires_grad=True) 126 | gates = zeros.scatter(1, top_k_indices, top_k_gates) 127 | 128 | if self.top_k < self.tot_expert: 129 | load = ( 130 | self._prob_in_top_k( 131 | clean_logits, noisy_logits, noise_stddev, top_logits 132 | ) 133 | ).sum(0) 134 | else: 135 | load = self._gates_to_load(gates) 136 | 137 | importance = gates.sum(0) 138 | loss = self.cv_squared(importance) + self.cv_squared(load) 139 | self.set_loss(loss) 140 | 141 | return ( 142 | top_k_indices.contiguous().view(-1), 143 | top_k_gates.contiguous().unsqueeze(1), 144 | ) 145 | -------------------------------------------------------------------------------- /gates/swipe_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Balanced gate using SWIPE algorithm 3 | """ 4 | import math 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from .naive_gate import NaiveGate 10 | 11 | from fmoe.functions import count_by_gate 12 | import fmoe_cuda as fmoe_native 13 | 14 | 15 | class SwipeGate(NaiveGate): 16 | def __init__(self, d_model, num_expert, world_size, top_k=2): 17 | super().__init__(d_model, num_expert, world_size, top_k) 18 | 19 | def swipe_once(self, idx, capacity, bias): 20 | with torch.no_grad(): 21 | idx_new, capacity = fmoe_native.swipe_once(idx, capacity, 22 | self.num_expert, self.world_size, bias) 23 | idx_new = idx_new.to(idx.device) 24 | return idx_new, capacity 25 | 26 | 27 | def forward(self, inp): 28 | score = self.gate(inp) 29 | orig_score, orig_idx = torch.topk(score, k=self.top_k, dim=-1) 30 | 31 | if not self.training: 32 | topk_val = F.softmax(orig_score, dim=-1) 33 | return orig_idx, topk_val 34 | 35 | capacity = torch.scalar_tensor(inp.shape[0] * self.top_k, 36 | dtype=torch.long) 37 | 38 | topk_idxs = [] 39 | topk_vals = [] 40 | idx_x = torch.arange(inp.shape[0], device=inp.device) 41 | for k in range(self.top_k): 42 | idx, capacity = self.swipe_once(orig_idx[:, k], capacity, 43 | k % self.num_expert) 44 | topk_vals.append(score[idx_x, idx]) 45 | topk_idxs.append(idx) 46 | topk_idx = torch.stack(topk_idxs).transpose(0, 1) 47 | topk_val = torch.stack(topk_vals).transpose(0, 1) 48 | topk_val = F.softmax(topk_val, dim=-1) 49 | return topk_idx, topk_val 50 | -------------------------------------------------------------------------------- /gates/switch_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Balanced gate with Switch Transformer's policy (Google, 2021) 3 | """ 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from .naive_gate import NaiveGate 9 | from .utils import limit_by_capacity 10 | 11 | 12 | class SwitchGate(NaiveGate): 13 | r""" 14 | A switch gate implementation 15 | """ 16 | 17 | def __init__(self, d_model, num_expert, world_size, topk=1, 18 | switch_eps=.1, capacity=(1.2, 2.4)): 19 | assert topk == 1, 'topk should be 1 in switch' 20 | super().__init__(d_model, num_expert, world_size, top_k=1) 21 | self.switch_eps = switch_eps 22 | self.capacity = capacity 23 | 24 | def forward(self, inp): 25 | r""" 26 | The switch firstly conduct softmax and then calculates the top-1 27 | """ 28 | score = self.gate(inp) 29 | 30 | if self.training: 31 | # random uniform number from [1-eps, 1+eps] 32 | noise = torch.rand_like(score) 33 | noise = noise * 2 * self.switch_eps + 1.0 - self.switch_eps 34 | score += noise 35 | 36 | # fp32 softmax for numerical stability 37 | score = F.softmax(score.float(), dim=-1) 38 | 39 | top1_score, top1_idx = torch.topk( 40 | score, k=1, dim=-1, largest=True 41 | ) # [.. x top_k] 42 | top1_score = top1_score.to(dtype=inp.dtype) 43 | 44 | cap_rate = self.capacity[0 if self.training else 1] 45 | capacity = math.ceil(cap_rate * inp.shape[0]) 46 | _new_lec, _new_gec, top1_idx = limit_by_capacity( 47 | top1_idx, self.num_expert, self.world_size, capacity) 48 | 49 | valid_idx = top1_idx[top1_idx > -1] 50 | fraction_expert = torch.scatter_add( 51 | torch.zeros(self.tot_expert, device=valid_idx.device), 52 | 0, 53 | valid_idx, 54 | torch.ones_like(valid_idx, dtype=torch.float), 55 | ) / valid_idx.numel() 56 | prob_expert = score.sum(dim=0) / valid_idx.numel() 57 | loss = (fraction_expert * prob_expert).sum() * self.tot_expert 58 | self.set_loss(loss) 59 | return top1_idx, top1_score 60 | -------------------------------------------------------------------------------- /gates/utils.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Utilities that may be used in the gates 3 | """ 4 | import torch 5 | from fmoe.functions import count_by_gate 6 | import fmoe_cuda as fmoe_native 7 | 8 | 9 | def limit_by_capacity(topk_idx, num_expert, world_size, capacity): 10 | with torch.no_grad(): 11 | capacity = torch.ones(num_expert, dtype=torch.int32, 12 | device=topk_idx.device) * capacity 13 | 14 | pos, lec, gec = count_by_gate(topk_idx, num_expert, world_size, 15 | require_pos=False) 16 | new_gec = fmoe_native.limit_by_capacity(gec, capacity, 17 | num_expert, world_size) 18 | if world_size > 1: 19 | new_lec = fmoe_native.expert_exchange(new_gec, num_expert, 20 | world_size) 21 | else: 22 | new_lec = new_gec 23 | 24 | topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx, 25 | new_lec.to(torch.int32), num_expert, world_size) 26 | return new_lec, new_gec, topk_idx 27 | -------------------------------------------------------------------------------- /gates/zero_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Zero gate that direct all input to gate 0 3 | """ 4 | from .base_gate import BaseGate 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class ZeroGate(BaseGate): 12 | r""" 13 | Guide all input samples to gate 0. 14 | """ 15 | 16 | def __init__(self, _1, num_expert, world_size, top_k=2): 17 | super().__init__(num_expert, world_size) 18 | self.top_k = top_k 19 | 20 | def forward(self, inp): 21 | r""" 22 | All output to expert 1 23 | """ 24 | idx = torch.zeros( 25 | inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device 26 | ) 27 | gate_score = ( 28 | torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k 29 | ) 30 | return idx, gate_score.reshape(-1, 1, self.top_k) 31 | -------------------------------------------------------------------------------- /linear.py: -------------------------------------------------------------------------------- 1 | r""" 2 | FMoE's parallel linear layer 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | import math 8 | 9 | import fmoe_cuda 10 | 11 | 12 | class MOELinear(Function): 13 | r""" 14 | Computes linear operators within one GPU on different experts simutaneously. 15 | """ 16 | 17 | @staticmethod 18 | def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None): 19 | global_output_buf = fmoe_cuda.linear_forward( 20 | global_input_buf, fwd_expert_count, weight, bias 21 | ) 22 | variables = (global_input_buf, fwd_expert_count, weight, bias) 23 | ctx.save_for_backward(*variables) 24 | return global_output_buf 25 | 26 | @staticmethod 27 | def backward(ctx, grad_out): 28 | (input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors 29 | grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward( 30 | grad_out, input_buf, fwd_expert_count, weight, bias 31 | ) 32 | 33 | if not torch.is_tensor(bias): 34 | grad_bias = None 35 | 36 | return grad_inp_buf, None, grad_weight, grad_bias 37 | 38 | 39 | 40 | class FMoELinear(nn.Module): 41 | r""" 42 | A linear layer that contains multiple experts. 43 | As multiple experts can be placed on the same worker, the computation can be 44 | performed in parallel to increase the performance. 45 | The FMoELinear module provides such function. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | num_expert: int, 51 | in_feat: int, 52 | out_feat: int, 53 | bias: bool = True, 54 | rank: int = 0, 55 | ): 56 | super().__init__() 57 | self.num_expert = num_expert 58 | self.in_feat = in_feat 59 | self.out_feat = out_feat 60 | self.rank = rank 61 | self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat)) 62 | if bias: 63 | self.bias = nn.Parameter(torch.zeros(num_expert, out_feat)) 64 | else: 65 | self.register_parameter("bias", None) 66 | 67 | self.reset_parameters() 68 | 69 | def forward(self, inp, fwd_expert_count): 70 | r""" 71 | Call MOE function 72 | """ 73 | x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias) 74 | return x 75 | 76 | def extra_repr(self) -> str: 77 | return "num_expert={}, in_features={}, \ 78 | out_features={}, bias={}, rank={}".format( 79 | self.num_expert, 80 | self.in_feat, 81 | self.out_feat, 82 | self.bias is not None, 83 | self.rank, 84 | ) 85 | 86 | def reset_parameters(self): 87 | # Approach is the same as in torch.nn.Linear 88 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88 89 | # bias is left to zero, similar as megatron 90 | 91 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 92 | 93 | -------------------------------------------------------------------------------- /new_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | from fmoe.gates.base_gate import BaseGate 6 | from custom_gate import CustomNaiveGate_Attn 7 | 8 | import pdb 9 | import torch.nn.functional as F 10 | 11 | 12 | __all__ = ['set_top_k', 'set_router_mode', 'freeze_part_weight', 'adjust_moe_gate_number', 13 | 'show_dts_gate_number', 'set_temperature', 'set_threshold', 14 | 'SWA_Average', 'collect_top_k', 'THOR_Model'] 15 | 16 | 17 | def set_top_k(model, num=2): 18 | for name, m in model.named_modules(): 19 | if hasattr(m, 'top_k') and hasattr(m, 'gate'): 20 | if isinstance(m.gate, BaseGate) and not isinstance(m.gate, CustomNaiveGate_Attn): 21 | m.top_k = num 22 | m.gate.top_k = num 23 | print('Layer name: {}, Top-K = {}, {}'.format(name, m.top_k, m.gate.top_k)) 24 | 25 | def collect_top_k(model): 26 | top_k = None 27 | for name, m in model.named_modules(): 28 | if hasattr(m, 'top_k') and hasattr(m, 'gate'): 29 | if isinstance(m.gate, BaseGate) and not isinstance(m.gate, CustomNaiveGate_Attn): 30 | top_k = m.gate.top_k 31 | break 32 | return top_k 33 | 34 | def set_router_mode(model, args, flag=True): 35 | # for name, m in model.named_modules(): 36 | # if isinstance(m, BaseGate): 37 | # m.dense_moe_flag = flag 38 | # print('Layer name: {}, Average MoE = {}'.format(name, m.dense_moe_flag)) 39 | print('** Using Score-Based Average for Dense Inference') 40 | 41 | current_gate = 0 42 | for name, m in model.named_modules(): 43 | if hasattr(m, 'top_k') and hasattr(m, 'gate'): 44 | if isinstance(m.gate, BaseGate) and not isinstance(m.gate, CustomNaiveGate_Attn): 45 | if flag: 46 | m.top_k = args.moe_num_expert 47 | m.gate.top_k = args.moe_num_expert 48 | else: 49 | m.top_k = args.moe_top_k 50 | m.gate.top_k = args.moe_top_k 51 | current_gate = m.top_k 52 | print('Set {}, Top-K = {} {}'.format(name, m.top_k, m.gate.top_k)) 53 | return current_gate 54 | 55 | def kl_loss_sym(logits1, logits2): 56 | 57 | kl_loss = nn.KLDivLoss(reduction="none") 58 | 59 | loss = kl_loss(F.log_softmax(logits1, dim=1), F.softmax(logits2, dim=1)) + kl_loss(F.log_softmax(logits2, dim=1), F.softmax(logits1, dim=1)) 60 | 61 | return loss.mean(-1) 62 | 63 | def freeze_part_weight(model, args): 64 | if args.freeze_gate: 65 | print('* Freeze Router') 66 | for name, p in model.named_parameters(): 67 | if 'gate.gate' in name: 68 | p.requires_grad = False 69 | 70 | if args.freeze_main_network: 71 | print('* Freeze All') 72 | for name, p in model.named_parameters(): 73 | if '.experts.' in name: 74 | p.requires_grad = False 75 | 76 | if args.freeze_main_network_all: 77 | print('* Freeze Attention') 78 | for name, p in model.named_parameters(): 79 | if 'word_emb.emb_layers' in name: continue 80 | if 'crit.out_layers' in name: continue 81 | if 'layers.' in name: 82 | if not 'gate.gate' in name: 83 | p.requires_grad = False 84 | 85 | for name, p in model.named_parameters(): 86 | if p.requires_grad: 87 | print('* Trainable Parameters {}, shape = {}'.format(name, p.shape)) 88 | else: 89 | print('* Freeze Parameters {}, shape = {}'.format(name, p.shape)) 90 | 91 | def calculate_gate_number(steps, args, overall_steps, min_experts, max_experts): 92 | if args.dynamic_moe_mode == 'linear_increase': 93 | number_experts = max_experts - min_experts 94 | gate_num = round(number_experts * steps / overall_steps) + min_experts 95 | elif args.dynamic_moe_mode == 'linear_decrease': 96 | number_experts = min_experts - max_experts 97 | gate_num = round(number_experts * steps / overall_steps) + max_experts 98 | elif args.dynamic_moe_mode == 'cosine_decrease': 99 | number_experts = max_experts - min_experts 100 | cosine_value = np.cos(np.pi * steps / (2 * overall_steps)) 101 | gate_num = round(number_experts * cosine_value) + min_experts 102 | elif args.dynamic_moe_mode == 'cosine_increase': 103 | number_experts = min_experts - max_experts 104 | cosine_value = np.cos(np.pi * steps / (2 * overall_steps)) 105 | gate_num = round(number_experts * cosine_value) + max_experts 106 | elif args.dynamic_moe_mode == 'exp_increase': 107 | number_experts = min_experts - max_experts 108 | current_steps = steps // (overall_steps // 300) 109 | cosine_value = 0.99 ** current_steps 110 | gate_num = round(number_experts * cosine_value) + max_experts 111 | elif args.dynamic_moe_mode == 'multi_step_increase': 112 | custom_gate_number = [1,2,4,8,16] 113 | length = len(custom_gate_number) 114 | gate_num_index = int(length * steps / overall_steps) 115 | gate_num = custom_gate_number[gate_num_index] 116 | elif args.dynamic_moe_mode == 'multi_step_decrease': 117 | custom_gate_number = [16,8,4,2,1] 118 | length = len(custom_gate_number) 119 | gate_num_index = int(length * steps / overall_steps) 120 | gate_num = custom_gate_number[gate_num_index] 121 | 122 | gate_num = np.clip(gate_num, min_experts, max_experts) 123 | 124 | return gate_num 125 | 126 | def adjust_moe_gate_number(model, steps, args, current_gate): 127 | new_gate_num = calculate_gate_number(steps, args, args.dynamic_overall_steps, args.moe_top_k_min, args.moe_top_k_max) 128 | if new_gate_num != current_gate: 129 | print('* Set New Top-k = {}'.format(new_gate_num)) 130 | set_top_k(model, new_gate_num) 131 | current_gate = new_gate_num 132 | return current_gate 133 | 134 | 135 | ## Dense to Sparse 136 | def show_dts_gate_number(model): 137 | for name, m in model.named_modules(): 138 | if isinstance(m, BaseGate): 139 | mean_experts = m.sum_top_k / m.forward_n 140 | layer_temp = m.temperature 141 | layer_threshold = m.threshold 142 | print('* Mean-Experts = {:.0f}, Temperature = {:.4f}, Threshold = {:.4f}'.format(mean_experts, layer_temp, layer_threshold)) 143 | 144 | def set_temperature(model, iterations, all_iteration, max_temp, min_temp): 145 | temp = max_temp + iterations * (min_temp - max_temp) / all_iteration 146 | for name, m in model.named_modules(): 147 | if isinstance(m, BaseGate): 148 | m.temperature = temp 149 | 150 | def set_threshold(model, args): 151 | if args.gate_name == 'CustomDTSGate': 152 | print('* Set threshold for DTS Gate') 153 | for name, m in model.named_modules(): 154 | if isinstance(m, BaseGate): 155 | m.threshold = args.threshold 156 | 157 | 158 | 159 | ## Weight Average 160 | class SWA_Average(nn.Module): 161 | def __init__(self, model, t_start, t_end, device): 162 | super(SWA_Average, self).__init__() 163 | self.device = device 164 | self.average_model = copy.deepcopy(model) 165 | self.register_buffer('n_average', torch.tensor(0, dtype=torch.long, device=self.device)) 166 | self.t_start = t_start 167 | self.t_end = t_end 168 | 169 | def forward(self, data, target, *mems): 170 | return self.average_model(data, target, *mems) 171 | 172 | def avg_fn(self, averaged_model_parameter, model_parameter, num_averaged): 173 | return averaged_model_parameter + (model_parameter - averaged_model_parameter) / ( 174 | num_averaged + 1 175 | ) 176 | 177 | def update_parameters(self, current_model, step): 178 | if step >= self.t_start and step <= self.t_end: 179 | print('Update parameters with step {}, current_n_average = {}'.format(step, self.n_average)) 180 | for p_swa, p_model in zip(self.average_model.parameters(), current_model.parameters()): 181 | p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model.detach(), self.n_average)) 182 | self.n_average +=1 183 | 184 | class THOR_Model(nn.Module): 185 | def __init__(self, basic_model, kl_alpha): 186 | super(THOR_Model, self).__init__() 187 | self.module = basic_model 188 | self.kl_alpha = kl_alpha 189 | 190 | def reset_length(self, tgt_len, ext_len, mem_len): 191 | self.module.reset_length(tgt_len, ext_len, mem_len) 192 | 193 | def forward(self, data, target, *mems): 194 | if self.training: 195 | outputs = self.module(data, target, *mems) 196 | outputs2 = self.module(data, target, *mems) 197 | loss_kl = kl_loss_sym(outputs[0], outputs2[0]) 198 | new_loss = (outputs[1] + outputs2[1])/2 + self.kl_alpha * loss_kl 199 | outputs[1] = new_loss 200 | else: 201 | outputs = self.module(data, target, *mems) 202 | return outputs[1:] 203 | 204 | -------------------------------------------------------------------------------- /script/figure5/12layers_smoe_dropout.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo 'Run training...' 3 | python -u train.py \ 4 | --cuda \ 5 | --data ../data/enwik8/ \ 6 | --dataset enwik8 \ 7 | --n_layer 12 \ 8 | --d_model 256 \ 9 | --n_head 8 \ 10 | --d_head 64 \ 11 | --d_inner 512 \ 12 | --dropout 0.1 \ 13 | --dropatt 0.0 \ 14 | --optim adam \ 15 | --lr 0.00025 \ 16 | --warmup_step 0 \ 17 | --max_step 400000 \ 18 | --tgt_len 512 \ 19 | --mem_len 512 \ 20 | --eval_tgt_len 128 \ 21 | --batch_size 22 \ 22 | --multi_gpu \ 23 | --moe --moe-num-expert 16 --moe-top-k 2 \ 24 | --gate_name CustomNaiveGate \ 25 | --moe_index 0,1,2,3 \ 26 | --freeze_gate \ 27 | --dynamic_moe \ 28 | --dynamic_moe_mode linear_increase \ 29 | --dynamic_overall_steps 400000 \ 30 | --moe-top-k-min 8 \ 31 | --moe-top-k-max 16 \ 32 | --work_dir SMoE-Dropout 33 | -------------------------------------------------------------------------------- /script/figure5/8layers_smoe_dropout.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo 'Run training...' 3 | python -u train.py \ 4 | --cuda \ 5 | --data ../data/enwik8/ \ 6 | --dataset enwik8 \ 7 | --n_layer 8 \ 8 | --d_model 256 \ 9 | --n_head 8 \ 10 | --d_head 64 \ 11 | --d_inner 512 \ 12 | --dropout 0.1 \ 13 | --dropatt 0.0 \ 14 | --optim adam \ 15 | --lr 0.00025 \ 16 | --warmup_step 0 \ 17 | --max_step 400000 \ 18 | --tgt_len 512 \ 19 | --mem_len 512 \ 20 | --eval_tgt_len 128 \ 21 | --batch_size 22 \ 22 | --multi_gpu \ 23 | --moe --moe-num-expert 16 --moe-top-k 2 \ 24 | --gate_name CustomNaiveGate \ 25 | --moe_index 0,1,2,3 \ 26 | --freeze_gate \ 27 | --dynamic_moe \ 28 | --dynamic_moe_mode linear_increase \ 29 | --dynamic_overall_steps 400000 \ 30 | --moe-top-k-min 8 \ 31 | --moe-top-k-max 16 \ 32 | --work_dir SMoE-Dropout 33 | -------------------------------------------------------------------------------- /script/table1/transformer_xl/directly_dense_training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo 'Run training...' 3 | python -u train.py \ 4 | --cuda \ 5 | --data ../data/enwik8/ \ 6 | --dataset enwik8 \ 7 | --n_layer 4 \ 8 | --d_model 256 \ 9 | --n_head 8 \ 10 | --d_head 64 \ 11 | --d_inner 8192 \ 12 | --dropout 0.1 \ 13 | --dropatt 0.0 \ 14 | --optim adam \ 15 | --lr 0.00025 \ 16 | --warmup_step 0 \ 17 | --max_step 400000 \ 18 | --tgt_len 512 \ 19 | --mem_len 512 \ 20 | --eval_tgt_len 128 \ 21 | --batch_size 22 \ 22 | --multi_gpu \ 23 | --work_dir Directly_Dense_Training 24 | -------------------------------------------------------------------------------- /script/table1/transformer_xl/smoe_dropout.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo 'Run training...' 3 | python -u train.py \ 4 | --cuda \ 5 | --data ../data/enwik8/ \ 6 | --dataset enwik8 \ 7 | --n_layer 4 \ 8 | --d_model 256 \ 9 | --n_head 8 \ 10 | --d_head 64 \ 11 | --d_inner 512 \ 12 | --dropout 0.1 \ 13 | --dropatt 0.0 \ 14 | --optim adam \ 15 | --lr 0.00025 \ 16 | --warmup_step 0 \ 17 | --max_step 400000 \ 18 | --tgt_len 512 \ 19 | --mem_len 512 \ 20 | --eval_tgt_len 128 \ 21 | --batch_size 22 \ 22 | --multi_gpu \ 23 | --moe --moe-num-expert 16 --moe-top-k 2 \ 24 | --gate_name CustomNaiveGate \ 25 | --moe_index 0,1,2,3 \ 26 | --freeze_gate \ 27 | --dynamic_moe \ 28 | --dynamic_moe_mode linear_increase \ 29 | --dynamic_overall_steps 400000 \ 30 | --moe-top-k-min 8 \ 31 | --moe-top-k-max 16 \ 32 | --work_dir SMoE-Dropout 33 | -------------------------------------------------------------------------------- /script/table2/sst2/dense_model.sh: -------------------------------------------------------------------------------- 1 | python -u train_sst2.py \ 2 | --cuda \ 3 | --data ../glue_data/SST-2 \ 4 | --dataset sst2 \ 5 | --n_layer 4 \ 6 | --d_model 256 \ 7 | --n_head 8 \ 8 | --d_head 64 \ 9 | --d_inner 8192 \ 10 | --dropout 0.1 \ 11 | --dropatt 0.0 \ 12 | --optim adam \ 13 | --lr 1e-4 \ 14 | --warmup_step 0 \ 15 | --max_step 5000 \ 16 | --eval-interval 500 \ 17 | --log-interval 100 \ 18 | --tgt_len 512 \ 19 | --mem_len 128 \ 20 | --eval_tgt_len 128 \ 21 | --batch_size 16 \ 22 | --work_dir dense_model \ 23 | --pretrained_weight $1 -------------------------------------------------------------------------------- /script/table2/sst2/smoe_dropout.sh: -------------------------------------------------------------------------------- 1 | python -u train_sst2.py \ 2 | --cuda \ 3 | --data ../glue_data/SST-2 \ 4 | --dataset sst2 \ 5 | --n_layer 4 \ 6 | --d_model 256 \ 7 | --n_head 8 \ 8 | --d_head 64 \ 9 | --d_inner 512 \ 10 | --dropout 0.1 \ 11 | --dropatt 0.0 \ 12 | --optim adam \ 13 | --lr 1e-4 \ 14 | --warmup_step 0 \ 15 | --max_step 5000 \ 16 | --eval-interval 500 \ 17 | --log-interval 100 \ 18 | --tgt_len 512 \ 19 | --mem_len 128 \ 20 | --eval_tgt_len 128 \ 21 | --batch_size 16 \ 22 | --work_dir smoe_dropout \ 23 | --pretrained_weight $1 \ 24 | --moe --moe-num-expert 16 --moe-top-k 2 \ 25 | --gate_name CustomNaiveGate \ 26 | --dynamic_moe \ 27 | --freeze_gate \ 28 | --dynamic_moe_mode linear_increase \ 29 | --dynamic_overall_steps 5000 \ 30 | --moe-top-k-min 16 \ 31 | --moe-top-k-max 16 -------------------------------------------------------------------------------- /train_sst2.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import pdb 3 | import argparse 4 | import time 5 | import math 6 | import os, sys 7 | import itertools 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | 15 | from data_utils import get_lm_corpus 16 | from mem_transformer_sst2 import MemTransformerLM 17 | from utils.exp_utils import create_exp_dir 18 | from utils.data_parallel import BalancedDataParallel 19 | from fmoe.gates.base_gate import BaseGate 20 | 21 | from new_utils import * 22 | 23 | import warnings 24 | warnings.filterwarnings(action= 'ignore') 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') 27 | 28 | parser.add_argument('--pretrained_weight', default=None, type=str) 29 | 30 | parser.add_argument('--data', type=str, default='../data/wikitext-103', 31 | help='location of the data corpus') 32 | parser.add_argument('--dataset', type=str, default='wt103', 33 | choices=['wt103', 'lm1b', 'enwik8', 'text8', 'csqa', 'sst2', 'sst2_v2'], 34 | help='dataset name') 35 | parser.add_argument('--n_layer', type=int, default=12, 36 | help='number of total layers') 37 | parser.add_argument('--n_head', type=int, default=10, 38 | help='number of heads') 39 | parser.add_argument('--d_head', type=int, default=50, 40 | help='head dimension') 41 | parser.add_argument('--d_embed', type=int, default=-1, 42 | help='embedding dimension') 43 | parser.add_argument('--d_model', type=int, default=500, 44 | help='model dimension') 45 | parser.add_argument('--d_inner', type=int, default=1000, 46 | help='inner dimension in FF') 47 | parser.add_argument('--dropout', type=float, default=0.0, 48 | help='global dropout rate') 49 | parser.add_argument('--dropatt', type=float, default=0.0, 50 | help='attention probability dropout rate') 51 | parser.add_argument('--init', default='normal', type=str, 52 | help='parameter initializer to use.') 53 | parser.add_argument('--emb_init', default='normal', type=str, 54 | help='parameter initializer to use.') 55 | parser.add_argument('--init_range', type=float, default=0.1, 56 | help='parameters initialized by U(-init_range, init_range)') 57 | parser.add_argument('--emb_init_range', type=float, default=0.01, 58 | help='parameters initialized by U(-init_range, init_range)') 59 | parser.add_argument('--init_std', type=float, default=0.02, 60 | help='parameters initialized by N(0, init_std)') 61 | parser.add_argument('--proj_init_std', type=float, default=0.01, 62 | help='parameters initialized by N(0, init_std)') 63 | parser.add_argument('--optim', default='adam', type=str, 64 | choices=['adam', 'sgd', 'adagrad'], 65 | help='optimizer to use.') 66 | parser.add_argument('--lr', type=float, default=0.00025, 67 | help='initial learning rate (0.00025|5 for adam|sgd)') 68 | parser.add_argument('--mom', type=float, default=0.0, 69 | help='momentum for sgd') 70 | parser.add_argument('--scheduler', default='cosine', type=str, 71 | choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'], 72 | help='lr scheduler to use.') 73 | parser.add_argument('--warmup_step', type=int, default=0, 74 | help='upper epoch limit') 75 | parser.add_argument('--decay_rate', type=float, default=0.5, 76 | help='decay factor when ReduceLROnPlateau is used') 77 | parser.add_argument('--lr_min', type=float, default=0.0, 78 | help='minimum learning rate during annealing') 79 | parser.add_argument('--clip', type=float, default=0.25, 80 | help='gradient clipping') 81 | parser.add_argument('--clip_nonemb', action='store_true', 82 | help='only clip the gradient of non-embedding params') 83 | parser.add_argument('--max_step', type=int, default=100000, 84 | help='upper epoch limit') 85 | parser.add_argument('--batch_size', type=int, default=60, 86 | help='batch size') 87 | parser.add_argument('--batch_chunk', type=int, default=1, 88 | help='split batch into chunks to save memory') 89 | parser.add_argument('--tgt_len', type=int, default=70, 90 | help='number of tokens to predict') 91 | parser.add_argument('--eval_tgt_len', type=int, default=50, 92 | help='number of tokens to predict for evaluation') 93 | parser.add_argument('--ext_len', type=int, default=0, 94 | help='length of the extended context') 95 | parser.add_argument('--mem_len', type=int, default=0, 96 | help='length of the retained previous heads') 97 | parser.add_argument('--not_tied', action='store_true', 98 | help='do not tie the word embedding and softmax weights') 99 | parser.add_argument('--seed', type=int, default=1111, 100 | help='random seed') 101 | parser.add_argument('--cuda', action='store_true', 102 | help='use CUDA') 103 | parser.add_argument('--adaptive', action='store_true', 104 | help='use adaptive softmax') 105 | parser.add_argument('--div_val', type=int, default=1, 106 | help='divident value for adapative input and softmax') 107 | parser.add_argument('--pre_lnorm', action='store_true', 108 | help='apply LayerNorm to the input instead of the output') 109 | parser.add_argument('--varlen', action='store_true', 110 | help='use variable length') 111 | parser.add_argument('--multi_gpu', action='store_true', 112 | help='use multiple GPU') 113 | parser.add_argument('--log-interval', type=int, default=200, 114 | help='report interval') 115 | parser.add_argument('--eval-interval', type=int, default=4000, 116 | help='evaluation interval') 117 | parser.add_argument('--work_dir', default='LM-TFM', type=str, 118 | help='experiment directory.') 119 | parser.add_argument('--restart', action='store_true', 120 | help='restart training from the saved checkpoint') 121 | parser.add_argument('--restart_dir', type=str, default='', 122 | help='restart dir') 123 | parser.add_argument('--debug', action='store_true', 124 | help='run in debug mode (do not create exp dir)') 125 | parser.add_argument('--same_length', action='store_true', 126 | help='use the same attn length for all tokens') 127 | parser.add_argument('--attn_type', type=int, default=0, 128 | help='attention type. 0 for ours, 1 for Shaw et al,' 129 | '2 for Vaswani et al, 3 for Al Rfou et al.') 130 | parser.add_argument('--clamp_len', type=int, default=-1, 131 | help='use the same pos embeddings after clamp_len') 132 | parser.add_argument('--eta_min', type=float, default=0.0, 133 | help='min learning rate for cosine scheduler') 134 | parser.add_argument('--gpu0_bsz', type=int, default=-1, 135 | help='batch size on gpu 0') 136 | parser.add_argument('--max_eval_steps', type=int, default=-1, 137 | help='max eval steps') 138 | parser.add_argument('--sample_softmax', type=int, default=-1, 139 | help='number of samples in sampled softmax') 140 | parser.add_argument('--patience', type=int, default=0, 141 | help='patience') 142 | parser.add_argument('--finetune_v2', action='store_true', 143 | help='finetune v2') 144 | parser.add_argument('--finetune_v3', action='store_true', 145 | help='finetune v3') 146 | parser.add_argument('--fp16', action='store_true', 147 | help='Run in pseudo-fp16 mode (fp16 storage fp32 math).') 148 | parser.add_argument('--static-loss-scale', type=float, default=1, 149 | help='Static loss scale, positive power of 2 values can ' 150 | 'improve fp16 convergence.') 151 | parser.add_argument('--dynamic-loss-scale', action='store_true', 152 | help='Use dynamic loss scaling. If supplied, this argument' 153 | ' supersedes --static-loss-scale.') 154 | parser.add_argument('--moe', action='store_true', 155 | help='replace position-wise ffn with moe position-wise ffn') 156 | parser.add_argument('--moe-num-expert', type=int, default=64, 157 | help='number of experts in MoE') 158 | 159 | parser.add_argument('--moe-top-k', type=int, default=2, 160 | help='top_k experts in hard gate of moe') 161 | 162 | 163 | ## other settings 164 | parser.add_argument('--gate_name', type=str, default='NaiveGate', 165 | help='Router Type') 166 | parser.add_argument('--moe_index', type=str, default=None, help='MoE Index') 167 | ## Random Weight 168 | parser.add_argument('--freeze_gate', action='store_true') 169 | parser.add_argument('--freeze_main_network', action='store_true') 170 | parser.add_argument('--freeze_main_network_all', action='store_true') 171 | ## Gradually adjust Top-K number during training 172 | parser.add_argument('--dynamic_moe', action='store_true', 173 | help='dynamic change moe top-k') 174 | parser.add_argument('--dynamic_moe_mode', type=str, default='linear_increase') 175 | parser.add_argument('--dynamic_overall_steps', type=int, default=-1) 176 | parser.add_argument('--moe-top-k-min', type=int, default=2) 177 | parser.add_argument('--moe-top-k-max', type=int, default=16) 178 | 179 | ## Dense to Sparse 180 | parser.add_argument('--min_temp', type=int, default=0.3) 181 | parser.add_argument('--max_temp', type=int, default=2) 182 | parser.add_argument('--threshold', type=int, default=0.001) 183 | ## Dense Dropout 184 | parser.add_argument('--dense_drop', action='store_true') 185 | parser.add_argument('--expert_drop', type=float, default=0.5) 186 | parser.add_argument('--num_expert', type=int, default=64) 187 | ## SWAD/SWA 188 | parser.add_argument('--swad', action='store_true') 189 | parser.add_argument('--swad_start', type=int, default=0) 190 | parser.add_argument('--swad_end', type=int, default=400000) 191 | ## Dynamic Routing 192 | parser.add_argument('--dynamic_router_start', type=int, default=-1) 193 | 194 | args = parser.parse_args() 195 | args.tied = not args.not_tied 196 | assert args.moe_num_expert >= args.moe_top_k, "must have moe-num-expert >= moe-top_k" 197 | 198 | if args.d_embed < 0: 199 | args.d_embed = args.d_model 200 | 201 | assert args.ext_len >= 0, 'extended context length must be non-negative' 202 | assert args.batch_size % args.batch_chunk == 0 203 | 204 | args.work_dir = '{}-{}'.format(args.work_dir, args.dataset) 205 | args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S')) 206 | logging = create_exp_dir(args.work_dir, 207 | scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug) 208 | 209 | # Set the random seed manually for reproducibility. 210 | np.random.seed(args.seed) 211 | torch.manual_seed(args.seed) 212 | if torch.cuda.is_available(): 213 | if not args.cuda: 214 | print('WARNING: You have a CUDA device, so you should probably run with --cuda') 215 | else: 216 | torch.cuda.manual_seed_all(args.seed) 217 | 218 | # Validate `--fp16` option 219 | if args.fp16: 220 | if not args.cuda: 221 | print('WARNING: --fp16 requires --cuda, ignoring --fp16 option') 222 | args.fp16 = False 223 | else: 224 | try: 225 | from apex.fp16_utils import FP16_Optimizer 226 | except: 227 | print('WARNING: apex not installed, ignoring --fp16 option') 228 | args.fp16 = False 229 | 230 | device = torch.device('cuda' if args.cuda else 'cpu') 231 | 232 | ############################################################################### 233 | # Load data 234 | ############################################################################### 235 | corpus = get_lm_corpus(args.data, args.dataset) 236 | ntokens = len(corpus.vocab) 237 | args.n_token = ntokens 238 | 239 | eval_batch_size = 10 240 | 241 | # for CSQA 242 | # tr_iter = corpus.get_iterator('train', args.batch_size) 243 | # va_iter = corpus.get_iterator('valid', args.batch_size) 244 | # te_iter = va_iter 245 | 246 | tr_iter = corpus.get_iterator('train', args.batch_size) 247 | va_iter = corpus.get_iterator('valid', args.batch_size) 248 | te_iter = va_iter 249 | 250 | 251 | # adaptive softmax / embedding 252 | cutoffs, tie_projs = [], [False] 253 | if args.adaptive: 254 | assert args.dataset in ['wt103', 'lm1b'] 255 | if args.dataset == 'wt103': 256 | cutoffs = [20000, 40000, 200000] 257 | tie_projs += [True] * len(cutoffs) 258 | elif args.dataset == 'lm1b': 259 | cutoffs = [60000, 100000, 640000] 260 | tie_projs += [False] * len(cutoffs) 261 | 262 | ############################################################################### 263 | # Build the model 264 | ############################################################################### 265 | def init_weight(weight): 266 | if args.init == 'uniform': 267 | nn.init.uniform_(weight, -args.init_range, args.init_range) 268 | elif args.init == 'normal': 269 | nn.init.normal_(weight, 0.0, args.init_std) 270 | 271 | def init_bias(bias): 272 | nn.init.constant_(bias, 0.0) 273 | 274 | def weights_init(m): 275 | classname = m.__class__.__name__ 276 | if classname.find('Linear') != -1: 277 | if hasattr(m, 'weight') and m.weight is not None: 278 | init_weight(m.weight) 279 | if hasattr(m, 'bias') and m.bias is not None: 280 | init_bias(m.bias) 281 | elif classname.find('AdaptiveEmbedding') != -1: 282 | if hasattr(m, 'emb_projs'): 283 | for i in range(len(m.emb_projs)): 284 | if m.emb_projs[i] is not None: 285 | nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std) 286 | elif classname.find('Embedding') != -1: 287 | if hasattr(m, 'weight'): 288 | init_weight(m.weight) 289 | elif classname.find('ProjectedAdaptiveLogSoftmax') != -1: 290 | if hasattr(m, 'cluster_weight') and m.cluster_weight is not None: 291 | init_weight(m.cluster_weight) 292 | if hasattr(m, 'cluster_bias') and m.cluster_bias is not None: 293 | init_bias(m.cluster_bias) 294 | if hasattr(m, 'out_projs'): 295 | for i in range(len(m.out_projs)): 296 | if m.out_projs[i] is not None: 297 | nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std) 298 | elif classname.find('LayerNorm') != -1: 299 | if hasattr(m, 'weight'): 300 | nn.init.normal_(m.weight, 1.0, args.init_std) 301 | if hasattr(m, 'bias') and m.bias is not None: 302 | init_bias(m.bias) 303 | elif classname.find('TransformerLM') != -1: 304 | if hasattr(m, 'r_emb'): 305 | init_weight(m.r_emb) 306 | if hasattr(m, 'r_w_bias'): 307 | init_weight(m.r_w_bias) 308 | if hasattr(m, 'r_r_bias'): 309 | init_weight(m.r_r_bias) 310 | if hasattr(m, 'r_bias'): 311 | init_bias(m.r_bias) 312 | 313 | def update_dropout(m): 314 | classname = m.__class__.__name__ 315 | if classname.find('Dropout') != -1: 316 | if hasattr(m, 'p'): 317 | m.p = args.dropout 318 | 319 | def update_dropatt(m): 320 | if hasattr(m, 'dropatt'): 321 | m.dropatt.p = args.dropatt 322 | 323 | if args.moe_index is not None: 324 | moe_index = list(map(int, args.moe_index.split(','))) 325 | else: 326 | moe_index = None 327 | 328 | if args.restart: 329 | with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f: 330 | model = torch.load(f) 331 | if not args.fp16: 332 | model = model.float() 333 | model.apply(update_dropout) 334 | model.apply(update_dropatt) 335 | else: 336 | model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, 337 | args.d_head, args.d_inner, args.dropout, args.dropatt, 338 | tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, 339 | tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, 340 | ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, 341 | same_length=args.same_length, attn_type=args.attn_type, 342 | clamp_len=args.clamp_len, sample_softmax=args.sample_softmax, 343 | moe=args.moe, moe_num_expert=args.moe_num_expert, moe_top_k=args.moe_top_k, gate_name=args.gate_name, moe_index=moe_index, 344 | dense_drop=args.dense_drop, expert_drop=args.expert_drop, num_expert=args.num_expert) 345 | model.apply(weights_init) 346 | model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing 347 | args.n_all_param = sum([p.nelement() for p in model.parameters()]) 348 | args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) 349 | 350 | # for Dense to Sparse Method 351 | set_threshold(model, args) 352 | freeze_part_weight(model, args) 353 | 354 | if args.fp16: 355 | model = model.half() 356 | 357 | if args.multi_gpu: 358 | model = model.to(device) 359 | if args.gpu0_bsz >= 0: 360 | para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk, 361 | model, dim=1).to(device) 362 | else: 363 | para_model = nn.DataParallel(model, dim=1).to(device) 364 | else: 365 | para_model = model.to(device) 366 | 367 | if args.swad: 368 | assert not args.restart 369 | print('Initial SWAD Model') 370 | swa_model = SWA_Average(model, t_start=args.swad_start, t_end=args.swad_end, device=device) 371 | 372 | #### optimizer 373 | if args.optim.lower() == 'sgd': 374 | if args.sample_softmax > 0: 375 | dense_params, sparse_params = [], [] 376 | for param in model.parameters(): 377 | if not param.requires_grad: 378 | print(param.shape) 379 | continue 380 | if param.size() == model.word_emb.weight.size(): 381 | sparse_params.append(param) 382 | else: 383 | dense_params.append(param) 384 | optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2) 385 | optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom) 386 | else: 387 | optimizer = optim.SGD(filter(lambda p:p.requires_grad, model.parameters()), lr=args.lr, 388 | momentum=args.mom) 389 | elif args.optim.lower() == 'adam': 390 | if args.sample_softmax > 0: 391 | dense_params, sparse_params = [], [] 392 | for param in model.parameters(): 393 | if not param.requires_grad: 394 | print(param.shape) 395 | continue 396 | if param.size() == model.word_emb.weight.size(): 397 | sparse_params.append(param) 398 | else: 399 | dense_params.append(param) 400 | optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr) 401 | optimizer = optim.Adam(dense_params, lr=args.lr) 402 | else: 403 | optimizer = optim.Adam(filter(lambda p:p.requires_grad, model.parameters()), lr=args.lr) 404 | elif args.optim.lower() == 'adagrad': 405 | optimizer = optim.Adagrad(filter(lambda p:p.requires_grad, model.parameters()), lr=args.lr) 406 | 407 | #### scheduler 408 | if args.scheduler == 'cosine': 409 | # here we do not set eta_min to lr_min to be backward compatible 410 | # because in previous versions eta_min is default to 0 411 | # rather than the default value of lr_min 1e-6 412 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 413 | args.max_step, eta_min=args.eta_min) # should use eta_min arg 414 | if args.sample_softmax > 0: 415 | scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse, 416 | args.max_step, eta_min=args.eta_min) # should use eta_min arg 417 | elif args.scheduler == 'inv_sqrt': 418 | # originally used for Transformer (in Attention is all you need) 419 | def lr_lambda(step): 420 | # return a multiplier instead of a learning rate 421 | if step == 0 and args.warmup_step == 0: 422 | return 1. 423 | else: 424 | return 1. / (step ** 0.5) if step > args.warmup_step \ 425 | else step / (args.warmup_step ** 1.5) 426 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 427 | elif args.scheduler == 'dev_perf': 428 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 429 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min) 430 | if args.sample_softmax > 0: 431 | scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse, 432 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min) 433 | elif args.scheduler == 'constant': 434 | pass 435 | 436 | if args.cuda and args.fp16: 437 | # If args.dynamic_loss_scale is False, static_loss_scale will be used. 438 | # If args.dynamic_loss_scale is True, it will take precedence over static_loss_scale. 439 | optimizer = FP16_Optimizer(optimizer, 440 | static_loss_scale = args.static_loss_scale, 441 | dynamic_loss_scale = args.dynamic_loss_scale, 442 | dynamic_loss_args = {'init_scale': 2 ** 16}) 443 | 444 | if args.restart: 445 | if os.path.exists(os.path.join(args.restart_dir, 'optimizer.pt')): 446 | with open(os.path.join(args.restart_dir, 'optimizer.pt'), 'rb') as f: 447 | opt_state_dict = torch.load(f) 448 | optimizer.load_state_dict(opt_state_dict) 449 | else: 450 | print('Optimizer was not saved. Start from scratch.') 451 | 452 | logging('=' * 100) 453 | for k, v in args.__dict__.items(): 454 | logging(' - {} : {}'.format(k, v)) 455 | logging('=' * 100) 456 | logging('#params = {}'.format(args.n_all_param)) 457 | logging('#non emb params = {}'.format(args.n_nonemb_param)) 458 | ############################################################################### 459 | # Training code 460 | ############################################################################### 461 | 462 | 463 | 464 | logging('=' * 100) 465 | logging('==== loading pretrained model from {} ===='.format(args.pretrained_weight)) 466 | logging('=' * 100) 467 | 468 | # Load the best saved model. 469 | with open(args.pretrained_weight, 'rb') as f: 470 | pretrained_model = torch.load(f) 471 | pretrained_model_checkpoint = pretrained_model.state_dict() 472 | filtered_checkpoint = {} 473 | for key in pretrained_model_checkpoint.keys(): 474 | if not key in model.state_dict(): 475 | logging('Can not load {}'.format(key)) 476 | elif not pretrained_model_checkpoint[key].shape == model.state_dict()[key].shape: 477 | logging('Can not load {}, shape do not match'.format(key)) 478 | else: 479 | filtered_checkpoint[key] = pretrained_model_checkpoint[key] 480 | 481 | model.load_state_dict(filtered_checkpoint, strict=False) 482 | 483 | 484 | 485 | def evaluate(model, eval_iter): 486 | # Turn on evaluation mode which disables dropout. 487 | model.eval() 488 | 489 | # If the model does not use memory at all, make the ext_len longer. 490 | # Otherwise, make the mem_len longer and keep the ext_len the same. 491 | if args.mem_len == 0: 492 | model.reset_length(args.eval_tgt_len, 493 | args.ext_len+args.tgt_len-args.eval_tgt_len, args.mem_len) 494 | else: 495 | model.reset_length(args.eval_tgt_len, 496 | args.ext_len, args.mem_len+args.tgt_len-args.eval_tgt_len) 497 | 498 | # Evaluation 499 | total_len, total_acc = 0, 0. 500 | with torch.no_grad(): 501 | mems = tuple() 502 | for i, (data, mask, label) in enumerate(eval_iter): 503 | data = data.cuda() 504 | mask = mask.cuda() 505 | label = label.cuda() 506 | 507 | predict, mems = para_model(data, mask, *mems) 508 | 509 | total_acc += (predict.argmax(-1) == label).sum().item() 510 | total_len += label.shape[0] 511 | 512 | # Switch back to the training mode 513 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len) 514 | model.train() 515 | 516 | return 100 * total_acc / total_len 517 | 518 | def train(): 519 | # Turn on training mode which enables dropout. 520 | global train_step, train_loss, best_val_acc, best_val_acc_dense, eval_start_time, log_start_time, current_gate, all_top_k, train_correct, train_n 521 | model.train() 522 | 523 | criterion = nn.CrossEntropyLoss() 524 | mems = tuple() 525 | 526 | train_iter = tr_iter.get_varlen_iter() 527 | for batch, (data, mask, label) in enumerate(train_iter): 528 | 529 | if args.gate_name == 'CustomDTSGate': 530 | set_temperature(model, train_step, args.max_step, args.max_temp, args.min_temp) 531 | 532 | if args.dynamic_moe: 533 | current_gate = adjust_moe_gate_number(model, train_step, args, current_gate) 534 | 535 | current_top_k = collect_top_k(model) 536 | all_top_k.append(current_top_k) 537 | 538 | model.zero_grad() 539 | data = data.cuda() 540 | mask = mask.cuda() 541 | label = label.cuda() 542 | 543 | predict, mems = para_model(data, mask, *mems) 544 | 545 | loss = criterion(predict, label) 546 | loss = loss.float() 547 | 548 | train_correct += (predict.argmax(-1) == label).sum().item() 549 | train_n += label.shape[0] 550 | 551 | if args.fp16: 552 | optimizer.backward(loss) 553 | else: 554 | loss.backward() 555 | train_loss += loss.float().item() 556 | 557 | if args.fp16: 558 | optimizer.clip_master_grads(args.clip) 559 | else: 560 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 561 | 562 | optimizer.step() 563 | if args.sample_softmax > 0: 564 | optimizer_sparse.step() 565 | 566 | 567 | # step-wise learning rate annealing 568 | train_step += 1 569 | if args.scheduler in ['cosine', 'constant', 'dev_perf']: 570 | # linear warmup stage 571 | if train_step < args.warmup_step: 572 | curr_lr = args.lr * train_step / args.warmup_step 573 | optimizer.param_groups[0]['lr'] = curr_lr 574 | if args.sample_softmax > 0: 575 | optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2 576 | else: 577 | if args.scheduler == 'cosine': 578 | scheduler.step(train_step) 579 | if args.sample_softmax > 0: 580 | scheduler_sparse.step(train_step) 581 | elif args.scheduler == 'inv_sqrt': 582 | scheduler.step(train_step) 583 | 584 | if train_step % args.log_interval == 1: 585 | cur_loss = train_loss / args.log_interval 586 | cur_acc = train_correct / train_n 587 | elapsed = time.time() - log_start_time 588 | 589 | if args.gate_name == 'CustomDTSGate': 590 | show_dts_gate_number(model) 591 | 592 | log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \ 593 | '| ms/batch {:5.2f} | loss {:5.2f} | Accuracy {:5.2f}'.format( 594 | epoch, train_step, batch+1, optimizer.param_groups[0]['lr'], 595 | elapsed * 1000 / args.log_interval, cur_loss, cur_acc*100) 596 | logging(log_str) 597 | train_loss = 0 598 | log_start_time = time.time() 599 | 600 | if train_step % args.eval_interval == 0: 601 | 602 | current_gate = set_router_mode(model, args, flag=True) 603 | val_acc_dense = evaluate(model, va_iter) 604 | current_gate = set_router_mode(model, args, flag=False) 605 | val_acc = evaluate(model, va_iter) 606 | 607 | if args.swad: 608 | swa_model.update_parameters(model, train_step) 609 | current_gate = set_router_mode(swa_model.average_model, args, flag=True) 610 | val_acc_dense_swa = evaluate(swa_model.average_model, va_iter) 611 | current_gate = set_router_mode(swa_model.average_model, args, flag=False) 612 | val_acc_swa = evaluate(swa_model.average_model, va_iter) 613 | logging('-' * 100) 614 | log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \ 615 | '| SWA valid Accuracy {:5.2f}'.format( 616 | train_step // args.eval_interval, train_step, 617 | (time.time() - eval_start_time), val_acc_swa) 618 | logging(log_str) 619 | logging('-' * 100) 620 | log_str_dense = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \ 621 | '| SWA Dense valid Accuracy {:5.2f}'.format( 622 | train_step // args.eval_interval, train_step, 623 | (time.time() - eval_start_time), val_acc_dense_swa) 624 | logging(log_str_dense) 625 | logging('-' * 100) 626 | with open(os.path.join(args.work_dir, 'model_swa.pt'), 'wb') as f: 627 | torch.save(swa_model.average_model, f) 628 | 629 | logging('-' * 100) 630 | log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \ 631 | '| valid Accuracy {:5.2f}'.format( 632 | train_step // args.eval_interval, train_step, 633 | (time.time() - eval_start_time), val_acc) 634 | logging(log_str) 635 | logging('-' * 100) 636 | log_str_dense = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \ 637 | '| Dense valid Accuracy {:5.2f}'.format( 638 | train_step // args.eval_interval, train_step, 639 | (time.time() - eval_start_time), val_acc_dense) 640 | logging(log_str_dense) 641 | logging('-' * 100) 642 | # Save the model if the validation loss is the best we've seen so far. 643 | if not best_val_acc or val_acc > best_val_acc: 644 | if not args.debug: 645 | with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f: 646 | torch.save(model, f) 647 | with open(os.path.join(args.work_dir, 'optimizer.pt'), 'wb') as f: 648 | torch.save(optimizer.state_dict(), f) 649 | best_val_acc = val_acc 650 | 651 | if not best_val_acc_dense or val_acc_dense > best_val_acc_dense: 652 | if not args.debug: 653 | with open(os.path.join(args.work_dir, 'model_dense.pt'), 'wb') as f: 654 | torch.save(model, f) 655 | with open(os.path.join(args.work_dir, 'optimizer_dense.pt'), 'wb') as f: 656 | torch.save(optimizer.state_dict(), f) 657 | best_val_acc_dense = val_acc_dense 658 | 659 | eval_start_time = time.time() 660 | 661 | if train_step == args.dynamic_router_start: 662 | args.freeze_gate = True 663 | freeze_part_weight(model, args) 664 | 665 | if train_step == args.max_step: 666 | break 667 | 668 | 669 | # Loop over epochs. 670 | train_step = 0 671 | train_loss = 0 672 | train_correct = 0 673 | train_n = 0 674 | best_val_acc = None 675 | best_val_acc_dense = None 676 | current_gate = args.moe_top_k 677 | log_start_time = time.time() 678 | eval_start_time = time.time() 679 | all_top_k = [] 680 | 681 | # At any point you can hit Ctrl + C to break out of training early. 682 | try: 683 | for epoch in itertools.count(start=1): 684 | train() 685 | if train_step == args.max_step: 686 | logging('-' * 100) 687 | logging('End of training') 688 | break 689 | except KeyboardInterrupt: 690 | logging('-' * 100) 691 | logging('Exiting from training early') 692 | 693 | 694 | # Load the best saved model. 695 | with open(os.path.join(args.work_dir, 'model_dense.pt'), 'rb') as f: 696 | model = torch.load(f) 697 | para_model = model.to(device) 698 | 699 | # Run on test data. 700 | for gate_number in [1,2,4,8,16,32,64]: 701 | if gate_number <= args.moe_num_expert: 702 | set_top_k(model, gate_number) 703 | test_loss = evaluate(model, te_iter) 704 | logging('=' * 100) 705 | if args.dataset in ['enwik8', 'text8']: 706 | logging('Dense | End of training | Gate-Number {:.0f} | test loss {:5.2f}'.format( 707 | gate_number, test_loss)) 708 | else: 709 | logging('Dense | End of training | Gate-Number {:.0f} | test loss {:5.2f}'.format( 710 | gate_number, test_loss)) 711 | logging('=' * 100) 712 | 713 | 714 | with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f: 715 | model = torch.load(f) 716 | para_model = model.to(device) 717 | 718 | # Run on test data. 719 | for gate_number in [1,2,4,8,16,32,64]: 720 | if gate_number <= args.moe_num_expert: 721 | set_top_k(model, gate_number) 722 | test_loss = evaluate(model, te_iter) 723 | logging('=' * 100) 724 | if args.dataset in ['enwik8', 'text8']: 725 | logging('Top-2 | End of training | Gate-Number {:.0f} | test loss {:5.2f}'.format( 726 | gate_number, test_loss)) 727 | else: 728 | logging('Top-2 | End of training | Gate-Number {:.0f} | test loss {:5.2f}'.format( 729 | gate_number, test_loss)) 730 | logging('=' * 100) 731 | 732 | if args.swad: 733 | with open(os.path.join(args.work_dir, 'model_swa.pt'), 'rb') as f: 734 | model = torch.load(f) 735 | para_model = model.to(device) 736 | 737 | # Run on test data. 738 | for gate_number in [1,2,4,8,16,32,64]: 739 | if gate_number <= args.moe_num_expert: 740 | set_top_k(model, gate_number) 741 | test_loss = evaluate(model, te_iter) 742 | logging('=' * 100) 743 | if args.dataset in ['enwik8', 'text8']: 744 | logging('SWAD | End of training | Gate-Number {:.0f} | test loss {:5.2f}'.format( 745 | gate_number, test_loss)) 746 | else: 747 | logging('SWAD | End of training | Gate-Number {:.0f} | test loss {:5.2f}'.format( 748 | gate_number, test_loss)) 749 | logging('=' * 100) 750 | 751 | if len(all_top_k) and all_top_k[0] != None: 752 | all_top_k = np.array(all_top_k) 753 | print('* Mean Top-K During Training = {}-[{}]'.format(np.mean(all_top_k), all_top_k.shape[0])) 754 | -------------------------------------------------------------------------------- /utils/adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class AdaptiveLogSoftmax(nn.Module): 10 | def __init__(self, in_features, n_classes, cutoffs, keep_order=False): 11 | super(AdaptiveLogSoftmax, self).__init__() 12 | 13 | cutoffs = list(cutoffs) 14 | 15 | if (cutoffs != sorted(cutoffs)) \ 16 | or (min(cutoffs) <= 0) \ 17 | or (max(cutoffs) >= (n_classes - 1)) \ 18 | or (len(set(cutoffs)) != len(cutoffs)) \ 19 | or any([int(c) != c for c in cutoffs]): 20 | 21 | raise ValueError("cutoffs should be a sequence of unique, positive " 22 | "integers sorted in an increasing order, where " 23 | "each value is between 1 and n_classes-1") 24 | 25 | self.in_features = in_features 26 | self.n_classes = n_classes 27 | self.cutoffs = cutoffs + [n_classes] 28 | 29 | self.shortlist_size = self.cutoffs[0] 30 | self.n_clusters = len(self.cutoffs) - 1 31 | self.head_size = self.shortlist_size + self.n_clusters 32 | 33 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features)) 34 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 35 | 36 | self.keep_order = keep_order 37 | 38 | 39 | def forward(self, hidden, target, weight, bias, keep_order=False): 40 | if hidden.size(0) != target.size(0): 41 | raise RuntimeError('Input and target should have the same size ' 42 | 'in the batch dimension.') 43 | 44 | head_weight = torch.cat( 45 | [weight[:self.shortlist_size], self.cluster_weight], dim=0) 46 | head_bias = torch.cat( 47 | [bias[:self.shortlist_size], self.cluster_bias], dim=0) 48 | 49 | head_logit = F.linear(hidden, head_weight, bias=head_bias) 50 | head_logprob = F.log_softmax(head_logit, dim=1) 51 | 52 | nll = torch.zeros_like(target, 53 | dtype=hidden.dtype, device=hidden.device) 54 | 55 | offset = 0 56 | cutoff_values = [0] + self.cutoffs 57 | for i in range(len(cutoff_values) - 1): 58 | l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1] 59 | 60 | mask_i = (target >= l_idx) & (target < h_idx) 61 | indices_i = mask_i.nonzero().squeeze() 62 | 63 | if indices_i.numel() == 0: 64 | continue 65 | 66 | target_i = target.index_select(0, indices_i) - l_idx 67 | head_logprob_i = head_logprob.index_select(0, indices_i) 68 | 69 | if i == 0: 70 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 71 | else: 72 | weight_i = weight[l_idx:h_idx] 73 | bias_i = bias[l_idx:h_idx] 74 | 75 | hidden_i = hidden.index_select(0, indices_i) 76 | 77 | tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i) 78 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 79 | 80 | logprob_i = head_logprob_i[:, -i] \ 81 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 82 | 83 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 84 | nll.index_copy_(0, indices_i, -logprob_i) 85 | else: 86 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 87 | 88 | offset += logprob_i.size(0) 89 | 90 | return nll 91 | -------------------------------------------------------------------------------- /utils/data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.parallel import DataParallel 3 | import torch 4 | from torch.nn.parallel._functions import Scatter 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, torch.Tensor): 15 | try: 16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 17 | except: 18 | print('obj', obj.size()) 19 | print('dim', dim) 20 | print('chunk_sizes', chunk_sizes) 21 | quit() 22 | if isinstance(obj, tuple) and len(obj) > 0: 23 | return list(zip(*map(scatter_map, obj))) 24 | if isinstance(obj, list) and len(obj) > 0: 25 | return list(map(list, zip(*map(scatter_map, obj)))) 26 | if isinstance(obj, dict) and len(obj) > 0: 27 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 28 | return [obj for targets in target_gpus] 29 | 30 | # After scatter_map is called, a scatter_map cell will exist. This cell 31 | # has a reference to the actual function scatter_map, which has references 32 | # to a closure that has a reference to the scatter_map cell (because the 33 | # fn is recursive). To avoid this reference cycle, we set the function to 34 | # None, clearing the cell 35 | try: 36 | return scatter_map(inputs) 37 | finally: 38 | scatter_map = None 39 | 40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 41 | r"""Scatter with support for kwargs dictionary""" 42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 44 | if len(inputs) < len(kwargs): 45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 46 | elif len(kwargs) < len(inputs): 47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 48 | inputs = tuple(inputs) 49 | kwargs = tuple(kwargs) 50 | return inputs, kwargs 51 | 52 | class BalancedDataParallel(DataParallel): 53 | def __init__(self, gpu0_bsz, *args, **kwargs): 54 | self.gpu0_bsz = gpu0_bsz 55 | super().__init__(*args, **kwargs) 56 | 57 | def forward(self, *inputs, **kwargs): 58 | if not self.device_ids: 59 | return self.module(*inputs, **kwargs) 60 | if self.gpu0_bsz == 0: 61 | device_ids = self.device_ids[1:] 62 | else: 63 | device_ids = self.device_ids 64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 65 | if len(self.device_ids) == 1: 66 | return self.module(*inputs[0], **kwargs[0]) 67 | replicas = self.replicate(self.module, self.device_ids) 68 | if self.gpu0_bsz == 0: 69 | replicas = replicas[1:] 70 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 71 | return self.gather(outputs, self.output_device) 72 | 73 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 74 | return parallel_apply(replicas, inputs, kwargs, device_ids) 75 | 76 | def scatter(self, inputs, kwargs, device_ids): 77 | bsz = inputs[0].size(self.dim) 78 | num_dev = len(self.device_ids) 79 | gpu0_bsz = self.gpu0_bsz 80 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 81 | if gpu0_bsz < bsz_unit: 82 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 83 | delta = bsz - sum(chunk_sizes) 84 | for i in range(delta): 85 | chunk_sizes[i + 1] += 1 86 | if gpu0_bsz == 0: 87 | chunk_sizes = chunk_sizes[1:] 88 | else: 89 | return super().scatter(inputs, kwargs, device_ids) 90 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 91 | 92 | -------------------------------------------------------------------------------- /utils/exp_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os, shutil 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | 9 | def logging(s, log_path, print_=True, log_=True): 10 | if print_: 11 | print(s) 12 | if log_: 13 | with open(log_path, 'a+') as f_log: 14 | f_log.write(s + '\n') 15 | 16 | def get_logger(log_path, **kwargs): 17 | return functools.partial(logging, log_path=log_path, **kwargs) 18 | 19 | def create_exp_dir(dir_path, scripts_to_save=None, debug=False): 20 | if debug: 21 | print('Debug Mode : no experiment dir created') 22 | return functools.partial(logging, log_path=None, log_=False) 23 | 24 | if not os.path.exists(dir_path): 25 | os.makedirs(dir_path) 26 | 27 | print('Experiment dir : {}'.format(dir_path)) 28 | if scripts_to_save is not None: 29 | script_path = os.path.join(dir_path, 'scripts') 30 | if not os.path.exists(script_path): 31 | os.makedirs(script_path) 32 | for script in scripts_to_save: 33 | dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) 34 | shutil.copyfile(script, dst_file) 35 | 36 | return get_logger(log_path=os.path.join(dir_path, 'log.txt')) 37 | 38 | def save_checkpoint(model, optimizer, path, epoch): 39 | torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch))) 40 | torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch))) 41 | -------------------------------------------------------------------------------- /utils/log_uniform_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | class LogUniformSampler(object): 6 | def __init__(self, range_max, n_sample): 7 | """ 8 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 9 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 10 | 11 | expected count can be approximated by 1 - (1 - p)^n 12 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 13 | 14 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 15 | """ 16 | with torch.no_grad(): 17 | self.range_max = range_max 18 | log_indices = torch.arange(1., range_max+2., 1.).log_() 19 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 20 | # print('P', self.dist.numpy().tolist()[-30:]) 21 | 22 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 23 | 24 | self.n_sample = n_sample 25 | 26 | def sample(self, labels): 27 | """ 28 | labels: [b1, b2] 29 | Return 30 | true_log_probs: [b1, b2] 31 | samp_log_probs: [n_sample] 32 | neg_samples: [n_sample] 33 | """ 34 | 35 | # neg_samples = torch.empty(0).long() 36 | n_sample = self.n_sample 37 | n_tries = 2 * n_sample 38 | 39 | with torch.no_grad(): 40 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 41 | device = labels.device 42 | neg_samples = neg_samples.to(device) 43 | true_log_probs = self.log_q[labels].to(device) 44 | samp_log_probs = self.log_q[neg_samples].to(device) 45 | return true_log_probs, samp_log_probs, neg_samples 46 | 47 | def sample_logits(embedding, bias, labels, inputs, sampler): 48 | """ 49 | embedding: an nn.Embedding layer 50 | bias: [n_vocab] 51 | labels: [b1, b2] 52 | inputs: [b1, b2, n_emb] 53 | sampler: you may use a LogUniformSampler 54 | Return 55 | logits: [b1, b2, 1 + n_sample] 56 | """ 57 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 58 | n_sample = neg_samples.size(0) 59 | b1, b2 = labels.size(0), labels.size(1) 60 | all_ids = torch.cat([labels.view(-1), neg_samples]) 61 | all_w = embedding(all_ids) 62 | true_w = all_w[: -n_sample].view(b1, b2, -1) 63 | sample_w = all_w[- n_sample:].view(n_sample, -1) 64 | 65 | all_b = bias[all_ids] 66 | true_b = all_b[: -n_sample].view(b1, b2) 67 | sample_b = all_b[- n_sample:] 68 | 69 | hit = (labels[:, :, None] == neg_samples).detach() 70 | 71 | true_logits = torch.einsum('ijk,ijk->ij', 72 | [true_w, inputs]) + true_b - true_log_probs 73 | sample_logits = torch.einsum('lk,ijk->ijl', 74 | [sample_w, inputs]) + sample_b - samp_log_probs 75 | sample_logits.masked_fill_(hit, -1e30) 76 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 77 | 78 | return logits 79 | 80 | 81 | # class LogUniformSampler(object): 82 | # def __init__(self, range_max, unique=False): 83 | # """ 84 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 85 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 86 | # """ 87 | # self.range_max = range_max 88 | # log_indices = torch.arange(1., range_max+2., 1.).log_() 89 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 90 | 91 | # self.unique = unique 92 | 93 | # if self.unique: 94 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0) 95 | 96 | # def sample(self, n_sample, labels): 97 | # pos_sample, new_labels = labels.unique(return_inverse=True) 98 | # n_pos_sample = pos_sample.size(0) 99 | # n_neg_sample = n_sample - n_pos_sample 100 | 101 | # if self.unique: 102 | # self.exclude_mask.index_fill_(0, pos_sample, 1) 103 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) 104 | # self.exclude_mask.index_fill_(0, pos_sample, 0) 105 | # else: 106 | # sample_dist = self.dist 107 | 108 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample) 109 | 110 | # sample = torch.cat([pos_sample, neg_sample]) 111 | # sample_prob = self.dist[sample] 112 | 113 | # return new_labels, sample, sample_prob 114 | 115 | 116 | if __name__ == '__main__': 117 | S, B = 3, 4 118 | n_vocab = 10000 119 | n_sample = 5 120 | H = 32 121 | 122 | labels = torch.LongTensor(S, B).random_(0, n_vocab) 123 | 124 | # sampler = LogUniformSampler(n_vocab, unique=False) 125 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) 126 | 127 | sampler = LogUniformSampler(n_vocab, unique=True) 128 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) 129 | 130 | # print('true_probs', true_probs.numpy().tolist()) 131 | # print('samp_probs', samp_probs.numpy().tolist()) 132 | # print('neg_samples', neg_samples.numpy().tolist()) 133 | 134 | # print('sum', torch.sum(sampler.dist).item()) 135 | 136 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() 137 | 138 | embedding = nn.Embedding(n_vocab, H) 139 | bias = torch.zeros(n_vocab) 140 | inputs = torch.Tensor(S, B, H).normal_() 141 | 142 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) 143 | print('logits', logits.detach().numpy().tolist()) 144 | print('logits shape', logits.size()) 145 | print('out_labels', out_labels.detach().numpy().tolist()) 146 | print('out_labels shape', out_labels.size()) 147 | 148 | -------------------------------------------------------------------------------- /utils/proj_adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import pdb 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) 10 | CUDA_MINOR = int(torch.version.cuda.split('.')[1]) 11 | 12 | class Projection(nn.Module): 13 | def __init__(self, out_feat, in_feat): 14 | self.weight = nn.Parameter(torch.Tensor(out_feat, in_feat)) 15 | 16 | class ProjectedAdaptiveLogSoftmax(nn.Module): 17 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 18 | keep_order=False): 19 | super(ProjectedAdaptiveLogSoftmax, self).__init__() 20 | 21 | self.n_token = n_token 22 | self.d_embed = d_embed 23 | self.d_proj = d_proj 24 | 25 | self.cutoffs = cutoffs + [n_token] 26 | self.cutoff_ends = [0] + self.cutoffs 27 | self.div_val = div_val 28 | 29 | self.shortlist_size = self.cutoffs[0] 30 | self.n_clusters = len(self.cutoffs) - 1 31 | self.head_size = self.shortlist_size + self.n_clusters 32 | 33 | if self.n_clusters > 0: 34 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) 35 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 36 | 37 | self.out_layers = nn.ModuleList() 38 | self.out_projs = nn.ModuleList() 39 | 40 | if div_val == 1: 41 | for i in range(len(self.cutoffs)): 42 | if d_proj != d_embed: 43 | self.out_projs.append( 44 | Projection(d_proj, d_embed) 45 | ) 46 | else: 47 | self.out_projs.append(None) 48 | 49 | self.out_layers.append(nn.Linear(d_embed, n_token)) 50 | else: 51 | for i in range(len(self.cutoffs)): 52 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 53 | d_emb_i = d_embed // (div_val ** i) 54 | 55 | self.out_projs.append( 56 | Projection(d_proj, d_emb_i) 57 | ) 58 | 59 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) 60 | 61 | self.keep_order = keep_order 62 | 63 | def _compute_logit(self, hidden, weight, bias, proj): 64 | if proj is None: 65 | logit = F.linear(hidden, weight, bias=bias) 66 | else: 67 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: 68 | proj_hid = F.linear(hidden, proj.t().contiguous()) 69 | logit = F.linear(proj_hid, weight, bias=bias) 70 | # else: 71 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) 72 | # if bias is not None: 73 | # logit = logit + bias 74 | 75 | return logit 76 | 77 | def forward(self, hidden, target, keep_order=False): 78 | ''' 79 | hidden :: [len*bsz x d_proj] 80 | target :: [len*bsz] 81 | ''' 82 | 83 | if hidden.size(0) != target.size(0): 84 | raise RuntimeError('Input and target should have the same size ' 85 | 'in the batch dimension.') 86 | 87 | if self.n_clusters == 0: 88 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 89 | self.out_layers[0].bias, self.out_projs[0].weight if self.out_projs[0] is not None else None) 90 | nll = -F.log_softmax(logit, dim=-1) \ 91 | .gather(1, target.unsqueeze(1)).squeeze(1) 92 | else: 93 | # construct weights and biases 94 | weights, biases = [], [] 95 | for i in range(len(self.cutoffs)): 96 | if self.div_val == 1: 97 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 98 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 99 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 100 | else: 101 | weight_i = self.out_layers[i].weight 102 | bias_i = self.out_layers[i].bias 103 | 104 | if i == 0: 105 | weight_i = torch.cat( 106 | [weight_i, self.cluster_weight], dim=0) 107 | bias_i = torch.cat( 108 | [bias_i, self.cluster_bias], dim=0) 109 | 110 | weights.append(weight_i) 111 | biases.append(bias_i) 112 | 113 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0].weight if self.out_projs[0] is not None else None 114 | 115 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 116 | head_logprob = F.log_softmax(head_logit, dim=1) 117 | 118 | nll = torch.zeros_like(target, 119 | dtype=hidden.dtype, device=hidden.device) 120 | 121 | offset = 0 122 | cutoff_values = [0] + self.cutoffs 123 | for i in range(len(cutoff_values) - 1): 124 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 125 | 126 | mask_i = (target >= l_idx) & (target < r_idx) 127 | indices_i = mask_i.nonzero().squeeze() 128 | 129 | if indices_i.numel() == 0: 130 | continue 131 | 132 | target_i = target.index_select(0, indices_i) - l_idx 133 | head_logprob_i = head_logprob.index_select(0, indices_i) 134 | 135 | if i == 0: 136 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 137 | else: 138 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i].weight if self.out_projs[i] is not None else None 139 | 140 | hidden_i = hidden.index_select(0, indices_i) 141 | 142 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) 143 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 144 | 145 | logprob_i = head_logprob_i[:, -i] \ 146 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 147 | 148 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 149 | nll.index_copy_(0, indices_i, -logprob_i) 150 | else: 151 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 152 | 153 | offset += logprob_i.size(0) 154 | 155 | return nll 156 | 157 | 158 | 159 | class ProjectedAdaptiveLogSoftmax_new(nn.Module): 160 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 161 | keep_order=False): 162 | super(ProjectedAdaptiveLogSoftmax_new, self).__init__() 163 | 164 | self.n_token = n_token 165 | self.d_embed = d_embed 166 | self.d_proj = d_proj 167 | 168 | self.cutoffs = cutoffs + [n_token] 169 | self.cutoff_ends = [0] + self.cutoffs 170 | self.div_val = div_val 171 | 172 | self.shortlist_size = self.cutoffs[0] 173 | self.n_clusters = len(self.cutoffs) - 1 174 | self.head_size = self.shortlist_size + self.n_clusters 175 | 176 | if self.n_clusters > 0: 177 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) 178 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 179 | 180 | self.out_layers = nn.ModuleList() 181 | self.out_projs = nn.ModuleList() 182 | 183 | if div_val == 1: 184 | for i in range(len(self.cutoffs)): 185 | if d_proj != d_embed: 186 | self.out_projs.append( 187 | Projection(d_proj, d_embed) 188 | ) 189 | else: 190 | self.out_projs.append(None) 191 | 192 | self.out_layers.append(nn.Linear(d_embed, n_token)) 193 | else: 194 | for i in range(len(self.cutoffs)): 195 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 196 | d_emb_i = d_embed // (div_val ** i) 197 | 198 | self.out_projs.append( 199 | Projection(d_proj, d_emb_i) 200 | ) 201 | 202 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) 203 | 204 | self.keep_order = keep_order 205 | 206 | def _compute_logit(self, hidden, weight, bias, proj): 207 | if proj is None: 208 | logit = F.linear(hidden, weight, bias=bias) 209 | else: 210 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: 211 | proj_hid = F.linear(hidden, proj.t().contiguous()) 212 | logit = F.linear(proj_hid, weight, bias=bias) 213 | # else: 214 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) 215 | # if bias is not None: 216 | # logit = logit + bias 217 | 218 | return logit 219 | 220 | def forward(self, hidden, target, keep_order=False): 221 | ''' 222 | hidden :: [len*bsz x d_proj] 223 | target :: [len*bsz] 224 | ''' 225 | 226 | if hidden.size(0) != target.size(0): 227 | raise RuntimeError('Input and target should have the same size ' 228 | 'in the batch dimension.') 229 | 230 | if self.n_clusters == 0: 231 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 232 | self.out_layers[0].bias, self.out_projs[0].weight if self.out_projs[0] is not None else None) 233 | nll = -F.log_softmax(logit, dim=-1) \ 234 | .gather(1, target.unsqueeze(1)).squeeze(1) 235 | else: 236 | # construct weights and biases 237 | weights, biases = [], [] 238 | for i in range(len(self.cutoffs)): 239 | if self.div_val == 1: 240 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 241 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 242 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 243 | else: 244 | weight_i = self.out_layers[i].weight 245 | bias_i = self.out_layers[i].bias 246 | 247 | if i == 0: 248 | weight_i = torch.cat( 249 | [weight_i, self.cluster_weight], dim=0) 250 | bias_i = torch.cat( 251 | [bias_i, self.cluster_bias], dim=0) 252 | 253 | weights.append(weight_i) 254 | biases.append(bias_i) 255 | 256 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0].weight if self.out_projs[0] is not None else None 257 | 258 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 259 | head_logprob = F.log_softmax(head_logit, dim=1) 260 | 261 | nll = torch.zeros_like(target, 262 | dtype=hidden.dtype, device=hidden.device) 263 | 264 | offset = 0 265 | cutoff_values = [0] + self.cutoffs 266 | for i in range(len(cutoff_values) - 1): 267 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 268 | 269 | mask_i = (target >= l_idx) & (target < r_idx) 270 | indices_i = mask_i.nonzero().squeeze() 271 | 272 | if indices_i.numel() == 0: 273 | continue 274 | 275 | target_i = target.index_select(0, indices_i) - l_idx 276 | head_logprob_i = head_logprob.index_select(0, indices_i) 277 | 278 | if i == 0: 279 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 280 | else: 281 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i].weight if self.out_projs[i] is not None else None 282 | 283 | hidden_i = hidden.index_select(0, indices_i) 284 | 285 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) 286 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 287 | 288 | logprob_i = head_logprob_i[:, -i] \ 289 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 290 | 291 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 292 | nll.index_copy_(0, indices_i, -logprob_i) 293 | else: 294 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 295 | 296 | offset += logprob_i.size(0) 297 | 298 | return nll, logit -------------------------------------------------------------------------------- /utils/vocabulary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import csv 4 | import json 5 | import torch 6 | from collections import Counter, OrderedDict 7 | import torch.nn.functional as F 8 | from torch.nn.utils.rnn import pad_sequence 9 | 10 | 11 | 12 | class Vocab(object): 13 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, 14 | delimiter=None, vocab_file=None): 15 | self.counter = Counter() 16 | self.special = special 17 | self.min_freq = min_freq 18 | self.max_size = max_size 19 | self.lower_case = lower_case 20 | self.delimiter = delimiter 21 | self.vocab_file = vocab_file 22 | 23 | def tokenize(self, line, add_eos=False, add_double_eos=False, add_cls_token=False, add_s=False, add_cls_token_last=False): 24 | line = line.strip() 25 | # convert to lower case 26 | if self.lower_case: 27 | line = line.lower() 28 | 29 | # empty delimiter '' will evaluate False 30 | if self.delimiter == '': 31 | symbols = line 32 | else: 33 | symbols = line.split(self.delimiter) 34 | 35 | if add_cls_token: 36 | return [''] + symbols + [''] 37 | elif add_cls_token_last: 38 | return [''] + symbols + [''] 39 | elif add_double_eos: # lm1b 40 | return [''] + symbols + [''] 41 | elif add_eos: 42 | return symbols + [''] 43 | elif add_s: 44 | return symbols + [''] 45 | else: 46 | return symbols 47 | 48 | def count_file(self, path, verbose=False, add_eos=False): 49 | if verbose: print('counting file {} ...'.format(path)) 50 | assert os.path.exists(path) 51 | 52 | sents = [] 53 | with open(path, 'r', encoding='utf-8') as f: 54 | for idx, line in enumerate(f): 55 | if verbose and idx > 0 and idx % 500000 == 0: 56 | print(' line {}'.format(idx)) 57 | symbols = self.tokenize(line, add_eos=add_eos) 58 | self.counter.update(symbols) 59 | sents.append(symbols) 60 | 61 | return sents 62 | 63 | def count_csqa(self, path, num_classes=5, verbose=False, add_eos=False, add_double_eos=False, add_cls_token=False): 64 | if verbose: print('counting file {} ...'.format(path)) 65 | assert os.path.exists(path) 66 | 67 | sents = [] 68 | with open(path, 'r', encoding='utf-8') as f: 69 | for idx, line in enumerate(f): 70 | if verbose and idx > 0 and idx % 500000 == 0: 71 | print(' line {}'.format(idx)) 72 | example = json.loads(line.strip()) 73 | question = example["question"]["stem"] 74 | assert len(example["question"]["choices"]) == num_classes 75 | # format: ` Q: Where would I not want a fox? A: hen house ` 76 | question = "Q: " + question 77 | question_toks = self.tokenize(question, add_eos=add_eos, add_double_eos=add_double_eos, add_cls_token=add_cls_token) 78 | for i, choice in enumerate(example["question"]["choices"]): 79 | src = "A: " + choice["text"] 80 | assert (ord(choice["label"]) - ord("A")) == i 81 | src_bin = self.tokenize(src, add_eos=add_eos) 82 | question_toks.extend(src_bin) 83 | self.counter.update(question_toks) 84 | sents.append(question_toks) 85 | return sents 86 | 87 | def count_sst2(self, path, verbose=False, add_eos=False, add_double_eos=False, add_cls_token=False): 88 | if verbose: print('counting file {} ...'.format(path)) 89 | assert os.path.exists(path) 90 | sents = [] 91 | with open(path, 'r', encoding='utf-8') as f: 92 | tsv_file = csv.reader(f, delimiter="\t") 93 | for line in tsv_file: 94 | if not line[1] in ['0', '1']: 95 | # print('* Ignore ', line) 96 | continue 97 | sentence, label = line[0], int(line[1]) 98 | assert label in [0,1] 99 | sentence_toks = self.tokenize(sentence, add_eos=add_eos, add_double_eos=add_double_eos, add_cls_token=add_cls_token) 100 | self.counter.update(sentence_toks) 101 | sents.append(sentence_toks) 102 | return sents 103 | 104 | def count_sents(self, sents, verbose=False): 105 | """ 106 | sents : a list of sentences, each a list of tokenized symbols 107 | """ 108 | if verbose: print('counting {} sents ...'.format(len(sents))) 109 | for idx, symbols in enumerate(sents): 110 | if verbose and idx > 0 and idx % 500000 == 0: 111 | print(' line {}'.format(idx)) 112 | self.counter.update(symbols) 113 | 114 | def _build_from_file(self, vocab_file): 115 | self.idx2sym = [] 116 | self.sym2idx = OrderedDict() 117 | 118 | with open(vocab_file, 'r', encoding='utf-8') as f: 119 | for line in f: 120 | symb = line.strip().split()[0] 121 | self.add_symbol(symb) 122 | self.unk_idx = self.sym2idx[''] 123 | 124 | def build_vocab(self): 125 | if self.vocab_file: 126 | print('building vocab from {}'.format(self.vocab_file)) 127 | self._build_from_file(self.vocab_file) 128 | print('final vocab size {}'.format(len(self))) 129 | else: 130 | print('building vocab with min_freq={}, max_size={}'.format( 131 | self.min_freq, self.max_size)) 132 | self.idx2sym = [] 133 | self.sym2idx = OrderedDict() 134 | 135 | for sym in self.special: 136 | self.add_special(sym) 137 | 138 | for sym, cnt in self.counter.most_common(self.max_size): 139 | if cnt < self.min_freq: break 140 | self.add_symbol(sym) 141 | 142 | print('final vocab size {} from {} unique tokens'.format( 143 | len(self), len(self.counter))) 144 | 145 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True, 146 | add_double_eos=False): 147 | if verbose: print('encoding file {} ...'.format(path)) 148 | assert os.path.exists(path) 149 | encoded = [] 150 | with open(path, 'r', encoding='utf-8') as f: 151 | for idx, line in enumerate(f): 152 | if verbose and idx > 0 and idx % 500000 == 0: 153 | print(' line {}'.format(idx)) 154 | symbols = self.tokenize(line, add_eos=add_eos, 155 | add_double_eos=add_double_eos) 156 | encoded.append(self.convert_to_tensor(symbols)) 157 | 158 | if ordered: 159 | encoded = torch.cat(encoded) 160 | 161 | return encoded 162 | 163 | def encode_csqa_file(self, path, ordered=False, num_classes=5, verbose=False, add_eos=False, 164 | add_double_eos=False, add_cls_token=False): 165 | if verbose: print('encoding file {} ...'.format(path)) 166 | assert os.path.exists(path) 167 | encoded = [[] for i in range(num_classes)] 168 | labels = [] 169 | 170 | with open(path, 'r', encoding='utf-8') as f: 171 | for idx, line in enumerate(f): 172 | if verbose and idx > 0 and idx % 500000 == 0: 173 | print(' line {}'.format(idx)) 174 | example = json.loads(line.strip()) 175 | if "answerKey" in example: 176 | label = ord(example["answerKey"]) - ord("A") 177 | labels.append(label) 178 | question = example["question"]["stem"] 179 | assert len(example["question"]["choices"]) == num_classes 180 | # format: ` Q: Where would I not want a fox? A: hen house ` 181 | question = "Q: " + question 182 | question_bin = self.tokenize(question, add_eos=add_eos, 183 | add_double_eos=add_double_eos, add_cls_token=add_cls_token) 184 | for i, choice in enumerate(example["question"]["choices"]): 185 | src = " A: " + choice["text"] 186 | assert (ord(choice["label"]) - ord("A")) == i 187 | src_bin = question_bin + self.tokenize(src, add_s=True) 188 | encoded[i].append(self.convert_to_tensor(src_bin)) 189 | 190 | labels = torch.LongTensor(labels) 191 | 192 | # pdb.set_trace() 193 | 194 | # if ordered: 195 | # for idx in range(num_classes): 196 | # encoded[idx] = pad_sequence(encoded[idx]) 197 | 198 | # encoded = pad_sequence(encoded) 199 | # print(encoded.shape) 200 | 201 | return [encoded, labels] 202 | 203 | def encode_sst2_file(self, path, verbose=False, add_eos=False, 204 | add_double_eos=False, add_cls_token=False): 205 | if verbose: print('encoding file {} ...'.format(path)) 206 | assert os.path.exists(path) 207 | encoded = [] 208 | labels = [] 209 | with open(path, 'r', encoding='utf-8') as f: 210 | tsv_file = csv.reader(f, delimiter="\t") 211 | for line in tsv_file: 212 | if not line[1] in ['0', '1']: 213 | print('* Ignore ', line) 214 | continue 215 | sentence, label = line[0], int(line[1]) 216 | assert label in [0,1] 217 | sentence_toks = self.tokenize(sentence, add_eos=add_eos, add_double_eos=add_double_eos, add_cls_token=add_cls_token) 218 | encoded.append(self.convert_to_tensor(sentence_toks)) 219 | labels.append(label) 220 | 221 | labels = torch.LongTensor(labels) 222 | return [encoded, labels] 223 | 224 | def encode_sst2_file_v2(self, path, verbose=False, add_eos=False, 225 | add_double_eos=False, add_cls_token_last=False): 226 | if verbose: print('encoding file {} ...'.format(path)) 227 | assert os.path.exists(path) 228 | encoded = [] 229 | labels = [] 230 | with open(path, 'r', encoding='utf-8') as f: 231 | tsv_file = csv.reader(f, delimiter="\t") 232 | for line in tsv_file: 233 | if not line[1] in ['0', '1']: 234 | print('* Ignore ', line) 235 | continue 236 | sentence, label = line[0], int(line[1]) 237 | assert label in [0,1] 238 | sentence_toks = self.tokenize(sentence, add_eos=add_eos, add_double_eos=add_double_eos, add_cls_token_last=add_cls_token_last) 239 | encoded.append(self.convert_to_tensor(sentence_toks)) 240 | labels.append(label) 241 | 242 | labels = torch.LongTensor(labels) 243 | return [encoded, labels] 244 | 245 | def encode_sents(self, sents, ordered=False, verbose=False): 246 | if verbose: print('encoding {} sents ...'.format(len(sents))) 247 | encoded = [] 248 | for idx, symbols in enumerate(sents): 249 | if verbose and idx > 0 and idx % 500000 == 0: 250 | print(' line {}'.format(idx)) 251 | encoded.append(self.convert_to_tensor(symbols)) 252 | 253 | if ordered: 254 | encoded = torch.cat(encoded) 255 | 256 | return encoded 257 | 258 | def add_special(self, sym): 259 | if sym not in self.sym2idx: 260 | self.idx2sym.append(sym) 261 | self.sym2idx[sym] = len(self.idx2sym) - 1 262 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) 263 | 264 | def add_symbol(self, sym): 265 | if sym not in self.sym2idx: 266 | self.idx2sym.append(sym) 267 | self.sym2idx[sym] = len(self.idx2sym) - 1 268 | 269 | def get_sym(self, idx): 270 | assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) 271 | return self.idx2sym[idx] 272 | 273 | def get_idx(self, sym): 274 | if sym in self.sym2idx: 275 | return self.sym2idx[sym] 276 | else: 277 | # print('encounter unk {}'.format(sym)) 278 | print(sym) 279 | assert '' not in sym 280 | assert hasattr(self, 'unk_idx') 281 | return self.sym2idx.get(sym, self.unk_idx) 282 | 283 | def get_symbols(self, indices): 284 | return [self.get_sym(idx) for idx in indices] 285 | 286 | def get_indices(self, symbols): 287 | return [self.get_idx(sym) for sym in symbols] 288 | 289 | def convert_to_tensor(self, symbols): 290 | return torch.LongTensor(self.get_indices(symbols)) 291 | 292 | def convert_to_sent(self, indices, exclude=None): 293 | if exclude is None: 294 | return ' '.join([self.get_sym(idx) for idx in indices]) 295 | else: 296 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) 297 | 298 | def __len__(self): 299 | return len(self.idx2sym) 300 | --------------------------------------------------------------------------------