├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── examples ├── __init__.py └── fairseq │ ├── README.md │ ├── __init__.py │ ├── generate.py │ ├── interactive.py │ ├── models │ ├── __init__.py │ ├── bert.py │ ├── language_modeling.py │ └── machine_translation.py │ ├── tasks │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── basic_loader.py │ │ ├── mlm_loader.py │ │ └── utils.py │ └── pretraining.py │ ├── train.py │ └── utils │ ├── __init__.py │ └── sparse_clip.py ├── setup.py ├── tests ├── __init__.py ├── test_decoder.py ├── test_encoder.py └── test_encoder_decoder.py └── torchscale ├── __init__.py ├── architecture ├── __init__.py ├── config.py ├── decoder.py ├── encoder.py ├── encoder_decoder.py └── utils.py ├── component ├── __init__.py ├── droppath.py ├── embedding.py ├── feedforward_network.py ├── multihead_attention.py ├── multiway_network.py ├── relative_position_bias.py ├── xmoe │ ├── __init__.py │ ├── moe_layer.py │ └── routing.py └── xpos_relative_position.py └── model ├── BEiT3.py └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Aa][Rr][Mm]/ 27 | [Aa][Rr][Mm]64/ 28 | bld/ 29 | [Bb]in/ 30 | [Oo]bj/ 31 | [Ll]og/ 32 | [Ll]ogs/ 33 | 34 | # Visual Studio 2015/2017 cache/options directory 35 | .vs/ 36 | # Uncomment if you have tasks that create the project's static files in wwwroot 37 | #wwwroot/ 38 | 39 | # Visual Studio 2017 auto generated files 40 | Generated\ Files/ 41 | 42 | # MSTest test Results 43 | [Tt]est[Rr]esult*/ 44 | [Bb]uild[Ll]og.* 45 | 46 | # NUnit 47 | *.VisualState.xml 48 | TestResult.xml 49 | nunit-*.xml 50 | 51 | # Build Results of an ATL Project 52 | [Dd]ebugPS/ 53 | [Rr]eleasePS/ 54 | dlldata.c 55 | 56 | # Benchmark Results 57 | BenchmarkDotNet.Artifacts/ 58 | 59 | # .NET Core 60 | project.lock.json 61 | project.fragment.lock.json 62 | artifacts/ 63 | 64 | # StyleCop 65 | StyleCopReport.xml 66 | 67 | # Files built by Visual Studio 68 | *_i.c 69 | *_p.c 70 | *_h.h 71 | *.ilk 72 | *.meta 73 | *.obj 74 | *.iobj 75 | *.pch 76 | *.pdb 77 | *.ipdb 78 | *.pgc 79 | *.pgd 80 | *.rsp 81 | *.sbr 82 | *.tlb 83 | *.tli 84 | *.tlh 85 | *.tmp 86 | *.tmp_proj 87 | *_wpftmp.csproj 88 | *.log 89 | *.vspscc 90 | *.vssscc 91 | .builds 92 | *.pidb 93 | *.svclog 94 | *.scc 95 | 96 | # Chutzpah Test files 97 | _Chutzpah* 98 | 99 | # Visual C++ cache files 100 | ipch/ 101 | *.aps 102 | *.ncb 103 | *.opendb 104 | *.opensdf 105 | *.sdf 106 | *.cachefile 107 | *.VC.db 108 | *.VC.VC.opendb 109 | 110 | # Visual Studio profiler 111 | *.psess 112 | *.vsp 113 | *.vspx 114 | *.sap 115 | 116 | # Visual Studio Trace Files 117 | *.e2e 118 | 119 | # TFS 2012 Local Workspace 120 | $tf/ 121 | 122 | # Guidance Automation Toolkit 123 | *.gpState 124 | 125 | # ReSharper is a .NET coding add-in 126 | _ReSharper*/ 127 | *.[Rr]e[Ss]harper 128 | *.DotSettings.user 129 | 130 | # TeamCity is a build add-in 131 | _TeamCity* 132 | 133 | # DotCover is a Code Coverage Tool 134 | *.dotCover 135 | 136 | # AxoCover is a Code Coverage Tool 137 | .axoCover/* 138 | !.axoCover/settings.json 139 | 140 | # Visual Studio code coverage results 141 | *.coverage 142 | *.coveragexml 143 | 144 | # NCrunch 145 | _NCrunch_* 146 | .*crunch*.local.xml 147 | nCrunchTemp_* 148 | 149 | # MightyMoose 150 | *.mm.* 151 | AutoTest.Net/ 152 | 153 | # Web workbench (sass) 154 | .sass-cache/ 155 | 156 | # Installshield output folder 157 | [Ee]xpress/ 158 | 159 | # DocProject is a documentation generator add-in 160 | DocProject/buildhelp/ 161 | DocProject/Help/*.HxT 162 | DocProject/Help/*.HxC 163 | DocProject/Help/*.hhc 164 | DocProject/Help/*.hhk 165 | DocProject/Help/*.hhp 166 | DocProject/Help/Html2 167 | DocProject/Help/html 168 | 169 | # Click-Once directory 170 | publish/ 171 | 172 | # Publish Web Output 173 | *.[Pp]ublish.xml 174 | *.azurePubxml 175 | # Note: Comment the next line if you want to checkin your web deploy settings, 176 | # but database connection strings (with potential passwords) will be unencrypted 177 | *.pubxml 178 | *.publishproj 179 | 180 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 181 | # checkin your Azure Web App publish settings, but sensitive information contained 182 | # in these scripts will be unencrypted 183 | PublishScripts/ 184 | 185 | # NuGet Packages 186 | *.nupkg 187 | # NuGet Symbol Packages 188 | *.snupkg 189 | # The packages folder can be ignored because of Package Restore 190 | **/[Pp]ackages/* 191 | # except build/, which is used as an MSBuild target. 192 | !**/[Pp]ackages/build/ 193 | # Uncomment if necessary however generally it will be regenerated when needed 194 | #!**/[Pp]ackages/repositories.config 195 | # NuGet v3's project.json files produces more ignorable files 196 | *.nuget.props 197 | *.nuget.targets 198 | 199 | # Microsoft Azure Build Output 200 | csx/ 201 | *.build.csdef 202 | 203 | # Microsoft Azure Emulator 204 | ecf/ 205 | rcf/ 206 | 207 | # Windows Store app package directories and files 208 | AppPackages/ 209 | BundleArtifacts/ 210 | Package.StoreAssociation.xml 211 | _pkginfo.txt 212 | *.appx 213 | *.appxbundle 214 | *.appxupload 215 | 216 | # Visual Studio cache files 217 | # files ending in .cache can be ignored 218 | *.[Cc]ache 219 | # but keep track of directories ending in .cache 220 | !?*.[Cc]ache/ 221 | 222 | # Others 223 | ClientBin/ 224 | ~$* 225 | *~ 226 | *.dbmdl 227 | *.dbproj.schemaview 228 | *.jfm 229 | *.pfx 230 | *.publishsettings 231 | orleans.codegen.cs 232 | 233 | # Including strong name files can present a security risk 234 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 235 | #*.snk 236 | 237 | # Since there are multiple workflows, uncomment next line to ignore bower_components 238 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 239 | #bower_components/ 240 | 241 | # RIA/Silverlight projects 242 | Generated_Code/ 243 | 244 | # Backup & report files from converting an old project file 245 | # to a newer Visual Studio version. Backup files are not needed, 246 | # because we have git ;-) 247 | _UpgradeReport_Files/ 248 | Backup*/ 249 | UpgradeLog*.XML 250 | UpgradeLog*.htm 251 | ServiceFabricBackup/ 252 | *.rptproj.bak 253 | 254 | # SQL Server files 255 | *.mdf 256 | *.ldf 257 | *.ndf 258 | 259 | # Business Intelligence projects 260 | *.rdl.data 261 | *.bim.layout 262 | *.bim_*.settings 263 | *.rptproj.rsuser 264 | *- [Bb]ackup.rdl 265 | *- [Bb]ackup ([0-9]).rdl 266 | *- [Bb]ackup ([0-9][0-9]).rdl 267 | 268 | # Microsoft Fakes 269 | FakesAssemblies/ 270 | 271 | # GhostDoc plugin setting file 272 | *.GhostDoc.xml 273 | 274 | # Node.js Tools for Visual Studio 275 | .ntvs_analysis.dat 276 | node_modules/ 277 | 278 | # Visual Studio 6 build log 279 | *.plg 280 | 281 | # Visual Studio 6 workspace options file 282 | *.opt 283 | 284 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 285 | *.vbw 286 | 287 | # Visual Studio LightSwitch build output 288 | **/*.HTMLClient/GeneratedArtifacts 289 | **/*.DesktopClient/GeneratedArtifacts 290 | **/*.DesktopClient/ModelManifest.xml 291 | **/*.Server/GeneratedArtifacts 292 | **/*.Server/ModelManifest.xml 293 | _Pvt_Extensions 294 | 295 | # Paket dependency manager 296 | .paket/paket.exe 297 | paket-files/ 298 | 299 | # FAKE - F# Make 300 | .fake/ 301 | 302 | # CodeRush personal settings 303 | .cr/personal 304 | 305 | # Python Tools for Visual Studio (PTVS) 306 | __pycache__/ 307 | *.pyc 308 | 309 | # Cake - Uncomment if you are using it 310 | # tools/** 311 | # !tools/packages.config 312 | 313 | # Tabs Studio 314 | *.tss 315 | 316 | # Telerik's JustMock configuration file 317 | *.jmconfig 318 | 319 | # BizTalk build output 320 | *.btp.cs 321 | *.btm.cs 322 | *.odx.cs 323 | *.xsd.cs 324 | 325 | # OpenCover UI analysis results 326 | OpenCover/ 327 | 328 | # Azure Stream Analytics local run output 329 | ASALocalRun/ 330 | 331 | # MSBuild Binary and Structured Log 332 | *.binlog 333 | 334 | # NVidia Nsight GPU debugger configuration file 335 | *.nvuser 336 | 337 | # MFractors (Xamarin productivity tool) working folder 338 | .mfractor/ 339 | 340 | # Local History for Visual Studio 341 | .localhistory/ 342 | 343 | # BeatPulse healthcheck temp database 344 | healthchecksdb 345 | 346 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 347 | MigrationBackup/ 348 | 349 | # Ionide (cross platform F# VS Code tools) working folder 350 | .ionide/ 351 | 352 | 353 | # Byte-compiled / optimized / DLL files 354 | __pycache__/ 355 | *.py[cod] 356 | *$py.class 357 | 358 | # C extensions 359 | *.so 360 | 361 | # Distribution / packaging 362 | .Python 363 | build/ 364 | develop-eggs/ 365 | dist/ 366 | downloads/ 367 | eggs/ 368 | .eggs/ 369 | lib/ 370 | lib64/ 371 | parts/ 372 | sdist/ 373 | var/ 374 | wheels/ 375 | share/python-wheels/ 376 | *.egg-info/ 377 | .installed.cfg 378 | *.egg 379 | MANIFEST 380 | 381 | # PyInstaller 382 | # Usually these files are written by a python script from a template 383 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 384 | *.manifest 385 | *.spec 386 | 387 | # Installer logs 388 | pip-log.txt 389 | pip-delete-this-directory.txt 390 | 391 | # Unit test / coverage reports 392 | htmlcov/ 393 | .tox/ 394 | .nox/ 395 | .coverage 396 | .coverage.* 397 | .cache 398 | nosetests.xml 399 | coverage.xml 400 | *.cover 401 | *.py,cover 402 | .hypothesis/ 403 | .pytest_cache/ 404 | cover/ 405 | 406 | # Translations 407 | *.mo 408 | *.pot 409 | 410 | # Django stuff: 411 | *.log 412 | local_settings.py 413 | db.sqlite3 414 | db.sqlite3-journal 415 | 416 | # Flask stuff: 417 | instance/ 418 | .webassets-cache 419 | 420 | # Scrapy stuff: 421 | .scrapy 422 | 423 | # Sphinx documentation 424 | docs/_build/ 425 | 426 | # PyBuilder 427 | .pybuilder/ 428 | target/ 429 | 430 | # Jupyter Notebook 431 | .ipynb_checkpoints 432 | 433 | # IPython 434 | profile_default/ 435 | ipython_config.py 436 | 437 | # pyenv 438 | # For a library or package, you might want to ignore these files since the code is 439 | # intended to run in multiple environments; otherwise, check them in: 440 | # .python-version 441 | 442 | # pipenv 443 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 444 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 445 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 446 | # install all needed dependencies. 447 | #Pipfile.lock 448 | 449 | # poetry 450 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 451 | # This is especially recommended for binary packages to ensure reproducibility, and is more 452 | # commonly ignored for libraries. 453 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 454 | #poetry.lock 455 | 456 | # pdm 457 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 458 | #pdm.lock 459 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 460 | # in version control. 461 | # https://pdm.fming.dev/#use-with-ide 462 | .pdm.toml 463 | 464 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 465 | __pypackages__/ 466 | 467 | # Celery stuff 468 | celerybeat-schedule 469 | celerybeat.pid 470 | 471 | # SageMath parsed files 472 | *.sage.py 473 | 474 | # Environments 475 | .env 476 | .venv 477 | env/ 478 | venv/ 479 | ENV/ 480 | env.bak/ 481 | venv.bak/ 482 | 483 | # Spyder project settings 484 | .spyderproject 485 | .spyproject 486 | 487 | # Rope project settings 488 | .ropeproject 489 | 490 | # mkdocs documentation 491 | /site 492 | 493 | # mypy 494 | .mypy_cache/ 495 | .dmypy.json 496 | dmypy.json 497 | 498 | # Pyre type checker 499 | .pyre/ 500 | 501 | # pytype static type analyzer 502 | .pytype/ 503 | 504 | # Cython debug symbols 505 | cython_debug/ 506 | 507 | # PyCharm 508 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 509 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 510 | # and can be added to the global gitignore or merged into this file. For a more nuclear 511 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 512 | #.idea/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **LEX**: A Length-Extrapolatable Transformer 2 | 3 | ## Key Feature 4 | - [**XPos**](https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py): An Extrapolatable Position Embedding for Transformer decoder. 5 | - [**BCA**](https://github.com/sunyt32/torchscale/blob/main/torchscale/component/multihead_attention.py#L101): An efficient implementation for Block Causal Attention. 6 | 7 | ## Third-Party Implementation 8 | - XPos: [**Flash-Attention**](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/layers/rotary.py) 9 | 10 | 11 | 12 | # TorchScale - A Library for Transformers at (Any) Scale 13 | 14 |

15 | MIT License 16 | MIT License 17 |

18 | 19 | TorchScale is a PyTorch library that allows researchers and developers to scale up Transformers efficiently and effectively. 20 | It has the implementation of fundamental research to improve modeling generality and capability as well as training stability and efficiency of scaling Transformers. 21 | 22 | - Stability - [**DeepNet**](https://arxiv.org/abs/2203.00555): scaling Transformers to 1,000 Layers and beyond 23 | - Generality - [**Foundation Transformers (Magneto)**](https://arxiv.org/abs/2210.06423): towards true general-purpose modeling across tasks and modalities (including language, vision, speech, and multimodal) 24 | - Efficiency - [**X-MoE**](https://arxiv.org/abs/2204.09179): scalable & finetunable sparse Mixture-of-Experts (MoE) 25 | - Extrapolatablility - [**LEX**](https://arxiv.org/abs/2212.10554): A Length-Extrapolatable Transformer 26 | 27 | ## News 28 | 29 | - November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)] 30 | 31 | ## Installation 32 | 33 | To install: 34 | ``` 35 | pip install torchscale 36 | ``` 37 | 38 | Alternatively, you can develop it locally: 39 | ``` 40 | git clone https://github.com/microsoft/torchscale.git 41 | cd torchscale 42 | pip install -e . 43 | ``` 44 | 45 | ## Getting Started 46 | 47 | It takes only several lines of code to create a model with the above fundamental research features enabled. Here is how to quickly obtain a BERT-like encoder: 48 | 49 | ```python 50 | >>> from torchscale.architecture.config import EncoderConfig 51 | >>> from torchscale.architecture.encoder import Encoder 52 | 53 | >>> config = EncoderConfig(vocab_size=64000) 54 | >>> model = Encoder(config) 55 | 56 | >>> print(model) 57 | ``` 58 | 59 | We also support the `Decoder` architecture and the `EncoderDecoder` architecture: 60 | 61 | ```python 62 | # Creating a decoder model 63 | >>> from torchscale.architecture.config import DecoderConfig 64 | >>> from torchscale.architecture.decoder import Decoder 65 | 66 | >>> config = DecoderConfig(vocab_size=64000) 67 | >>> decoder = Decoder(config) 68 | >>> print(decoder) 69 | 70 | # Creating a encoder-decoder model 71 | >>> from torchscale.architecture.config import EncoderDecoderConfig 72 | >>> from torchscale.architecture.encoder_decoder import EncoderDecoder 73 | 74 | >>> config = EncoderDecoderConfig(vocab_size=64000) 75 | >>> encdec = EncoderDecoder(config) 76 | >>> print(encdec) 77 | ``` 78 | 79 | ## Key Features 80 | 81 | - [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555) 82 | * enabled by setting *deepnorm=True* in the `Config` class. 83 | * It adjusts both the residual connection and the initialization method according to the model architecture (i.e., encoder, decoder, or encoder-decoder). 84 | 85 | - [SubLN for the model generality and the training stability](https://arxiv.org/abs/2210.06423) 86 | * enabled by *subln=True*. This is enabled by default. 87 | * It introduces another LayerNorm to each sublayer and adjusts the initialization according to the model architecture. 88 | * Note that SubLN and DeepNorm cannot be used in one single model. 89 | 90 | - [X-MoE: efficient and finetunable sparse MoE modeling](https://arxiv.org/abs/2204.09179) 91 | * enabled by *use_xmoe=True*. 92 | * It replaces every *'moe_freq'* `FeedForwardNetwork` layers with the X-MoE layers. 93 | 94 | - [Multiway architecture for multimodality](https://arxiv.org/abs/2208.10442) 95 | * enabled by *multiway=True*. 96 | * It provides a pool of Transformer's parameters used for different modalities. 97 | 98 | - [Extrapolatable position embedding (Xpos)](https://arxiv.org/abs/1910.10683) 99 | * enabled by *xpos_rel_pos=True*. 100 | 101 | - [Blockwise Causal Attention (BCA)](https://arxiv.org/abs/2212.10554) 102 | * enabled by adjusting *block_size*. If *block_size=-1*, BCA will not be implemented. 103 | * Setting *block_size* as pre-training length is recommended. 104 | 105 | - [SparseClip: improving the gradient clipping for sparse MoE models](https://arxiv.org/abs/2211.13184) 106 | * we provide a [sample code](examples/fairseq/utils/sparse_clip.py) that can be easily adapted to the FairSeq (or other) repo. 107 | 108 | Most of the features above can be used by simply passing the corresponding parameters to the config. For example: 109 | 110 | ```python 111 | >>> from torchscale.architecture.config import EncoderConfig 112 | >>> from torchscale.architecture.encoder import Encoder 113 | 114 | >>> config = EncoderConfig(vocab_size=64000, deepnorm=True, multiway=True) 115 | >>> model = Encoder(config) 116 | 117 | >>> print(model) 118 | ``` 119 | 120 | ## Examples 121 | 122 | We have the examples of how to use TorchScale in the following scenarios/tasks: 123 | 124 | - Language 125 | 126 | * [Decoder/GPT](examples/fairseq/README.md#example-gpt-pretraining) 127 | 128 | * [Encoder-Decoder/Neural Machine Translation](examples/fairseq/README.md#example-machine-translation) 129 | 130 | * [Encoder/BERT](examples/fairseq/README.md#example-bert-pretraining) 131 | 132 | - Vision 133 | 134 | * ViT/BEiT [In progress] 135 | 136 | - Speech 137 | 138 | - Multimodal 139 | 140 | * [Multiway Transformers/BEiT-3](torchscale/model/BEiT3.py) [In progress] 141 | 142 | We plan to provide more examples regarding different tasks (e.g. vision pretraining and speech recognition) and various deep learning toolkits (e.g. [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). Any comments or PRs are welcome! 143 | 144 | ## Results 145 | 146 | ### Stability Evaluation 147 | 148 |

149 | 150 |

151 | 152 | The training curve is smooth by using TorchScale, while the baseline Transformer cannot converge. 153 | 154 | ### Scaling-up Experiments 155 | 156 |

157 | 158 |

159 | 160 | TorchScale supports arbitrary depths and widths, successfully scaling-up the models without pain. 161 | 162 | ## Acknowledgments 163 | 164 | Some implementations in TorchScale are either adapted from or inspired by the [FairSeq](https://github.com/facebookresearch/fairseq) repository and the [UniLM](https://github.com/microsoft/unilm) repository. 165 | 166 | ## Citations 167 | 168 | If you find this repository useful, please consider citing our work: 169 | 170 | ``` 171 | @article{torchscale, 172 | author = {Shuming Ma and Hongyu Wang and Shaohan Huang and Wenhui Wang and Zewen Chi and Li Dong and Alon Benhaim and Barun Patra and Vishrav Chaudhary and Xia Song and Furu Wei}, 173 | title = {{TorchScale}: {Transformers} at Scale}, 174 | journal = {CoRR}, 175 | volume = {abs/2211.13184}, 176 | year = {2022} 177 | } 178 | ``` 179 | 180 | ``` 181 | @article{deepnet, 182 | author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei}, 183 | title = {{DeepNet}: Scaling {Transformers} to 1,000 Layers}, 184 | journal = {CoRR}, 185 | volume = {abs/2203.00555}, 186 | year = {2022}, 187 | } 188 | ``` 189 | 190 | ``` 191 | @article{magneto, 192 | author = {Hongyu Wang and Shuming Ma and Shaohan Huang and Li Dong and Wenhui Wang and Zhiliang Peng and Yu Wu and Payal Bajaj and Saksham Singhal and Alon Benhaim and Barun Patra and Zhun Liu and Vishrav Chaudhary and Xia Song and Furu Wei}, 193 | title = {Foundation {Transformers}}, 194 | journal = {CoRR}, 195 | volume = {abs/2210.06423}, 196 | year = {2022} 197 | } 198 | ``` 199 | 200 | ``` 201 | @inproceedings{xmoe, 202 | title={On the Representation Collapse of Sparse Mixture of Experts}, 203 | author={Zewen Chi and Li Dong and Shaohan Huang and Damai Dai and Shuming Ma and Barun Patra and Saksham Singhal and Payal Bajaj and Xia Song and Xian-Ling Mao and Heyan Huang and Furu Wei}, 204 | booktitle={Advances in Neural Information Processing Systems}, 205 | year={2022}, 206 | url={https://openreview.net/forum?id=mWaYC6CZf5} 207 | } 208 | ``` 209 | 210 | ## Contributing 211 | 212 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 213 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 214 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 215 | 216 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 217 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 218 | provided by the bot. You will only need to do this once across all repos using our CLA. 219 | 220 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 221 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 222 | contact [Furu Wei](mailto:fuwei@microsoft.com) and [Shuming Ma](mailto:shumma@microsoft.com) with any additional questions or comments. 223 | 224 | ## Trademarks 225 | 226 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 227 | trademarks or logos is subject to and must follow 228 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 229 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 230 | Any use of third-party trademarks or logos are subject to those third-party's policies. 231 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /examples/fairseq/README.md: -------------------------------------------------------------------------------- 1 | # Example: Integration with FairSeq 2 | 3 | ## Setup 4 | 5 | ```bash 6 | # Install the repo as a package: 7 | git clone https://github.com/msranlp/torchscale.git 8 | cd torchscale 9 | pip install -e . 10 | pip install git+https://github.com/shumingma/fairseq.git@moe 11 | pip install git+https://github.com/shumingma/infinibatch.git 12 | pip install iopath 13 | pip install --upgrade numpy 14 | ``` 15 | 16 | ## Example: BERT Pretraining 17 | 18 | ### Data Format 19 | 20 | We use a [streaming dataloader](https://github.com/microsoft/infinibatch) to read the data on-the-fly from the disk. It requires the data sharded into multiple small files (e.g. 10K lines per file), as well as a JSON file to contain some meta data and the paths to these files. 21 | 22 | The overall data directory should be organized as follows: 23 | ``` 24 | Data/ 25 | ├── json/ 26 | │ ├── train.json 27 | │ └── valid.json 28 | ├── shard/ 29 | │ ├── train/ 30 | │ │ ├── 00000.txt 31 | │ │ ├── 00001.txt 32 | │ │ └── ... 33 | │ └── valid/ 34 | │ ├── 00000.txt 35 | │ ├── 00001.txt 36 | │ └── ... 37 | ├── dict.txt 38 | └── sentencepiece.bpe.model 39 | ``` 40 | 41 | We recommend that each sharded data files contains no more than 10K lines with one sentence per line, and two documents should be separated with an empty line. 42 | ``` 43 | Document 1 Line 1 44 | Document 1 Line 2 45 | Document 1 Line 3 46 | 47 | Document 2 Line 1 48 | Document 2 Line 2 49 | 50 | ... 51 | ``` 52 | 53 | Also, the JSON file should be in the format like this: 54 | ``` 55 | [ 56 | { 57 | "source": [ 58 | "shard/train/00000.txt", 59 | "shard/train/00001.txt", 60 | ... 61 | ], 62 | "source_lang": "en", 63 | "weight": 1.0 64 | } 65 | ] 66 | ``` 67 | 68 | You can quickly get started with our processed vocabulary files: [sentencepiece.bpe.model](https://publicmodel.blob.core.windows.net/torchscale/vocab/sentencepiece.bpe.model) and [dict.txt](https://publicmodel.blob.core.windows.net/torchscale/vocab/dict.txt). Note that this vocabulary is English-only with 64K tokens. To train a new `sentencepiece.bpe.model` on your own data, please refer to the [SentencePiece](https://github.com/google/sentencepiece) repo. With the sentecepiece model and the installed `sentencepiece` library, you can extract the `dict.txt` file from it by 69 | ``` 70 | spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt 71 | ``` 72 | 73 | ### Training Command 74 | ```bash 75 | cd examples/fairseq/ 76 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \ 77 | --task pretraining \ 78 | --tokens-per-sample 512 \ 79 | --mask-prob 0.15 \ 80 | --span-length 3.0 \ 81 | --leave-unmasked-prob 0.0 \ 82 | --random-token-prob 0.0 \ 83 | --criterion masked_lm \ 84 | --arch mlm_base \ 85 | --share-encoder-input-output-embed \ 86 | --required-batch-size-multiple 8 \ 87 | --spm-model ${PATH_TO_DATA}/sentencepiece.bpe.model \ 88 | --dict-file ${PATH_TO_DATA}/dict.txt \ 89 | --optimizer adam \ 90 | --adam-betas '(0.9,0.98)' \ 91 | --adam-eps 1e-6 \ 92 | --clip-norm 2.0 \ 93 | --lr-scheduler polynomial_decay \ 94 | --lr 0.0005 \ 95 | --warmup-updates 10000 \ 96 | --total-num-update 125000 \ 97 | --max-update 125000 \ 98 | --max-sentences 32 \ 99 | --update-freq 1 \ 100 | --log-format simple \ 101 | --log-interval 100 \ 102 | --disable-validation \ 103 | --save-interval-updates 5000 \ 104 | --no-epoch-checkpoints \ 105 | --fp16 \ 106 | --fp16-init-scale 4 \ 107 | --fp16-scale-window 256 \ 108 | --min-loss-scale 0.0001 \ 109 | --seed 1 \ 110 | --save-dir ${PATH_TO_CKPT} \ 111 | --ddp-backend=no_c10d \ 112 | --distributed-no-spawn \ 113 | --reset-dataloader \ 114 | --batch-read-ahead 10000 \ 115 | --rel-pos-buckets 32 \ 116 | --max-rel-pos 128 \ 117 | --deepnorm 118 | ``` 119 | 120 | ## Example: GPT Pretraining 121 | 122 | ### Data Format 123 | 124 | We use the format as in the FairSeq's [language modeling example](https://github.com/facebookresearch/fairseq/tree/main/examples/language_model#1-preprocess-the-data). 125 | 126 | ### Dense Model 127 | 128 | ```bash 129 | cd examples/fairseq/ 130 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \ 131 | ${PATH_TO_DATA} \ 132 | --num-workers 2 \ 133 | --activation-fn gelu \ 134 | --share-decoder-input-output-embed \ 135 | --validate-interval-updates 1000 \ 136 | --save-interval-updates 1000 \ 137 | --no-epoch-checkpoints \ 138 | --memory-efficient-fp16 \ 139 | --fp16-init-scale 4 \ 140 | --arch lm_base \ 141 | --task language_modeling \ 142 | --sample-break-mode none \ 143 | --tokens-per-sample 128 \ 144 | --optimizer adam --adam-betas "(0.9, 0.98)" \ 145 | --adam-eps 1e-08 \ 146 | --clip-norm 0.0 \ 147 | --lr 5e-4 \ 148 | --lr-scheduler polynomial_decay \ 149 | --warmup-updates 750 \ 150 | --dropout 0.1 \ 151 | --attention-dropout 0.1 \ 152 | --weight-decay 0.01 \ 153 | --batch-size 4 \ 154 | --update-freq 1 \ 155 | --required-batch-size-multiple 1 \ 156 | --total-num-update 50000 \ 157 | --max-update 50000 \ 158 | --seed 1 \ 159 | --ddp-backend=c10d 160 | ``` 161 | 162 | ### Sparse (MoE) Model 163 | 164 | ```bash 165 | cd examples/fairseq/ 166 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \ 167 | ${PATH_TO_DATA} \ 168 | --num-workers 2 \ 169 | --activation-fn gelu \ 170 | --share-decoder-input-output-embed \ 171 | --validate-interval-updates 1000 \ 172 | --save-interval-updates 1000 \ 173 | --no-epoch-checkpoints \ 174 | --memory-efficient-fp16 \ 175 | --fp16-init-scale 4 \ 176 | --arch lm_base \ 177 | --task language_modeling \ 178 | --sample-break-mode none \ 179 | --tokens-per-sample 128 \ 180 | --optimizer adam --adam-betas "(0.9, 0.98)" \ 181 | --adam-eps 1e-08 \ 182 | --clip-norm 0.0 \ 183 | --lr 5e-4 \ 184 | --lr-scheduler polynomial_decay \ 185 | --warmup-updates 750 \ 186 | --dropout 0.1 \ 187 | --attention-dropout 0.1 \ 188 | --weight-decay 0.01 \ 189 | --batch-size 4 \ 190 | --update-freq 1 \ 191 | --required-batch-size-multiple 1 \ 192 | --total-num-update 50000 \ 193 | --max-update 50000 \ 194 | --seed 1 \ 195 | --ddp-backend=no_c10d \ 196 | --moe-expert-count 2 --moe-freq 2 \ 197 | --moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \ 198 | --moe-eval-capacity-token-fraction -1.0 \ 199 | --criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \ 200 | --use-xmoe 201 | ``` 202 | 203 | ## Example: Machine Translation 204 | 205 | ### Data Format 206 | 207 | We follow the FairSeq's [neural machine translation example](https://github.com/facebookresearch/fairseq/tree/main/examples/translation#training-a-new-model) to preprocess the data. 208 | 209 | ### Dense Model 210 | 211 | ```bash 212 | cd examples/fairseq/ 213 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \ 214 | ${PATH_TO_DATA} \ 215 | --arch mt_base --share-decoder-input-output-embed \ 216 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 217 | --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ 218 | --dropout 0.3 --weight-decay 0.0001 \ 219 | --max-tokens 4096 --fp16 220 | ``` 221 | 222 | ### Sparse (MoE) Model 223 | 224 | ```bash 225 | cd examples/fairseq/ 226 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \ 227 | ${PATH_TO_DATA} \ 228 | --arch mt_base --share-decoder-input-output-embed \ 229 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 230 | --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ 231 | --dropout 0.3 --weight-decay 0.0001 \ 232 | --moe-expert-count 2 --moe-freq 2 \ 233 | --moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \ 234 | --moe-eval-capacity-token-fraction -1.0 \ 235 | --criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \ 236 | --use-xmoe \ 237 | --max-tokens 4096 --fp16 238 | ``` 239 | -------------------------------------------------------------------------------- /examples/fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /examples/fairseq/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | # flake8: noqa 5 | import models 6 | import tasks 7 | from fairseq_cli.generate import cli_main 8 | 9 | if __name__ == "__main__": 10 | cli_main() 11 | -------------------------------------------------------------------------------- /examples/fairseq/interactive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | # flake8: noqa 5 | import models 6 | import tasks 7 | from fairseq_cli.interactive import cli_main 8 | 9 | if __name__ == "__main__": 10 | cli_main() 11 | -------------------------------------------------------------------------------- /examples/fairseq/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import argparse 5 | import importlib 6 | import os 7 | 8 | MODEL_REGISTRY = {} 9 | MODEL_DATACLASS_REGISTRY = {} 10 | ARCH_MODEL_REGISTRY = {} 11 | ARCH_MODEL_NAME_REGISTRY = {} 12 | ARCH_MODEL_INV_REGISTRY = {} 13 | ARCH_CONFIG_REGISTRY = {} 14 | 15 | # automatically import any Python files in the models/ directory 16 | models_dir = os.path.dirname(__file__) 17 | for file in os.listdir(models_dir): 18 | path = os.path.join(models_dir, file) 19 | if ( 20 | not file.startswith("_") 21 | and not file.startswith(".") 22 | and (file.endswith(".py") or os.path.isdir(path)) 23 | ): 24 | model_name = file[: file.find(".py")] if file.endswith(".py") else file 25 | module = importlib.import_module("models." + model_name) 26 | 27 | # extra `model_parser` for sphinx 28 | if model_name in MODEL_REGISTRY: 29 | parser = argparse.ArgumentParser(add_help=False) 30 | group_archs = parser.add_argument_group("Named architectures") 31 | group_archs.add_argument( 32 | "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name] 33 | ) 34 | group_args = parser.add_argument_group("Additional command-line arguments") 35 | MODEL_REGISTRY[model_name].add_args(group_args) 36 | globals()[model_name + "_parser"] = parser 37 | -------------------------------------------------------------------------------- /examples/fairseq/models/bert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import logging 5 | from dataclasses import dataclass, field 6 | from typing import Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from apex.normalization import FusedLayerNorm as LayerNorm 12 | from fairseq import utils 13 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass 14 | from fairseq.models import BaseFairseqModel, register_model, register_model_architecture 15 | from fairseq.models.squad import SQuADHead 16 | from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding 17 | from fairseq.modules import PositionalEmbedding 18 | from omegaconf import II 19 | 20 | from torchscale.architecture.config import EncoderConfig 21 | 22 | from .machine_translation import MTEncoder as Encoder 23 | 24 | DEFAULT_MAX_SOURCE_POSITIONS = 1024 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | @dataclass 30 | class BertConfig(FairseqDataclass): 31 | activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( 32 | default="relu", metadata={"help": "activation function to use"} 33 | ) 34 | dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) 35 | attention_dropout: float = field( 36 | default=0.0, metadata={"help": "dropout probability for attention weights"} 37 | ) 38 | activation_dropout: float = field( 39 | default=0.0, metadata={"help": "dropout probability after activation in FFN."} 40 | ) 41 | encoder_embed_dim: int = field( 42 | default=512, metadata={"help": "encoder embedding dimension"} 43 | ) 44 | encoder_output_dim: int = field( 45 | default=512, metadata={"help": "encoder output dimension"} 46 | ) 47 | encoder_input_dim: int = field( 48 | default=512, metadata={"help": "encoder input dimension"} 49 | ) 50 | encoder_ffn_embed_dim: int = field( 51 | default=2048, metadata={"help": "encoder embedding dimension for FFN"} 52 | ) 53 | encoder_layers: int = field(default=6, metadata={"help": "num encoder layers"}) 54 | encoder_attention_heads: int = field( 55 | default=8, metadata={"help": "num encoder attention heads"} 56 | ) 57 | encoder_normalize_before: bool = field( 58 | default=False, metadata={"help": "apply layernorm before each encoder block"} 59 | ) 60 | no_encoder_final_norm: bool = field( 61 | default=False, 62 | metadata={"help": "don't add an extra layernorm after the last encoder block"}, 63 | ) 64 | no_token_positional_embeddings: bool = field( 65 | default=False, 66 | metadata={ 67 | "help": "if set, disables positional embeddings (outside self attention)" 68 | }, 69 | ) 70 | share_encoder_input_output_embed: bool = field( 71 | default=False, metadata={"help": "share encoder input and output embeddings"} 72 | ) 73 | encoder_learned_pos: bool = field( 74 | default=False, 75 | metadata={"help": "use learned positional embeddings in the encoder"}, 76 | ) 77 | layernorm_embedding: bool = field( 78 | default=False, metadata={"help": "add layernorm to embedding"} 79 | ) 80 | no_scale_embedding: bool = field( 81 | default=False, metadata={"help": "if True, dont scale embeddings"} 82 | ) 83 | checkpoint_activations: bool = field( 84 | default=False, metadata={"help": "checkpoint activations at each layer"} 85 | ) 86 | offload_activations: bool = field( 87 | default=False, 88 | metadata={"help": "move checkpointed activations to CPU after they are used."}, 89 | ) 90 | # config for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) 91 | encoder_layerdrop: float = field( 92 | default=0.0, metadata={"help": "LayerDrop probability for encoder"} 93 | ) 94 | encoder_layers_to_keep: Optional[str] = field( 95 | default=None, 96 | metadata={ 97 | "help": "which layers to *keep* when pruning as a comma-separated list" 98 | }, 99 | ) 100 | # config for Fully Sharded Data Parallel (FSDP) training 101 | min_params_to_wrap: int = field( 102 | default=DEFAULT_MIN_PARAMS_TO_WRAP, 103 | metadata={ 104 | "help": ( 105 | "minimum number of params for a layer to be wrapped with FSDP() when " 106 | "training with --ddp-backend=fully_sharded. Smaller values will " 107 | "improve memory efficiency, but may make torch.distributed " 108 | "communication less efficient due to smaller input sizes. This option " 109 | "is set to 0 (i.e., always wrap) when --checkpoint-activations or " 110 | "--offload-activations are passed." 111 | ) 112 | }, 113 | ) 114 | max_source_positions: int = field( 115 | default=1024, metadata={"help": "max source positions"} 116 | ) 117 | pooler_activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( 118 | default="relu", metadata={"help": "activation function to use for pooler layer"} 119 | ) 120 | pooler_dropout: float = field( 121 | default=0.0, 122 | metadata={"help": "dropout probability in the masked_lm pooler layers"}, 123 | ) 124 | # options from other parts of the config 125 | # add_bos_token: bool = II("task.add_bos_token") 126 | # tokens_per_sample: int = II("task.tokens_per_sample") 127 | tpu: bool = II("common.tpu") 128 | rel_pos_buckets: int = field(default=0, metadata={"help": ""}) 129 | max_rel_pos: int = field(default=0, metadata={"help": ""}) 130 | moe_freq: int = field( 131 | default=0, 132 | metadata={"help": "Frequency at which we insert MoE Transformer layers"}, 133 | ) 134 | moe_expert_count: int = field( 135 | default=0, metadata={"help": "Number of experts in each MoE Layer"} 136 | ) 137 | moe_gating_use_fp32: bool = field( 138 | default=False, 139 | metadata={"help": "Use FP32 computations in MoE top2 gating function"}, 140 | ) 141 | moe_second_expert_policy: str = field( 142 | default="sampling", 143 | metadata={"help": "policy for second expert, options: all/sampling/random"}, 144 | ) 145 | moe_normalize_gate_prob_before_dropping: bool = field( 146 | default=False, 147 | metadata={ 148 | "help": "whether to normalize gate probs before or after dropping experts for capacity and randomization" 149 | }, 150 | ) 151 | moe_expert_ffn_dim: Optional[int] = field( 152 | default=None, metadata={"help": "MoE expert FFN dimension"} 153 | ) 154 | moe_top1_expert: Optional[bool] = field( 155 | default=False, metadata={"help": "Use top1 gate instead of top2"} 156 | ) 157 | moe_eval_capacity_token_fraction: Optional[float] = field( 158 | default=0.25, 159 | metadata={ 160 | "help": ( 161 | "Default: 0.25, Fraction of tokens as capacity during validation, " 162 | "if set to negative, use same as training. range: (0.0, 1.0]." 163 | ) 164 | }, 165 | ) 166 | moe_normalize_expert_grad: Optional[str] = field( 167 | default="world_size", 168 | metadata={ 169 | "help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'" 170 | }, 171 | ) 172 | record_a2a_perf_stats: Optional[bool] = field( 173 | default=False, 174 | metadata={"help": "records all to all perf stats during distributed training"}, 175 | ) 176 | dummy_a2a: Optional[bool] = field( 177 | default=False, 178 | metadata={ 179 | "help": "By passes all to all during distributed training by returning the input buffer as output" 180 | }, 181 | ) 182 | moe_batch_prioritized_routing: Optional[bool] = field( 183 | default=False, 184 | metadata={ 185 | "help": "if true orders token by the gate prob before capacity dropping." 186 | }, 187 | ) 188 | ddp_rank: int = II("distributed_training.distributed_rank") 189 | deepnorm: Optional[bool] = field( 190 | default=False, 191 | ) 192 | subln: Optional[bool] = field( 193 | default=False, 194 | ) 195 | 196 | 197 | @register_model("mlm", dataclass=BertConfig) 198 | class BertModel(BaseFairseqModel): 199 | def __init__(self, args, encoder): 200 | super().__init__() 201 | self.args = args 202 | self.encoder = encoder 203 | self.padding_idx = self.encoder.embed_tokens.padding_idx 204 | self.classification_heads = nn.ModuleDict() 205 | 206 | @classmethod 207 | def build_model(cls, args, task): 208 | """Build a new model instance.""" 209 | 210 | args.max_source_positions = getattr( 211 | args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS 212 | ) 213 | 214 | embed_tokens = cls.build_embedding( 215 | args, task.dictionary, args.encoder_embed_dim 216 | ) 217 | 218 | embed_positions = ( 219 | PositionalEmbedding( 220 | args.max_source_positions, 221 | args.encoder_embed_dim, 222 | task.dictionary.pad(), 223 | learned=args.encoder_learned_pos, 224 | ) 225 | if not args.no_token_positional_embeddings 226 | else None 227 | ) 228 | 229 | lm_head = cls.build_lm_head( 230 | args, 231 | args.encoder_embed_dim, 232 | len(task.dictionary), 233 | args.activation_fn, 234 | weight=embed_tokens.weight, 235 | ) 236 | 237 | config = EncoderConfig() 238 | config.override(args) 239 | 240 | encoder = Encoder( 241 | config, 242 | embed_tokens=embed_tokens, 243 | embed_positions=embed_positions, 244 | output_projection=lm_head, 245 | is_encoder_decoder=False, 246 | dictionary=task.dictionary, 247 | ) 248 | 249 | return cls(args, encoder) 250 | 251 | @classmethod 252 | def build_embedding(cls, args, dictionary, embed_dim, path=None): 253 | embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad()) 254 | return embed_tokens 255 | 256 | @classmethod 257 | def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight): 258 | return LMHead(embed_dim, output_dim, activation_fn, weight) 259 | 260 | def output_layer(self, features, masked_tokens=None): 261 | return self.encoder.output_projection(features, masked_tokens=masked_tokens) 262 | 263 | def register_classification_head( 264 | self, name, num_classes=None, inner_dim=None, **kwargs 265 | ): 266 | """Register a classification head.""" 267 | if name in self.classification_heads: 268 | prev_num_classes = self.classification_heads[name].out_proj.out_features 269 | prev_inner_dim = self.classification_heads[name].dense.out_features 270 | if num_classes != prev_num_classes or inner_dim != prev_inner_dim: 271 | logger.warning( 272 | 're-registering head "{}" with num_classes {} (prev: {}) ' 273 | "and inner_dim {} (prev: {})".format( 274 | name, num_classes, prev_num_classes, inner_dim, prev_inner_dim 275 | ) 276 | ) 277 | self.classification_heads[name] = ClassificationHead( 278 | self.args.encoder_embed_dim, 279 | inner_dim or self.args.encoder_embed_dim, 280 | num_classes, 281 | self.args.pooler_activation_fn, 282 | self.args.pooler_dropout, 283 | ) 284 | 285 | def register_question_answering_head(self, name, num_classes=None): 286 | self.classification_heads[name] = SQuADHead( 287 | self.args.encoder_embed_dim, 288 | ) 289 | 290 | def upgrade_state_dict_named(self, state_dict, name): 291 | prefix = name + "." if name != "" else "" 292 | 293 | # upgrade children modules 294 | super().upgrade_state_dict_named(state_dict, name) 295 | 296 | # Handle new classification heads present in the state dict. 297 | current_head_names = ( 298 | [] 299 | if not hasattr(self, "classification_heads") 300 | else self.classification_heads.keys() 301 | ) 302 | keys_to_delete = [] 303 | for k in state_dict.keys(): 304 | if not k.startswith(prefix + "classification_heads."): 305 | continue 306 | 307 | head_name = k[len(prefix + "classification_heads.") :].split(".")[0] # noqa: E203 308 | num_classes = state_dict[ 309 | prefix + "classification_heads." + head_name + ".out_proj.weight" 310 | ].size(0) 311 | inner_dim = state_dict[ 312 | prefix + "classification_heads." + head_name + ".dense.weight" 313 | ].size(0) 314 | 315 | if getattr(self.args, "load_checkpoint_heads", False): 316 | if head_name not in current_head_names: 317 | self.register_classification_head(head_name, num_classes, inner_dim) 318 | else: 319 | if head_name not in current_head_names: 320 | logger.warning( 321 | "deleting classification head ({}) from checkpoint " 322 | "not present in current model: {}".format(head_name, k) 323 | ) 324 | keys_to_delete.append(k) 325 | elif ( 326 | num_classes 327 | != self.classification_heads[head_name].out_proj.out_features 328 | or inner_dim 329 | != self.classification_heads[head_name].dense.out_features 330 | ): 331 | logger.warning( 332 | "deleting classification head ({}) from checkpoint " 333 | "with different dimensions than current model: {}".format( 334 | head_name, k 335 | ) 336 | ) 337 | keys_to_delete.append(k) 338 | for k in keys_to_delete: 339 | del state_dict[k] 340 | 341 | # Copy any newly-added classification heads into the state dict 342 | # with their current weights. 343 | if hasattr(self, "classification_heads"): 344 | cur_state = self.classification_heads.state_dict() 345 | for k, v in cur_state.items(): 346 | if prefix + "classification_heads." + k not in state_dict: 347 | logger.info("Overwriting " + prefix + "classification_heads." + k) 348 | state_dict[prefix + "classification_heads." + k] = v 349 | 350 | def forward( 351 | self, 352 | src_tokens=None, 353 | features_only=False, 354 | return_all_hiddens=False, 355 | classification_head_name=None, 356 | masked_tokens=None, 357 | **kwargs 358 | ): 359 | encoder_out = self.encoder( 360 | src_tokens, features_only=True, return_all_hiddens=return_all_hiddens 361 | ) 362 | x, extra = encoder_out["encoder_out"], encoder_out 363 | x = x.transpose(0, 1) 364 | 365 | if classification_head_name is not None: 366 | x = self.classification_heads[classification_head_name](x) 367 | elif not features_only: 368 | x = self.output_layer(x, masked_tokens=masked_tokens) 369 | 370 | return x, extra 371 | 372 | 373 | class ClassificationHead(nn.Module): 374 | """Head for sentence-level classification tasks.""" 375 | 376 | def __init__( 377 | self, 378 | input_dim, 379 | inner_dim, 380 | num_classes, 381 | activation_fn, 382 | pooler_dropout, 383 | ): 384 | super().__init__() 385 | self.dense = nn.Linear(input_dim, inner_dim) 386 | self.activation_fn = utils.get_activation_fn(activation_fn) 387 | self.dropout = nn.Dropout(p=pooler_dropout) 388 | self.out_proj = nn.Linear(inner_dim, num_classes) 389 | 390 | def forward(self, features, **kwargs): 391 | x = features[:, 0, :] # take token (equiv. to [CLS]) 392 | x = self.dropout(x) 393 | x = self.dense(x) 394 | x = self.activation_fn(x.float()).type_as(x) 395 | x = self.dropout(x) 396 | x = self.out_proj(x) 397 | return x 398 | 399 | 400 | class LMHead(nn.Module): 401 | """Head for masked language modeling.""" 402 | 403 | def __init__(self, embed_dim, output_dim, activation_fn, weight=None): 404 | super().__init__() 405 | self.dense = nn.Linear(embed_dim, embed_dim) 406 | self.activation_fn = utils.get_activation_fn(activation_fn) 407 | self.layer_norm = LayerNorm(embed_dim) 408 | 409 | if weight is None: 410 | weight = nn.Linear(embed_dim, output_dim, bias=False).weight 411 | self.weight = weight 412 | self.bias = nn.Parameter(torch.zeros(output_dim)) 413 | 414 | def forward(self, features, masked_tokens=None, **kwargs): 415 | # Only project the masked tokens while training, 416 | # saves both memory and computation 417 | if masked_tokens is not None: 418 | features = features[masked_tokens, :] 419 | 420 | x = self.dense(features) 421 | x = self.activation_fn(x.float()).type_as(x) 422 | x = self.layer_norm(x) 423 | # project back to size of vocabulary with bias 424 | x = F.linear(x, self.weight) + self.bias 425 | return x 426 | 427 | 428 | @register_model_architecture("mlm", "mlm_base") 429 | def base_unilm_architecture(args): 430 | if hasattr(args, "encoder_final_norm"): 431 | args.no_encoder_final_norm = not args.encoder_final_norm 432 | 433 | args.dropout = getattr(args, "dropout", 0.1) 434 | args.attention_dropout = getattr(args, "attention_dropout", 0.0) 435 | args.activation_dropout = getattr(args, "activation_dropout", 0.0) 436 | args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) 437 | 438 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) 439 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) 440 | args.encoder_layers = getattr(args, "encoder_layers", 12) 441 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) 442 | args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) 443 | args.activation_fn = getattr(args, "activation_fn", "gelu") 444 | args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") 445 | 446 | args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) 447 | args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) 448 | 449 | # args.add_bos_token = getattr(args, "add_bos_token", False) 450 | args.no_token_positional_embeddings = getattr( 451 | args, "no_token_positional_embeddings", False 452 | ) 453 | args.share_encoder_input_output_embed = getattr( 454 | args, "share_encoder_input_output_embed", True 455 | ) 456 | args.encoder_output_dim = getattr( 457 | args, "encoder_output_dim", args.encoder_embed_dim 458 | ) 459 | args.encoder_input_dim = getattr(args, "encoder_input_dim", args.encoder_embed_dim) 460 | 461 | # Model training is not stable without this 462 | args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) 463 | args.no_encoder_final_norm = getattr(args, "no_encoder_final_norm", False) 464 | 465 | args.no_scale_embedding = getattr(args, "no_scale_embedding", True) 466 | args.layernorm_embedding = getattr(args, "layernorm_embedding", True) 467 | args.checkpoint_activations = getattr(args, "checkpoint_activations", False) 468 | args.offload_activations = getattr(args, "offload_activations", False) 469 | if args.offload_activations: 470 | args.checkpoint_activations = True 471 | -------------------------------------------------------------------------------- /examples/fairseq/models/language_modeling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import logging 10 | from dataclasses import dataclass, field 11 | from typing import Optional 12 | 13 | import torch 14 | from fairseq import distributed_utils, utils 15 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass 16 | from fairseq.models import ( 17 | FairseqIncrementalDecoder, 18 | FairseqLanguageModel, 19 | register_model, 20 | register_model_architecture, 21 | ) 22 | from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding 23 | from fairseq.modules import PositionalEmbedding 24 | from omegaconf import II 25 | 26 | from torchscale.architecture.config import DecoderConfig 27 | from torchscale.architecture.decoder import Decoder 28 | 29 | DEFAULT_MAX_TARGET_POSITIONS = 4096 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | @dataclass 34 | class LanguageConfig(FairseqDataclass): 35 | activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( 36 | default="relu", metadata={"help": "activation function to use"} 37 | ) 38 | dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) 39 | attention_dropout: float = field( 40 | default=0.0, metadata={"help": "dropout probability for attention weights"} 41 | ) 42 | activation_dropout: float = field( 43 | default=0.0, metadata={"help": "dropout probability after activation in FFN."} 44 | ) 45 | relu_dropout: float = field( 46 | default=0.0, metadata={"help": "dropout probability after activation in FFN."} 47 | ) 48 | decoder_embed_dim: int = field( 49 | default=512, metadata={"help": "decoder embedding dimension"} 50 | ) 51 | decoder_output_dim: int = field( 52 | default=512, metadata={"help": "decoder output dimension"} 53 | ) 54 | decoder_input_dim: int = field( 55 | default=512, metadata={"help": "decoder input dimension"} 56 | ) 57 | decoder_ffn_embed_dim: int = field( 58 | default=2048, metadata={"help": "decoder embedding dimension for FFN"} 59 | ) 60 | decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"}) 61 | decoder_attention_heads: int = field( 62 | default=8, metadata={"help": "num decoder attention heads"} 63 | ) 64 | decoder_normalize_before: bool = field( 65 | default=False, metadata={"help": "apply layernorm before each decoder block"} 66 | ) 67 | no_token_positional_embeddings: bool = field( 68 | default=False, 69 | metadata={ 70 | "help": "if set, disables positional embeddings (outside self attention)" 71 | }, 72 | ) 73 | share_decoder_input_output_embed: bool = field( 74 | default=False, metadata={"help": "share decoder input and output embeddings"} 75 | ) 76 | decoder_learned_pos: bool = field( 77 | default=False, 78 | metadata={"help": "use learned positional embeddings in the decoder"}, 79 | ) 80 | layernorm_embedding: bool = field( 81 | default=False, metadata={"help": "add layernorm to embedding"} 82 | ) 83 | no_scale_embedding: bool = field( 84 | default=False, metadata={"help": "if True, dont scale embeddings"} 85 | ) 86 | checkpoint_activations: bool = field( 87 | default=False, metadata={"help": "checkpoint activations at each layer"} 88 | ) 89 | offload_activations: bool = field( 90 | default=False, 91 | metadata={"help": "move checkpointed activations to CPU after they are used."}, 92 | ) 93 | # config for Fully Sharded Data Parallel (FSDP) training 94 | min_params_to_wrap: int = field( 95 | default=DEFAULT_MIN_PARAMS_TO_WRAP, 96 | metadata={ 97 | "help": ( 98 | "minimum number of params for a layer to be wrapped with FSDP() when " 99 | "training with --ddp-backend=fully_sharded. Smaller values will " 100 | "improve memory efficiency, but may make torch.distributed " 101 | "communication less efficient due to smaller input sizes. This option " 102 | "is set to 0 (i.e., always wrap) when --checkpoint-activations or " 103 | "--offload-activations are passed." 104 | ) 105 | }, 106 | ) 107 | moe_freq: int = field( 108 | default=0, 109 | metadata={"help": "Frequency at which we insert MoE Transformer layers"}, 110 | ) 111 | moe_expert_count: int = field( 112 | default=0, metadata={"help": "Number of experts in each MoE Layer"} 113 | ) 114 | moe_gating_use_fp32: bool = field( 115 | default=False, 116 | metadata={"help": "Use FP32 computations in MoE top2 gating function"}, 117 | ) 118 | moe_second_expert_policy: str = field( 119 | default="sampling", 120 | metadata={"help": "policy for second expert, options: all/sampling/random"}, 121 | ) 122 | moe_normalize_gate_prob_before_dropping: bool = field( 123 | default=False, 124 | metadata={ 125 | "help": "whether to normalize gate probs before or after dropping experts for capacity and randomization" 126 | }, 127 | ) 128 | moe_expert_ffn_dim: Optional[int] = field( 129 | default=None, metadata={"help": "MoE expert FFN dimension"} 130 | ) 131 | moe_top1_expert: Optional[bool] = field( 132 | default=False, metadata={"help": "Use top1 gate instead of top2"} 133 | ) 134 | moe_eval_capacity_token_fraction: Optional[float] = field( 135 | default=0.25, 136 | metadata={ 137 | "help": ( 138 | "Default: 0.25, Fraction of tokens as capacity during validation, " 139 | "if set to negative, use same as training. range: (0.0, 1.0]." 140 | ) 141 | }, 142 | ) 143 | moe_normalize_expert_grad: Optional[str] = field( 144 | default="world_size", 145 | metadata={ 146 | "help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'" 147 | }, 148 | ) 149 | record_a2a_perf_stats: Optional[bool] = field( 150 | default=False, 151 | metadata={"help": "records all to all perf stats during distributed training"}, 152 | ) 153 | dummy_a2a: Optional[bool] = field( 154 | default=False, 155 | metadata={ 156 | "help": "By passes all to all during distributed training by returning the input buffer as output" 157 | }, 158 | ) 159 | moe_batch_prioritized_routing: Optional[bool] = field( 160 | default=False, 161 | metadata={ 162 | "help": "if true orders token by the gate prob before capacity dropping." 163 | }, 164 | ) 165 | use_xmoe: Optional[bool] = field( 166 | default=False, 167 | ) 168 | 169 | # options from other parts of the config 170 | add_bos_token: bool = II("task.add_bos_token") 171 | tokens_per_sample: int = II("task.tokens_per_sample") 172 | max_target_positions: Optional[int] = II("task.max_target_positions") 173 | tpu: bool = II("common.tpu") 174 | memory_efficient_fp16: bool = II("common.memory_efficient_fp16") 175 | fp16: bool = II("common.fp16") 176 | fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads") 177 | ddp_backend: str = II("distributed_training.ddp_backend") 178 | world_size: int = II("distributed_training.distributed_world_size") 179 | distributed_rank: int = II("distributed_training.distributed_rank") 180 | ddp_rank: int = II("distributed_training.distributed_rank") 181 | deepnorm: Optional[bool] = field( 182 | default=False, 183 | ) 184 | subln: Optional[bool] = field( 185 | default=False, 186 | ) 187 | xpos_rel_pos: Optional[bool] = field( 188 | default=False, 189 | metadata={"help": "use XPos as the relative position embhedding"}, 190 | ) 191 | block_size: Optional[int] = field( 192 | default=2048, 193 | ) 194 | rel_pos_buckets: Optional[int] = field( 195 | default=0, 196 | ) 197 | max_rel_pos: Optional[int] = field( 198 | default=0, 199 | ) 200 | 201 | 202 | @register_model("lm", dataclass=LanguageConfig) 203 | class LanguageModel(FairseqLanguageModel): 204 | def __init__(self, args, decoder): 205 | self.args = args 206 | super().__init__(decoder) 207 | 208 | @classmethod 209 | def build_model(cls, args, task): 210 | 211 | if getattr(args, "max_target_positions", None) is None: 212 | args.max_target_positions = getattr( 213 | args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS 214 | ) 215 | 216 | embed_tokens = cls.build_embedding( 217 | args, task.source_dictionary, args.decoder_embed_dim 218 | ) 219 | 220 | embed_positions = ( 221 | PositionalEmbedding( 222 | args.max_target_positions, 223 | args.decoder_embed_dim, 224 | task.dictionary.pad(), 225 | learned=args.decoder_learned_pos, 226 | ) 227 | if not args.no_token_positional_embeddings 228 | else None 229 | ) 230 | 231 | if args.share_decoder_input_output_embed: 232 | output_projection = torch.nn.Linear( 233 | embed_tokens.weight.shape[1], 234 | embed_tokens.weight.shape[0], 235 | bias=False, 236 | ) 237 | output_projection.weight = embed_tokens.weight 238 | else: 239 | output_projection = torch.nn.Linear( 240 | args.decoder_embed_dim, len(task.dictionary), bias=False 241 | ) 242 | torch.nn.init.normal_( 243 | output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5 244 | ) 245 | 246 | if getattr(args, "moe_freq", 0) > 0 and ( 247 | getattr(args, "fp16", False) 248 | and not getattr(args, "memory_efficient_fp16", False) 249 | and getattr(args, "ddp_backend", None) != "fully_sharded" 250 | ): 251 | assert ( 252 | args.fp16_no_flatten_grads 253 | ), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm" 254 | 255 | args.ddp_rank = distributed_utils.get_data_parallel_rank() 256 | 257 | config = DecoderConfig() 258 | config.override(args) 259 | 260 | decoder = LMDecoder( 261 | config, 262 | embed_tokens, 263 | embed_positions, 264 | output_projection, 265 | is_encoder_decoder=False, 266 | dictionary=task.dictionary, 267 | ) 268 | 269 | return cls(args, decoder) 270 | 271 | @classmethod 272 | def build_embedding(cls, args, dictionary, embed_dim, path=None): 273 | return Embedding(len(dictionary), embed_dim, dictionary.pad()) 274 | 275 | 276 | class LMDecoder(Decoder, FairseqIncrementalDecoder): 277 | def forward(self, src_tokens, **kwargs): 278 | self_attn_padding_mask = src_tokens.eq(self.dictionary.pad()) 279 | return super().forward(src_tokens, self_attn_padding_mask, **kwargs) 280 | 281 | def max_positions(self): 282 | if self.embed_positions is not None: 283 | return self.embed_positions.max_positions 284 | else: 285 | return DEFAULT_MAX_TARGET_POSITIONS 286 | 287 | def reorder_incremental_state_scripting( 288 | self, 289 | incremental_state, 290 | new_order, 291 | ): 292 | for module in incremental_state: 293 | for key in incremental_state[module]: 294 | result = incremental_state[module][key].index_select(0, new_order) 295 | incremental_state[module][key] = result 296 | 297 | 298 | @register_model_architecture("lm", "lm_base") 299 | def base_lm_architecture(args): 300 | # backward compatibility for older model checkpoints 301 | if hasattr(args, "no_tie_adaptive_proj"): 302 | # previous models defined --no-tie-adaptive-proj, so use the existence of 303 | # that option to determine if this is an "old" model checkpoint 304 | args.no_decoder_final_norm = True # old models always set this to True 305 | if args.no_tie_adaptive_proj is False: 306 | args.tie_adaptive_proj = True 307 | if hasattr(args, "decoder_final_norm"): 308 | args.no_decoder_final_norm = not args.decoder_final_norm 309 | 310 | args.dropout = getattr(args, "dropout", 0.1) 311 | args.attention_dropout = getattr(args, "attention_dropout", 0.0) 312 | 313 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) 314 | args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) 315 | args.decoder_layers = getattr(args, "decoder_layers", 6) 316 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) 317 | args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) 318 | args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) 319 | args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) 320 | args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) 321 | args.activation_fn = getattr(args, "activation_fn", "relu") 322 | 323 | args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) 324 | args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) 325 | 326 | args.base_layers = getattr(args, "base_layers", 0) 327 | args.base_sublayers = getattr(args, "base_sublayers", 1) 328 | args.base_shuffle = getattr(args, "base_shuffle", False) 329 | 330 | args.add_bos_token = getattr(args, "add_bos_token", False) 331 | args.no_token_positional_embeddings = getattr( 332 | args, "no_token_positional_embeddings", True 333 | ) 334 | args.xpos_rel_pos = getattr( 335 | args, "xpos_rel_pos", True 336 | ) 337 | args.block_size = getattr( 338 | args, "block_size", 2048 339 | ) 340 | args.share_decoder_input_output_embed = getattr( 341 | args, "share_decoder_input_output_embed", False 342 | ) 343 | args.character_embeddings = getattr(args, "character_embeddings", False) 344 | 345 | args.decoder_output_dim = getattr( 346 | args, "decoder_output_dim", args.decoder_embed_dim 347 | ) 348 | args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) 349 | 350 | # Model training is not stable without this 351 | args.decoder_normalize_before = True 352 | args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False) 353 | 354 | args.adaptive_input = getattr(args, "adaptive_input", False) 355 | args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) 356 | args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) 357 | 358 | args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) 359 | args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) 360 | 361 | args.no_scale_embedding = getattr(args, "no_scale_embedding", False) 362 | args.layernorm_embedding = getattr(args, "layernorm_embedding", False) 363 | args.checkpoint_activations = getattr(args, "checkpoint_activations", False) 364 | args.offload_activations = getattr(args, "offload_activations", False) 365 | if args.offload_activations: 366 | args.checkpoint_activations = True 367 | 368 | @register_model_architecture("lm", "lm_medium") 369 | def lm_medium(args): 370 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) 371 | args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) 372 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) 373 | args.decoder_layers = getattr(args, "decoder_layers", 24) 374 | base_lm_architecture(args) 375 | 376 | @register_model_architecture("lm", "lm_large") 377 | def lm_large(args): 378 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280) 379 | args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120) 380 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20) 381 | args.decoder_layers = getattr(args, "decoder_layers", 36) 382 | base_lm_architecture(args) 383 | 384 | 385 | @register_model_architecture("lm", "lm_xl") 386 | def lm_xl(args): 387 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1600) 388 | args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6400) 389 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 25) 390 | args.decoder_layers = getattr(args, "decoder_layers", 48) 391 | base_lm_architecture(args) 392 | 393 | 394 | @register_model_architecture("lm", "lm_base_abs") 395 | def lm_base_abs(args): 396 | args.xpos_rel_pos = getattr( 397 | args, "xpos_rel_pos", False 398 | ) 399 | args.no_token_positional_embeddings = getattr( 400 | args, "no_token_positional_embeddings", False 401 | ) 402 | base_lm_architecture(args) 403 | 404 | @register_model_architecture("lm", "lm_base_bucket") 405 | def lm_base_bucket(args): 406 | args.xpos_rel_pos = getattr( 407 | args, "xpos_rel_pos", False 408 | ) 409 | args.rel_pos_buckets = getattr( 410 | args, "rel_pos_buckets", 128 411 | ) 412 | args.max_rel_pos = getattr( 413 | args, "max_rel_pos", 2048 414 | ) 415 | base_lm_architecture(args) 416 | 417 | -------------------------------------------------------------------------------- /examples/fairseq/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import argparse 5 | import importlib 6 | import os 7 | 8 | # register dataclass 9 | TASK_DATACLASS_REGISTRY = {} 10 | TASK_REGISTRY = {} 11 | TASK_CLASS_NAMES = set() 12 | 13 | # automatically import any Python files in the tasks/ directory 14 | tasks_dir = os.path.dirname(__file__) 15 | for file in os.listdir(tasks_dir): 16 | path = os.path.join(tasks_dir, file) 17 | if ( 18 | not file.startswith("_") 19 | and not file.startswith(".") 20 | and (file.endswith(".py") or os.path.isdir(path)) 21 | ): 22 | task_name = file[: file.find(".py")] if file.endswith(".py") else file 23 | module = importlib.import_module("tasks." + task_name) 24 | 25 | # expose `task_parser` for sphinx 26 | if task_name in TASK_REGISTRY: 27 | parser = argparse.ArgumentParser(add_help=False) 28 | group_task = parser.add_argument_group("Task name") 29 | # fmt: off 30 | group_task.add_argument('--task', metavar=task_name, 31 | help='Enable this task with: ``--task=' + task_name + '``') 32 | # fmt: on 33 | group_args = parser.add_argument_group("Additional command-line arguments") 34 | TASK_REGISTRY[task_name].add_args(group_args) 35 | globals()[task_name + "_parser"] = parser 36 | -------------------------------------------------------------------------------- /examples/fairseq/tasks/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /examples/fairseq/tasks/data/basic_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch 5 | from infinibatch.iterators import CheckpointableIterator 6 | 7 | from . import utils 8 | 9 | 10 | class BaseBatchGen(CheckpointableIterator): 11 | """ 12 | This is a base class for batch generators that use infinibatch 13 | """ 14 | 15 | def __init__(self): 16 | self._iter = None 17 | self.epoch = 1 18 | self.next_epoch_idx = 1 19 | self.sharded_checkpoint = True 20 | self.should_close_after_finished = True 21 | 22 | def _build_iter(self): 23 | """ 24 | Build infinibatch iterator and assign to self._iter 25 | """ 26 | raise NotImplementedError() 27 | 28 | def _move_to_tensor(self, batch): 29 | def to_tensor(x): 30 | return torch.tensor(x) 31 | 32 | return utils.apply_to_sample(to_tensor, batch) 33 | 34 | @property 35 | def iterator(self): 36 | if self._iter is None: 37 | raise NotImplementedError("_build_iter() must called first") 38 | return self._iter 39 | 40 | def __iter__(self): 41 | if self._iter is None: 42 | raise NotImplementedError("_build_iter() must called first") 43 | return self._iter 44 | 45 | def __next__(self): 46 | return next(self._iter) 47 | 48 | def setstate(self, value): 49 | self._iter.setstate(value) 50 | 51 | def getstate(self): 52 | return self._iter.getstate() 53 | 54 | def close(self): 55 | self._iter.close() 56 | 57 | def __len__(self) -> int: 58 | return 819200000 59 | 60 | def next_epoch_itr( 61 | self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True 62 | ): 63 | return self 64 | 65 | def end_of_epoch(self) -> bool: 66 | return False 67 | 68 | def state_dict(self): 69 | """Returns a dictionary containing a whole state of the iterator.""" 70 | return self.getstate() 71 | 72 | def load_state_dict(self, state_dict): 73 | """Copies the state of the iterator from the given *state_dict*.""" 74 | self.setstate(state_dict) 75 | 76 | @property 77 | def first_batch(self): 78 | return "DUMMY" 79 | -------------------------------------------------------------------------------- /examples/fairseq/tasks/data/mlm_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import copy 5 | import itertools 6 | import os 7 | 8 | import numpy as np 9 | from infinibatch import iterators 10 | 11 | from .basic_loader import BaseBatchGen 12 | from .utils import NativeCheckpointableIterator, WeightIterator 13 | 14 | 15 | class MLMLoader(BaseBatchGen): 16 | def __init__( 17 | self, 18 | args, 19 | dataset, 20 | dictionary, 21 | tokenizer, 22 | max_tokens=None, 23 | max_sentences=None, 24 | max_positions=None, 25 | ignore_invalid_inputs=False, 26 | required_batch_size_multiple=1, 27 | seed=1, 28 | num_shards=1, 29 | shard_id=0, 30 | ): 31 | super().__init__() 32 | self.args = args 33 | self.data = dataset.data 34 | self.data_dir = dataset.data_dir 35 | self.shuffle = dataset.shuffle 36 | self.dictionary = dictionary 37 | self.tokenizer = tokenizer 38 | 39 | self.max_tokens = max_tokens 40 | self.max_sentences = max_sentences 41 | self.max_positions = max_positions 42 | self.tokens_per_sample = args.tokens_per_sample 43 | self.sample_break_mode = args.sample_break_mode 44 | self.ignore_invalid_inputs = ignore_invalid_inputs 45 | self.required_batch_size_multiple = required_batch_size_multiple 46 | self.seed = str(seed) 47 | self.num_shards = num_shards 48 | self.shard_id = shard_id 49 | 50 | self.batch_read_ahead = args.batch_read_ahead 51 | 52 | self._build_iter() 53 | 54 | def _build_iter(self): 55 | tokenized_lines = self._multilingual_tokenize() 56 | self.padded_batches = self._batchify(tokenized_lines) 57 | 58 | prefetch_batches = iterators.PrefetchIterator( 59 | self.padded_batches, 60 | buffer_size=10000, 61 | buffer_in_main_process=True, 62 | log_empty_buffer_warning=True and self.shard_id == 0, 63 | ) 64 | 65 | prefetch_batches = iterators.MapIterator(prefetch_batches, self._move_to_tensor) 66 | 67 | self._iter = prefetch_batches 68 | 69 | def _multilingual_tokenize(self): 70 | multilingual_iters = [] 71 | weights = [] 72 | 73 | for data in self.data: 74 | multilingual_iters.append(self._tokenize(data)) 75 | if "weight" in data: 76 | weights.append(float(data["weight"])) 77 | else: 78 | weights.append(int(data["count"])) 79 | 80 | if len(multilingual_iters) == 1: 81 | return multilingual_iters[0] 82 | 83 | sampling_iterator = WeightIterator(weights) 84 | control_iterator = NativeCheckpointableIterator(sampling_iterator) 85 | tokenized_lines = iterators.MultiplexIterator( 86 | control_iterator, multilingual_iters 87 | ) 88 | 89 | return tokenized_lines 90 | 91 | def _tokenize(self, data): 92 | """ 93 | data: 94 | { 95 | 'source': list[Path], 96 | 'source_lang': str, 97 | 'count': int, 98 | 'weight': float, 99 | 'name': str, 100 | } 101 | """ 102 | dataset = list( 103 | zip( 104 | data["source"], 105 | itertools.repeat(data["source_lang"]), 106 | ) 107 | ) 108 | 109 | if self.shuffle: 110 | chunk_files = iterators.InfinitePermutationSourceIterator( 111 | dataset, 112 | seed=self.seed, 113 | shuffle=self.shuffle, 114 | num_instances=self.num_shards, 115 | instance_rank=self.shard_id, 116 | ) 117 | else: 118 | chunk_files = iterators.ChunkedSourceIterator( 119 | dataset, 120 | num_instances=self.num_shards, 121 | instance_rank=self.shard_id, 122 | ) 123 | 124 | tokenized_lines = iterators.SelectManyIterator( 125 | chunk_files, lambda files: self._read_from_files(*files) 126 | ) 127 | tokenized_lines = iterators.SamplingRandomMapIterator( 128 | tokenized_lines, self._prepare, self.seed 129 | ) 130 | 131 | return tokenized_lines 132 | 133 | def _batchify(self, lines): 134 | 135 | if self.max_sentences is not None: 136 | if self.batch_read_ahead > 0: 137 | lines = iterators.BlockwiseShuffleIterator( 138 | lines, self.batch_read_ahead, self.seed 139 | ) 140 | batches = iterators.FixedBatchIterator(lines, self.max_sentences) 141 | else: 142 | 143 | def dynamic_batch_size(sample): 144 | lengths = [len(x) for x in sample] 145 | batch_size = self.max_tokens // max(lengths) 146 | batch_size = ( 147 | batch_size 148 | // self.required_batch_size_multiple 149 | * self.required_batch_size_multiple 150 | ) 151 | return max(1, batch_size) 152 | 153 | batches = iterators.BucketedReadaheadBatchIterator( 154 | lines, 155 | read_ahead=self.batch_read_ahead, 156 | key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None, 157 | batch_size=dynamic_batch_size, 158 | shuffle=self.shuffle, 159 | seed=self.seed, 160 | ) 161 | 162 | def collate(batch): 163 | batch_size = len(batch) 164 | 165 | mlm_source_max_length = max([len(x[0]) for x in batch]) 166 | mlm_target_max_length = max([len(x[1]) for x in batch]) 167 | s2s_source_max_length = max([len(x[2]) for x in batch]) 168 | s2s_target_max_length = max([len(x[3]) for x in batch]) 169 | 170 | mlm_source_ids = np.full( 171 | shape=(batch_size, mlm_source_max_length), 172 | dtype=np.int32, 173 | fill_value=self.dictionary.pad(), 174 | ) 175 | mlm_target_ids = np.full( 176 | shape=(batch_size, mlm_target_max_length), 177 | dtype=np.int32, 178 | fill_value=self.dictionary.pad(), 179 | ) 180 | s2s_source_ids = np.full( 181 | shape=(batch_size, s2s_source_max_length), 182 | dtype=np.int32, 183 | fill_value=self.dictionary.pad(), 184 | ) 185 | s2s_target_ids = np.full( 186 | shape=(batch_size, s2s_target_max_length - 1), 187 | dtype=np.int32, 188 | fill_value=self.dictionary.pad(), 189 | ) 190 | s2s_prev_input_ids = np.full( 191 | shape=(batch_size, s2s_target_max_length - 1), 192 | dtype=np.int32, 193 | fill_value=self.dictionary.pad(), 194 | ) 195 | 196 | for i, ( 197 | mlm_input_ids, 198 | mlm_label_ids, 199 | s2s_input_ids, 200 | s2s_label_ids, 201 | ) in enumerate(batch): 202 | mlm_source_ids[i, : len(mlm_input_ids)] = mlm_input_ids 203 | mlm_target_ids[i, : len(mlm_label_ids)] = mlm_label_ids 204 | s2s_source_ids[i, : len(s2s_input_ids)] = s2s_input_ids 205 | s2s_target_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[1:] 206 | s2s_prev_input_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[:-1] 207 | 208 | ret_batch = { 209 | "net_input": { 210 | "src_tokens": mlm_source_ids.astype(np.int64), 211 | }, 212 | "target": mlm_target_ids.astype(np.int64), 213 | "nsentences": batch_size, 214 | "ntokens": sum([len(x[0]) for x in batch]), 215 | } 216 | 217 | return ret_batch 218 | 219 | padded_batches = iterators.MapIterator(batches, collate) 220 | 221 | return padded_batches 222 | 223 | def _prepare(self, _random, doc): 224 | nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc) 225 | nonnoise_spans, noise_spans = self._span_corruption(_random, doc) 226 | return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans 227 | 228 | def _mask_lm(self, _random, doc): 229 | def mask_tokens(): 230 | return "" 231 | 232 | length = len(doc) 233 | mask_tokens_num = int(length * self.args.mask_prob) 234 | mask_tokens_num = min(max(mask_tokens_num, 1), length - 1) 235 | possible_mask_positions = _random.sample(range(length), k=mask_tokens_num) 236 | possible_mask_positions = sorted(possible_mask_positions) 237 | 238 | nonmasked_tokens = copy.deepcopy(doc) 239 | masked_tokens = [self.dictionary.pad() for _ in range(len(doc))] 240 | 241 | for position in possible_mask_positions: 242 | # masked_tokens.append(nonmasked_tokens[position]) 243 | masked_tokens[position] = nonmasked_tokens[position] 244 | nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()] 245 | 246 | return nonmasked_tokens, masked_tokens 247 | 248 | def _span_corruption(self, _random, doc): 249 | def mask_tokens(i): 250 | return f"" 251 | 252 | length = len(doc) 253 | noise_tokens_num = int(length * self.args.mask_prob) 254 | noise_tokens_num = min(max(noise_tokens_num, 1), length - 1) 255 | noise_spans_num = int(noise_tokens_num / self.args.span_length) 256 | noise_spans_num = max(noise_spans_num, 1) 257 | nonnoise_tokens_num = length - noise_tokens_num 258 | 259 | if noise_spans_num == 1: 260 | noise_split_positions = [0, noise_tokens_num] 261 | else: 262 | possible_split_positions = list(range(1, noise_tokens_num)) 263 | _random.shuffle(possible_split_positions) 264 | noise_split_positions = sorted( 265 | possible_split_positions[: noise_spans_num - 1] 266 | ) 267 | noise_split_positions = [0] + noise_split_positions + [noise_tokens_num] 268 | 269 | possible_insert_positions = list(range(nonnoise_tokens_num)) 270 | _random.shuffle(possible_insert_positions) 271 | noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num]) 272 | 273 | nonnoise_spans, noise_spans = [], [] 274 | last_end = 0 275 | for i in range(noise_spans_num): 276 | start_pos = noise_insert_positions[i] + noise_split_positions[i] 277 | end_pos = noise_insert_positions[i] + noise_split_positions[i + 1] 278 | mask_id = self.dictionary.indices[mask_tokens(i)] 279 | 280 | if getattr(self.args, "remove_target_sentinel", False): 281 | noise_spans.append(doc[start_pos:end_pos]) 282 | else: 283 | noise_spans.append([mask_id] + doc[start_pos:end_pos]) 284 | 285 | if getattr(self.args, "remove_source_sentinel", False): 286 | nonnoise_spans.extend(doc[last_end:start_pos]) 287 | else: 288 | nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id]) 289 | 290 | last_end = end_pos 291 | 292 | nonnoise_spans.extend(doc[last_end:]) 293 | noise_spans = sum(noise_spans, []) 294 | 295 | return nonnoise_spans, noise_spans 296 | 297 | def _read_from_files(self, source_file, source_lang): 298 | # data = [] 299 | file_path = os.path.join(self.data_dir, source_file) 300 | 301 | if not os.path.exists(file_path): 302 | print("| file {} not exists".format(file_path), flush=True) 303 | return iter([]) # skip bad file 304 | 305 | with open(file_path, "r", encoding="utf8") as f: 306 | lines = f.read().strip().split("\n") 307 | 308 | doc = [self.dictionary.bos()] 309 | for line in lines: 310 | if line == "": 311 | if self.sample_break_mode == "complete_doc": 312 | # data.append(doc) 313 | yield doc 314 | doc = [self.dictionary.bos()] 315 | continue 316 | 317 | tokenized_line = self.tokenizer.EncodeAsPieces(line) 318 | tokenized_id = [ 319 | self.dictionary.index(token) for token in tokenized_line 320 | ] + [self.dictionary.eos_index] 321 | 322 | if len(tokenized_id) > self.tokens_per_sample: 323 | continue 324 | if len(doc) + len(tokenized_id) > self.tokens_per_sample: 325 | # data.append(doc) 326 | yield doc 327 | doc = [self.dictionary.bos()] 328 | doc.extend(tokenized_id) 329 | 330 | if len(doc) > 1 and len(doc) <= self.tokens_per_sample: 331 | # data.append(doc) 332 | yield doc 333 | 334 | # return data 335 | -------------------------------------------------------------------------------- /examples/fairseq/tasks/data/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import collections 5 | from random import Random 6 | from typing import Dict, Iterable, Optional 7 | 8 | import numpy as np 9 | from infinibatch import iterators 10 | 11 | 12 | def apply_to_sample(f, sample): 13 | if hasattr(sample, "__len__") and len(sample) == 0: 14 | return {} 15 | 16 | def _apply(x): 17 | if isinstance(x, np.ndarray): 18 | return f(x) 19 | elif isinstance(x, collections.OrderedDict): 20 | # OrderedDict has attributes that needs to be preserved 21 | od = collections.OrderedDict( 22 | (key, _apply(value)) for key, value in x.items() 23 | ) 24 | od.__dict__ = x.__dict__ 25 | return od 26 | elif isinstance(x, dict): 27 | return {key: _apply(value) for key, value in x.items()} 28 | elif isinstance(x, list): 29 | return [_apply(x) for x in x] 30 | elif isinstance(x, tuple): 31 | return tuple(_apply(x) for x in x) 32 | elif isinstance(x, set): 33 | return {_apply(x) for x in x} 34 | else: 35 | return x 36 | 37 | return _apply(sample) 38 | 39 | 40 | class NativeCheckpointableIterator(iterators.CheckpointableIterator): 41 | def __init__(self, iterable: Iterable): 42 | self._input_iterable = iterable 43 | self.setstate(None) 44 | 45 | def getstate(self) -> Dict: 46 | return {"num_items_yielded": self._num_items_yielded} 47 | 48 | def setstate(self, checkpoint: Optional[Dict]): 49 | self._iterator = iter(self._input_iterable) 50 | self._num_items_yielded = ( 51 | iterators._advance_iterator(self._iterator, checkpoint["num_items_yielded"]) 52 | if checkpoint is not None 53 | else 0 54 | ) 55 | 56 | def __next__(self): 57 | item = next(self._iterator) 58 | self._num_items_yielded += 1 59 | return item 60 | 61 | def close(self): 62 | pass 63 | 64 | 65 | class WeightIterator(object): 66 | def __init__(self, weights, seed): 67 | self.weights = weights 68 | self.seed = seed 69 | self.control_index = list(range(len(weights))) 70 | self.setstate(None) 71 | 72 | def __iter__(self): 73 | return self 74 | 75 | def getstate(self): 76 | return {"random_state": self._random_state} 77 | 78 | def setstate(self, checkpoint): 79 | self._random_state = checkpoint["random_state"] if checkpoint else None 80 | self._random = ( 81 | None # this will trigger the lazy initialization in self.__next__ 82 | ) 83 | 84 | def __next__(self): 85 | if self._random is None: 86 | self._random = Random(self.seed) 87 | if self._random_state is not None: 88 | self._random.setstate(self._random_state) 89 | idx = self._random.choices(self.control_index, self.weights)[0] 90 | self._random_state = self._random.getstate() 91 | return idx 92 | 93 | def close(self): 94 | pass 95 | -------------------------------------------------------------------------------- /examples/fairseq/tasks/pretraining.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import json 5 | import logging 6 | import os 7 | from argparse import Namespace 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # 11 | # This source code is licensed under the MIT license found in the 12 | # LICENSE file in the root directory of this source tree. 13 | from dataclasses import dataclass, field 14 | 15 | import sentencepiece as spm 16 | from fairseq import utils 17 | from fairseq.data import Dictionary 18 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass 19 | from fairseq.tasks import FairseqTask, register_task 20 | from omegaconf import II, MISSING 21 | 22 | from .data.mlm_loader import MLMLoader 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) 27 | SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) 28 | 29 | 30 | @dataclass 31 | class PretrainingConfig(FairseqDataclass): 32 | data: str = field( 33 | default=MISSING, 34 | metadata={ 35 | "help": "colon separated path to data directories list, \ 36 | will be iterated upon during epochs in round-robin manner" 37 | }, 38 | ) 39 | sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( 40 | default="complete", 41 | metadata={ 42 | "help": 'If omitted or "none", fills each sample with tokens-per-sample ' 43 | 'tokens. If set to "complete", splits samples only at the end ' 44 | "of sentence, but may include multiple sentences per sample. " 45 | '"complete_doc" is similar but respects doc boundaries. ' 46 | 'If set to "eos", includes only one sentence per sample.' 47 | }, 48 | ) 49 | tokens_per_sample: int = field( 50 | default=1024, 51 | metadata={"help": "max number of tokens per sample for LM dataset"}, 52 | ) 53 | mask_prob: float = field( 54 | default=0.15, 55 | metadata={"help": "probability of replacing a token with mask"}, 56 | ) 57 | leave_unmasked_prob: float = field( 58 | default=0.1, 59 | metadata={"help": "probability that a masked token is unmasked"}, 60 | ) 61 | random_token_prob: float = field( 62 | default=0.1, 63 | metadata={"help": "probability of replacing a token with a random token"}, 64 | ) 65 | freq_weighted_replacement: bool = field( 66 | default=False, 67 | metadata={"help": "sample random replacement words based on word frequencies"}, 68 | ) 69 | mask_whole_words: bool = field( 70 | default=False, 71 | metadata={"help": "mask whole words; you may also want to set --bpe"}, 72 | ) 73 | mask_multiple_length: int = field( 74 | default=1, 75 | metadata={"help": "repeat the mask indices multiple times"}, 76 | ) 77 | mask_stdev: float = field( 78 | default=0.0, 79 | metadata={"help": "stdev of the mask length"}, 80 | ) 81 | shorten_method: SHORTEN_METHOD_CHOICES = field( 82 | default="none", 83 | metadata={ 84 | "help": "if not none, shorten sequences that exceed --tokens-per-sample" 85 | }, 86 | ) 87 | shorten_data_split_list: str = field( 88 | default="", 89 | metadata={ 90 | "help": "comma-separated list of dataset splits to apply shortening to, " 91 | 'e.g., "train,valid" (default: all dataset splits)' 92 | }, 93 | ) 94 | seed: int = II("common.seed") 95 | span_length: float = field( 96 | default=3.0, 97 | metadata={"help": "average span length for masking"}, 98 | ) 99 | remove_source_sentinel: bool = field( 100 | default=False, 101 | metadata={"help": "remove the source sentinel for the span corruption task"}, 102 | ) 103 | remove_target_sentinel: bool = field( 104 | default=False, 105 | metadata={"help": "remove the target sentinel for the span corruption task"}, 106 | ) 107 | batch_read_ahead: int = field( 108 | default=100000, 109 | metadata={"help": "batch read ahead size for infinibatch"}, 110 | ) 111 | required_batch_size_multiple: int = II("dataset.required_batch_size_multiple") 112 | spm_model: str = field( 113 | default="", 114 | metadata={"help": "sentencepice model to tokenize the data"}, 115 | ) 116 | dict_file: str = field( 117 | default="", 118 | metadata={"help": ""}, 119 | ) 120 | 121 | 122 | @register_task("pretraining", dataclass=PretrainingConfig) 123 | class PLMTask(FairseqTask): 124 | def __init__(self, cfg, dictionary, tokenizer): 125 | super().__init__(cfg) 126 | self.cfg = cfg 127 | self.dictionary = dictionary 128 | self.tokenizer = tokenizer 129 | self.seed = cfg.seed 130 | self.mask_idx = dictionary.index("") 131 | 132 | @classmethod 133 | def setup_task(cls, cfg, **kwargs): 134 | paths = utils.split_paths(cfg.data) 135 | assert len(paths) > 0 136 | if cfg.dict_file != "": 137 | dictionary = Dictionary.load(cfg.dict_file) 138 | else: 139 | dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) 140 | 141 | # add mask token 142 | dictionary.add_symbol("") 143 | for i in range(100): 144 | dictionary.add_symbol(f"") 145 | 146 | dictionary.pad_to_multiple_(cfg.required_batch_size_multiple) 147 | logger.info("dictionary: {} types".format(len(dictionary))) 148 | 149 | # tokenizer = SentencepieceBPE(Namespace(sentencepiece_model=cfg.spm_model)) 150 | tokenizer = spm.SentencePieceProcessor() 151 | tokenizer.Load(cfg.spm_model) 152 | return cls(cfg, dictionary, tokenizer) 153 | 154 | def load_dataset(self, split, epoch=1, combine=False, **kwargs): 155 | self.datasets[split] = { 156 | "data": json.load(open(f"{self.cfg.data}/json/{split}.json")), 157 | "data_dir": self.cfg.data, 158 | "shuffle": True if split == "train" else False, 159 | } 160 | self.datasets[split] = Namespace(**self.datasets[split]) 161 | 162 | def dataset(self, split): 163 | if split not in self.datasets: 164 | raise KeyError("Dataset not loaded: " + split) 165 | 166 | return self.datasets[split] 167 | 168 | def get_batch_iterator( 169 | self, 170 | dataset, 171 | max_tokens=None, 172 | max_sentences=None, 173 | max_positions=None, 174 | ignore_invalid_inputs=False, 175 | required_batch_size_multiple=1, 176 | seed=1, 177 | num_shards=1, 178 | shard_id=0, 179 | num_workers=0, 180 | epoch=1, 181 | data_buffer_size=0, 182 | disable_iterator_cache=False, 183 | ): 184 | return MLMLoader( 185 | self.cfg, 186 | dataset, 187 | self.dictionary, 188 | self.tokenizer, 189 | max_tokens=max_tokens, 190 | max_sentences=max_sentences, 191 | max_positions=max_positions, 192 | ignore_invalid_inputs=ignore_invalid_inputs, 193 | required_batch_size_multiple=required_batch_size_multiple, 194 | seed=seed, 195 | num_shards=num_shards, 196 | shard_id=shard_id, 197 | ) 198 | 199 | @property 200 | def source_dictionary(self): 201 | return self.dictionary 202 | 203 | @property 204 | def target_dictionary(self): 205 | return self.dictionary 206 | -------------------------------------------------------------------------------- /examples/fairseq/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | # flake8: noqa 5 | import models 6 | import tasks 7 | from fairseq_cli.train import cli_main 8 | 9 | if __name__ == "__main__": 10 | cli_main() 11 | -------------------------------------------------------------------------------- /examples/fairseq/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /examples/fairseq/utils/sparse_clip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import math 5 | import warnings 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm 10 | 11 | 12 | @torch.no_grad() 13 | def clip_grad_norm_( 14 | params, max_norm, moe_expert_count, aggregate_norm_fn=None 15 | ) -> torch.Tensor: 16 | def grad_exists(p): 17 | return p is not None and getattr(p, "grad", None) is not None 18 | 19 | if isinstance(params, torch.Tensor): 20 | params = [params] 21 | params = list(params) 22 | params = list(filter(grad_exists, params)) 23 | grads, expert_grads, base_expert_grads, sharded_grads = [], [], [], [] 24 | denom = math.sqrt(max(dist.get_global_world_size(), moe_expert_count)) 25 | for p in params: 26 | if hasattr(p, "expert"): 27 | expert_grads.append(p.grad.detach() / denom) 28 | elif hasattr(p, "base_expert"): 29 | base_expert_grads.append(p.grad.detach()) 30 | elif hasattr(p, "_is_sharded"): 31 | sharded_grads.append(p.grad.detach()) 32 | else: 33 | grads.append(p.grad.detach()) 34 | if len(grads) == 0: 35 | if len(params) > 0: 36 | total_norm = params[0].new_tensor(0.0) 37 | else: 38 | total_norm = torch.tensor(0.0) 39 | elif len(grads) == 1: 40 | total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) 41 | else: 42 | if multi_tensor_l2norm_available: 43 | total_norm = multi_tensor_total_norm(grads) 44 | else: 45 | if torch.cuda.is_available(): 46 | warnings.warn( 47 | "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " 48 | "you may get better performance by installing NVIDIA's apex library" 49 | ) 50 | device = torch.cuda.current_device() 51 | elif grads[0].device.type == "xla": 52 | device = grads[0].device 53 | else: 54 | device = torch.device("cpu") 55 | total_norm = torch.norm( 56 | torch.stack( 57 | [torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads] 58 | ) 59 | ) 60 | 61 | # calculate split_norm and all_reduce with other workers 62 | norms = [total_norm] 63 | for split_grads in [expert_grads, sharded_grads]: 64 | if len(split_grads) == 0: 65 | continue 66 | split_norm = torch.norm( 67 | torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads]) 68 | ) 69 | if dist.is_initialized(): 70 | split_norm.pow_(2) 71 | dist.all_reduce(split_norm) 72 | split_norm.sqrt_() 73 | norms.append(split_norm) 74 | if len(norms) > 1: 75 | total_norm = torch.norm(torch.stack(norms)) 76 | 77 | if aggregate_norm_fn is not None: 78 | total_norm = aggregate_norm_fn(total_norm) 79 | 80 | if max_norm > 0: 81 | max_norm = float(max_norm) 82 | clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) 83 | for g in grads + expert_grads + sharded_grads + base_expert_grads: 84 | g.mul_(clip_coef) 85 | return total_norm 86 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | from io import open 5 | 6 | from setuptools import find_packages, setup 7 | 8 | setup( 9 | name="torchscale", 10 | version="0.1.1", 11 | author="TorchScale Team", 12 | author_email="Shuming.Ma@microsoft.com", 13 | description="Transformers at any scale", 14 | long_description=open("README.md", "r", encoding="utf-8").read(), 15 | long_description_content_type="text/markdown", 16 | keywords="Transformers at any scale", 17 | license="MIT", 18 | url="https://github.com/msranlp/torchscale", 19 | packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), 20 | install_requires=["apex", "torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"], 21 | python_requires=">=3.8.0", 22 | classifiers=[ 23 | "Programming Language :: Python :: 3", 24 | ], 25 | ) 26 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /tests/test_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import pytest 5 | import torch 6 | 7 | from torchscale.architecture.config import DecoderConfig 8 | from torchscale.architecture.decoder import Decoder 9 | 10 | testcases = [ 11 | {}, 12 | {"vocab_size": 64000}, 13 | {"activation_fn": "relu"}, 14 | {"drop_path_rate": 0.1}, 15 | {"decoder_normalize_before": False}, 16 | {"no_scale_embedding": False}, 17 | {"layernorm_embedding": True}, 18 | {"rel_pos_buckets": 32, "max_rel_pos": 256}, 19 | {"deepnorm": True, "subln": False, "decoder_normalize_before": False}, 20 | {"bert_init": True}, 21 | {"multiway": True}, 22 | {"share_decoder_input_output_embed": True}, 23 | {"checkpoint_activations": True}, 24 | {"fsdp": True}, 25 | ] 26 | 27 | 28 | @pytest.mark.parametrize("args", testcases) 29 | def test_decoder(args): 30 | config = DecoderConfig(**args) 31 | model = Decoder(config) 32 | prev_output_tokens = torch.ones(2, 10) 33 | token_embeddings = torch.rand(2, 10, config.decoder_embed_dim) 34 | model( 35 | prev_output_tokens=prev_output_tokens, 36 | token_embeddings=token_embeddings, 37 | features_only=True, 38 | ) 39 | -------------------------------------------------------------------------------- /tests/test_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import pytest 5 | import torch 6 | 7 | from torchscale.architecture.config import EncoderConfig 8 | from torchscale.architecture.encoder import Encoder 9 | 10 | testcases = [ 11 | {}, 12 | {"vocab_size": 64000}, 13 | {"activation_fn": "relu"}, 14 | {"drop_path_rate": 0.1}, 15 | {"encoder_normalize_before": False}, 16 | {"no_scale_embedding": False}, 17 | {"layernorm_embedding": True}, 18 | {"rel_pos_buckets": 32, "max_rel_pos": 256}, 19 | {"deepnorm": True, "subln": False, "encoder_normalize_before": False}, 20 | {"bert_init": True}, 21 | {"multiway": True}, 22 | {"share_encoder_input_output_embed": True}, 23 | {"checkpoint_activations": True}, 24 | {"fsdp": True}, 25 | ] 26 | 27 | 28 | @pytest.mark.parametrize("args", testcases) 29 | def test_encoder(args): 30 | config = EncoderConfig(**args) 31 | model = Encoder(config) 32 | token_embeddings = torch.rand(2, 10, config.encoder_embed_dim) 33 | model(src_tokens=None, token_embeddings=token_embeddings) 34 | -------------------------------------------------------------------------------- /tests/test_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import pytest 5 | import torch 6 | 7 | from torchscale.architecture.config import EncoderDecoderConfig 8 | from torchscale.architecture.encoder_decoder import EncoderDecoder 9 | from torchscale.component.embedding import PositionalEmbedding, TextEmbedding 10 | 11 | testcases = [ 12 | {}, 13 | {"vocab_size": 64000}, 14 | {"activation_fn": "relu"}, 15 | {"drop_path_rate": 0.1}, 16 | {"encoder_normalize_before": False, "decoder_normalize_before": False}, 17 | {"no_scale_embedding": False}, 18 | {"layernorm_embedding": True}, 19 | {"rel_pos_buckets": 32, "max_rel_pos": 256}, 20 | { 21 | "deepnorm": True, 22 | "subln": False, 23 | "encoder_normalize_before": False, 24 | "decoder_normalize_before": False, 25 | }, 26 | {"bert_init": True}, 27 | {"multiway": True}, 28 | {"share_decoder_input_output_embed": True}, 29 | {"share_all_embeddings": True}, 30 | {"checkpoint_activations": True}, 31 | {"fsdp": True}, 32 | ] 33 | 34 | 35 | @pytest.mark.parametrize("args", testcases) 36 | def test_decoder(args): 37 | config = EncoderDecoderConfig(**args) 38 | model = EncoderDecoder( 39 | config, 40 | encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim), 41 | decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim), 42 | encoder_embed_positions=PositionalEmbedding( 43 | config.max_source_positions, config.encoder_embed_dim 44 | ), 45 | decoder_embed_positions=PositionalEmbedding( 46 | config.max_target_positions, config.decoder_embed_dim 47 | ), 48 | ) 49 | 50 | src_tokens = torch.ones(2, 20).long() 51 | prev_output_tokens = torch.ones(2, 10).long() 52 | 53 | model( 54 | src_tokens=src_tokens, 55 | prev_output_tokens=prev_output_tokens, 56 | features_only=True, 57 | ) 58 | -------------------------------------------------------------------------------- /torchscale/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /torchscale/architecture/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /torchscale/architecture/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | 5 | class EncoderConfig(object): 6 | def __init__(self, **kwargs): 7 | self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768) 8 | self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12) 9 | self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072) 10 | self.encoder_layers = kwargs.pop("encoder_layers", 12) 11 | self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True) 12 | self.activation_fn = kwargs.pop("activation_fn", "gelu") 13 | self.dropout = kwargs.pop("dropout", 0.0) 14 | self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) 15 | self.attention_dropout = kwargs.pop("attention_dropout", 0.0) 16 | self.activation_dropout = kwargs.pop("activation_dropout", 0.0) 17 | self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) 18 | self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) 19 | self.moe_freq = kwargs.pop("moe_freq", 0) 20 | self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) 21 | self.moe_expert_count = kwargs.pop("moe_expert_count", 0) 22 | self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) 23 | self.moe_eval_capacity_token_fraction = kwargs.pop( 24 | "moe_eval_capacity_token_fraction", 0.25 25 | ) 26 | self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") 27 | self.moe_normalize_gate_prob_before_dropping = kwargs.pop( 28 | "moe_normalize_gate_prob_before_dropping", False 29 | ) 30 | self.use_xmoe = kwargs.pop("use_xmoe", False) 31 | self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) 32 | self.max_rel_pos = kwargs.pop("max_rel_pos", 0) 33 | self.deepnorm = kwargs.pop("deepnorm", False) 34 | self.subln = kwargs.pop("subln", True) 35 | self.bert_init = kwargs.pop("bert_init", False) 36 | self.multiway = kwargs.pop("multiway", False) 37 | self.share_encoder_input_output_embed = kwargs.pop( 38 | "share_encoder_input_output_embed", False 39 | ) 40 | self.max_source_positions = kwargs.pop("max_source_positions", 1024) 41 | self.no_output_layer = kwargs.pop("no_output_layer", False) 42 | # Text 43 | self.vocab_size = kwargs.pop("vocab_size", -1) 44 | # Vision 45 | self.img_size = kwargs.pop("img_size", 224) 46 | self.patch_size = kwargs.pop("patch_size", 16) 47 | self.in_chans = kwargs.pop("in_chans", 3) 48 | # Fairscale 49 | self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) 50 | self.fsdp = kwargs.pop("fsdp", False) 51 | self.ddp_rank = kwargs.pop("ddp_rank", 0) 52 | 53 | if self.deepnorm: 54 | self.encoder_normalize_before = False 55 | self.subln = False 56 | if self.subln: 57 | self.encoder_normalize_before = True 58 | self.deepnorm = False 59 | if self.use_xmoe: 60 | self.moe_normalize_gate_prob_before_dropping = True 61 | self.moe_second_expert_policy = "random" 62 | assert self.moe_freq > 0 and self.moe_expert_count > 0 63 | 64 | def override(self, args): 65 | for hp in self.__dict__.keys(): 66 | if getattr(args, hp, None) is not None: 67 | self.__dict__[hp] = getattr(args, hp, None) 68 | 69 | 70 | class DecoderConfig(object): 71 | def __init__(self, **kwargs): 72 | self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) 73 | self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12) 74 | self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072) 75 | self.decoder_layers = kwargs.pop("decoder_layers", 12) 76 | self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True) 77 | self.activation_fn = kwargs.pop("activation_fn", "gelu") 78 | self.dropout = kwargs.pop("dropout", 0.0) 79 | self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) 80 | self.attention_dropout = kwargs.pop("attention_dropout", 0.0) 81 | self.activation_dropout = kwargs.pop("activation_dropout", 0.0) 82 | self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) 83 | self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) 84 | self.moe_freq = kwargs.pop("moe_freq", 0) 85 | self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) 86 | self.moe_expert_count = kwargs.pop("moe_expert_count", 0) 87 | self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) 88 | self.moe_eval_capacity_token_fraction = kwargs.pop( 89 | "moe_eval_capacity_token_fraction", 0.25 90 | ) 91 | self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") 92 | self.moe_normalize_gate_prob_before_dropping = kwargs.pop( 93 | "moe_normalize_gate_prob_before_dropping", False 94 | ) 95 | self.use_xmoe = kwargs.pop("use_xmoe", False) 96 | self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) 97 | self.block_size = kwargs.pop("block_size", 2048) 98 | self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) 99 | self.max_rel_pos = kwargs.pop("max_rel_pos", 0) 100 | self.deepnorm = kwargs.pop("deepnorm", False) 101 | self.subln = kwargs.pop("subln", True) 102 | self.bert_init = kwargs.pop("bert_init", False) 103 | self.multiway = kwargs.pop("multiway", False) 104 | self.share_decoder_input_output_embed = kwargs.pop( 105 | "share_decoder_input_output_embed", False 106 | ) 107 | self.max_target_positions = kwargs.pop("max_target_positions", 1024) 108 | self.no_output_layer = kwargs.pop("no_output_layer", False) 109 | # Text 110 | self.vocab_size = kwargs.pop("vocab_size", -1) 111 | # Fairscale 112 | self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) 113 | self.fsdp = kwargs.pop("fsdp", False) 114 | self.ddp_rank = kwargs.pop("ddp_rank", 0) 115 | 116 | if self.deepnorm: 117 | self.decoder_normalize_before = False 118 | self.subln = False 119 | if self.subln: 120 | self.decoder_normalize_before = True 121 | self.deepnorm = False 122 | if self.use_xmoe: 123 | self.moe_normalize_gate_prob_before_dropping = True 124 | self.moe_second_expert_policy = "random" 125 | assert self.moe_freq > 0 and self.moe_expert_count > 0 126 | 127 | def override(self, args): 128 | for hp in self.__dict__.keys(): 129 | if getattr(args, hp, None) is not None: 130 | self.__dict__[hp] = getattr(args, hp, None) 131 | 132 | 133 | class EncoderDecoderConfig(object): 134 | def __init__(self, **kwargs): 135 | self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768) 136 | self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12) 137 | self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072) 138 | self.encoder_layers = kwargs.pop("encoder_layers", 12) 139 | self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True) 140 | self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) 141 | self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12) 142 | self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072) 143 | self.decoder_layers = kwargs.pop("decoder_layers", 12) 144 | self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True) 145 | self.activation_fn = kwargs.pop("activation_fn", "gelu") 146 | self.dropout = kwargs.pop("dropout", 0.0) 147 | self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) 148 | self.attention_dropout = kwargs.pop("attention_dropout", 0.0) 149 | self.activation_dropout = kwargs.pop("activation_dropout", 0.0) 150 | self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) 151 | self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) 152 | self.moe_freq = kwargs.pop("moe_freq", 0) 153 | self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) 154 | self.moe_expert_count = kwargs.pop("moe_expert_count", 0) 155 | self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) 156 | self.moe_eval_capacity_token_fraction = kwargs.pop( 157 | "moe_eval_capacity_token_fraction", 0.25 158 | ) 159 | self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") 160 | self.moe_normalize_gate_prob_before_dropping = kwargs.pop( 161 | "moe_normalize_gate_prob_before_dropping", False 162 | ) 163 | self.use_xmoe = kwargs.pop("use_xmoe", False) 164 | self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) 165 | self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) 166 | self.max_rel_pos = kwargs.pop("max_rel_pos", 0) 167 | self.deepnorm = kwargs.pop("deepnorm", False) 168 | self.subln = kwargs.pop("subln", True) 169 | self.bert_init = kwargs.pop("bert_init", False) 170 | self.multiway = kwargs.pop("multiway", False) 171 | self.share_all_embeddings = kwargs.pop("share_all_embeddings", False) 172 | self.share_decoder_input_output_embed = kwargs.pop( 173 | "share_decoder_input_output_embed", False 174 | ) 175 | self.max_source_positions = kwargs.pop("max_source_positions", 1024) 176 | self.max_target_positions = kwargs.pop("max_target_positions", 1024) 177 | self.no_output_layer = kwargs.pop("no_output_layer", False) 178 | # Text 179 | self.vocab_size = kwargs.pop("vocab_size", -1) 180 | # Fairscale 181 | self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) 182 | self.fsdp = kwargs.pop("fsdp", False) 183 | self.ddp_rank = kwargs.pop("ddp_rank", 0) 184 | 185 | if self.deepnorm: 186 | self.encoder_normalize_before = False 187 | self.decoder_normalize_before = False 188 | self.subln = False 189 | if self.subln: 190 | self.encoder_normalize_before = True 191 | self.decoder_normalize_before = True 192 | self.deepnorm = False 193 | if self.use_xmoe: 194 | self.moe_normalize_gate_prob_before_dropping = True 195 | self.moe_second_expert_policy = "random" 196 | assert self.moe_freq > 0 and self.moe_expert_count > 0 197 | 198 | def override(self, args): 199 | for hp in self.__dict__.keys(): 200 | if getattr(args, hp, None) is not None: 201 | self.__dict__[hp] = getattr(args, hp, None) 202 | -------------------------------------------------------------------------------- /torchscale/architecture/decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import math 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from apex.normalization import FusedLayerNorm as LayerNorm 11 | from fairscale.nn import checkpoint_wrapper, wrap 12 | 13 | from torchscale.architecture.utils import init_bert_params 14 | from torchscale.component.droppath import DropPath 15 | from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts 16 | from torchscale.component.multihead_attention import MultiheadAttention 17 | from torchscale.component.xpos_relative_position import XPos 18 | from torchscale.component.relative_position_bias import RelativePositionBias 19 | from torchscale.component.xmoe.moe_layer import MOELayer 20 | from torchscale.component.xmoe.routing import Top1Gate, Top2Gate 21 | 22 | 23 | class DecoderLayer(nn.Module): 24 | def __init__( 25 | self, 26 | args, 27 | depth, 28 | is_moe_layer=False, 29 | is_encoder_decoder=False, 30 | ): 31 | super().__init__() 32 | self.args = args 33 | self.embed_dim = args.decoder_embed_dim 34 | self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True) 35 | 36 | if args.drop_path_rate > 0: 37 | drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[ 38 | depth 39 | ] 40 | self.drop_path = DropPath(drop_path_prob) 41 | else: 42 | self.drop_path = None 43 | 44 | self.self_attn = self.build_self_attention(self.embed_dim, args) 45 | 46 | self.normalize_before = args.decoder_normalize_before 47 | 48 | self.self_attn_layer_norm = LayerNorm(self.embed_dim) 49 | 50 | if not is_encoder_decoder: 51 | self.encoder_attn = None 52 | self.encoder_attn_layer_norm = None 53 | else: 54 | self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) 55 | self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) 56 | 57 | self.is_moe_layer = is_moe_layer 58 | self.ffn_dim = args.decoder_ffn_embed_dim 59 | 60 | if not self.is_moe_layer: 61 | self.ffn = self.build_ffn( 62 | self.embed_dim, 63 | self.args, 64 | ) 65 | else: 66 | if args.moe_top1_expert: 67 | gate = Top1Gate( 68 | self.embed_dim, 69 | args.moe_expert_count, 70 | use_fp32=args.moe_gating_use_fp32, 71 | moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction, 72 | use_xmoe=args.use_xmoe, 73 | ) 74 | else: 75 | gate = Top2Gate( 76 | self.embed_dim, 77 | args.moe_expert_count, 78 | args.moe_gating_use_fp32, 79 | args.moe_second_expert_policy, 80 | args.moe_normalize_gate_prob_before_dropping, 81 | args.moe_eval_capacity_token_fraction, 82 | use_xmoe=args.use_xmoe, 83 | ) 84 | experts = make_experts(args, self.embed_dim, self.ffn_dim) 85 | self.moe_layer = MOELayer(gate, experts, args) 86 | 87 | self.final_layer_norm = LayerNorm(self.embed_dim) 88 | 89 | if args.deepnorm: 90 | if is_encoder_decoder: 91 | self.alpha = math.pow(3.0 * args.decoder_layers, 0.25) 92 | else: 93 | self.alpha = math.pow(2.0 * args.decoder_layers, 0.25) 94 | else: 95 | self.alpha = 1.0 96 | 97 | def build_ffn(self, embed_dim, args): 98 | return FeedForwardNetwork( 99 | embed_dim, 100 | self.ffn_dim, 101 | args.activation_fn, 102 | args.dropout, 103 | args.activation_dropout, 104 | args.subln, 105 | ) 106 | 107 | def build_self_attention(self, embed_dim, args): 108 | return MultiheadAttention( 109 | args, 110 | embed_dim, 111 | args.decoder_attention_heads, 112 | dropout=args.attention_dropout, 113 | self_attention=True, 114 | encoder_decoder_attention=False, 115 | subln=args.subln, 116 | ) 117 | 118 | def build_encoder_attention(self, embed_dim, args): 119 | return MultiheadAttention( 120 | args, 121 | embed_dim, 122 | args.decoder_attention_heads, 123 | dropout=args.attention_dropout, 124 | self_attention=False, 125 | encoder_decoder_attention=True, 126 | subln=args.subln, 127 | ) 128 | 129 | def residual_connection(self, x, residual): 130 | return residual * self.alpha + x 131 | 132 | def forward( 133 | self, 134 | x, 135 | encoder_out=None, 136 | encoder_padding_mask=None, 137 | incremental_state=None, 138 | self_attn_mask=None, 139 | self_attn_padding_mask=None, 140 | self_attn_rel_pos=None, 141 | cross_attn_rel_pos=None, 142 | ): 143 | residual = x 144 | if self.normalize_before: 145 | x = self.self_attn_layer_norm(x) 146 | 147 | x, attn = self.self_attn( 148 | query=x, 149 | key=x, 150 | value=x, 151 | key_padding_mask=self_attn_padding_mask, 152 | incremental_state=incremental_state, 153 | attn_mask=self_attn_mask, 154 | rel_pos=self_attn_rel_pos, 155 | ) 156 | x = self.dropout_module(x) 157 | 158 | if self.drop_path is not None: 159 | x = self.drop_path(x) 160 | 161 | x = self.residual_connection(x, residual) 162 | if not self.normalize_before: 163 | x = self.self_attn_layer_norm(x) 164 | 165 | if self.encoder_attn is not None and encoder_out is not None: 166 | residual = x 167 | if self.normalize_before: 168 | x = self.encoder_attn_layer_norm(x) 169 | 170 | x, attn = self.encoder_attn( 171 | query=x, 172 | key=encoder_out, 173 | value=encoder_out, 174 | key_padding_mask=encoder_padding_mask, 175 | incremental_state=None, 176 | rel_pos=cross_attn_rel_pos, 177 | ) 178 | x = self.dropout_module(x) 179 | 180 | if self.drop_path is not None: 181 | x = self.drop_path(x) 182 | 183 | x = self.residual_connection(x, residual) 184 | if not self.normalize_before: 185 | x = self.encoder_attn_layer_norm(x) 186 | 187 | residual = x 188 | if self.normalize_before: 189 | x = self.final_layer_norm(x) 190 | if not self.is_moe_layer: 191 | x = self.ffn(x) 192 | l_aux = None 193 | else: 194 | x = x.transpose(0, 1) 195 | x, l_aux = self.moe_layer(x) 196 | x = x.transpose(0, 1) 197 | 198 | if self.drop_path is not None: 199 | x = self.drop_path(x) 200 | 201 | x = self.residual_connection(x, residual) 202 | if not self.normalize_before: 203 | x = self.final_layer_norm(x) 204 | 205 | return x, attn, None, l_aux 206 | 207 | 208 | class Decoder(nn.Module): 209 | def __init__( 210 | self, 211 | args, 212 | embed_tokens=None, 213 | embed_positions=None, 214 | output_projection=None, 215 | is_encoder_decoder=False, 216 | **kwargs 217 | ): 218 | super().__init__(**kwargs) 219 | self.args = args 220 | 221 | self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True) 222 | 223 | embed_dim = args.decoder_embed_dim 224 | self.embed_dim = embed_dim 225 | self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) 226 | 227 | self.embed_tokens = embed_tokens 228 | self.embed_positions = embed_positions 229 | 230 | if ( 231 | output_projection is None 232 | and not args.no_output_layer 233 | and args.vocab_size > 0 234 | ): 235 | self.output_projection = self.build_output_projection(args) 236 | else: 237 | self.output_projection = output_projection 238 | 239 | if args.layernorm_embedding: 240 | self.layernorm_embedding = LayerNorm(embed_dim) 241 | else: 242 | self.layernorm_embedding = None 243 | 244 | self.layers = nn.ModuleList([]) 245 | 246 | moe_freq = args.moe_freq 247 | for i in range(args.decoder_layers): 248 | is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0 249 | self.layers.append( 250 | self.build_decoder_layer( 251 | args, 252 | depth=i, 253 | is_moe_layer=is_moe_layer, 254 | is_encoder_decoder=is_encoder_decoder, 255 | ) 256 | ) 257 | 258 | self.num_layers = len(self.layers) 259 | 260 | if args.decoder_normalize_before: 261 | self.layer_norm = LayerNorm(embed_dim) 262 | else: 263 | self.layer_norm = None 264 | 265 | self.output_projection = output_projection 266 | 267 | self.block_size = args.block_size 268 | self.half_block_size = self.block_size // 2 269 | 270 | self.self_attn_xpos = None 271 | self.cross_attn_xpos = None 272 | self.self_attn_relative_position = None 273 | self.cross_attn_relative_position = None 274 | if args.xpos_rel_pos: 275 | self.self_attn_xpos = XPos( 276 | args.decoder_embed_dim // args.decoder_attention_heads 277 | ) 278 | if is_encoder_decoder: 279 | self.cross_attn_xpos = XPos( 280 | args.decoder_embed_dim // args.decoder_attention_heads 281 | ) 282 | elif args.rel_pos_buckets > 0 and args.max_rel_pos > 0: 283 | self.self_attn_relative_position = RelativePositionBias( 284 | num_buckets=args.rel_pos_buckets, 285 | max_distance=args.max_rel_pos, 286 | n_heads=args.decoder_attention_heads, 287 | ) 288 | if is_encoder_decoder: 289 | self.cross_attn_relative_position = RelativePositionBias( 290 | num_buckets=args.rel_pos_buckets, 291 | max_distance=args.max_rel_pos, 292 | n_heads=args.decoder_attention_heads, 293 | ) 294 | 295 | if args.bert_init: 296 | self.apply(init_bert_params) 297 | 298 | if args.deepnorm: 299 | if is_encoder_decoder: 300 | init_scale = math.pow(12.0 * args.decoder_layers, 0.25) 301 | else: 302 | init_scale = math.pow(8.0 * args.decoder_layers, 0.25) 303 | for name, p in self.named_parameters(): 304 | if ( 305 | "fc1" in name 306 | or "fc2" in name 307 | or "out_proj" in name 308 | or "v_proj" in name 309 | ): 310 | p.data.div_(init_scale) 311 | 312 | if args.subln: 313 | if is_encoder_decoder: 314 | init_scale = math.sqrt(math.log(args.decoder_layers * 3)) 315 | else: 316 | init_scale = math.sqrt(math.log(args.decoder_layers * 2)) 317 | for name, p in self.named_parameters(): 318 | if "encoder_attn" in name: 319 | continue 320 | if ( 321 | "fc1" in name 322 | or "fc2" in name 323 | or "out_proj" in name 324 | or "v_proj" in name 325 | ): 326 | p.data.mul_(init_scale) 327 | 328 | def build_output_projection( 329 | self, 330 | args, 331 | ): 332 | if args.share_decoder_input_output_embed: 333 | output_projection = torch.nn.Linear( 334 | self.embed_tokens.weight.shape[1], 335 | self.embed_tokens.weight.shape[0], 336 | bias=False, 337 | ) 338 | output_projection.weight = self.embed_tokens.weight 339 | else: 340 | output_projection = torch.nn.Linear( 341 | args.decoder_embed_dim, args.vocab_size, bias=False 342 | ) 343 | torch.nn.init.normal_( 344 | output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5 345 | ) 346 | return output_projection 347 | 348 | def build_decoder_layer( 349 | self, args, depth, is_moe_layer=False, is_encoder_decoder=False 350 | ): 351 | layer = DecoderLayer( 352 | args, 353 | depth, 354 | is_moe_layer=is_moe_layer, 355 | is_encoder_decoder=is_encoder_decoder, 356 | ) 357 | if args.checkpoint_activations: 358 | layer = checkpoint_wrapper(layer) 359 | if args.fsdp: 360 | layer = wrap(layer) 361 | return layer 362 | 363 | def forward_embedding( 364 | self, 365 | tokens, 366 | token_embedding=None, 367 | incremental_state=None, 368 | ): 369 | positions = None 370 | if self.embed_positions is not None: 371 | positions = self.embed_positions( 372 | tokens, incremental_state=incremental_state 373 | ) 374 | 375 | if incremental_state is not None: 376 | tokens = tokens[:, -1:] 377 | if positions is not None: 378 | positions = positions[:, -1:] 379 | 380 | if token_embedding is None: 381 | token_embedding = self.embed_tokens(tokens) 382 | 383 | x = embed = self.embed_scale * token_embedding 384 | 385 | if positions is not None: 386 | x += positions 387 | 388 | if self.layernorm_embedding is not None: 389 | x = self.layernorm_embedding(x) 390 | 391 | x = self.dropout_module(x) 392 | 393 | return x, embed 394 | 395 | def forward( 396 | self, 397 | prev_output_tokens, 398 | self_attn_padding_mask=None, 399 | encoder_out=None, 400 | incremental_state=None, 401 | features_only=False, 402 | return_all_hiddens=False, 403 | token_embeddings=None, 404 | **kwargs 405 | ): 406 | if self.block_size > 0 and prev_output_tokens.shape[1] > self.block_size and incremental_state is None: # padding to complete block 407 | activate_block = True 408 | src_length = prev_output_tokens.shape[1] 409 | pad_length = (src_length + self.half_block_size - 1) // self.half_block_size * self.half_block_size 410 | align_pad_length = pad_length - src_length 411 | if self_attn_padding_mask is None: 412 | self_attn_padding_mask = torch.zeros_like(prev_output_tokens) 413 | 414 | prev_output_tokens = F.pad(prev_output_tokens, (0, align_pad_length)) 415 | self_attn_padding_mask = F.pad(self_attn_padding_mask, (self.half_block_size, align_pad_length), value=1).unfold(1, self.block_size, self.half_block_size).transpose(0, 1).reshape(-1, self.block_size) 416 | else: 417 | activate_block = False 418 | # embed tokens and positions 419 | x, _ = self.forward_embedding( 420 | prev_output_tokens, token_embeddings, incremental_state 421 | ) 422 | x = x.transpose(0, 1) 423 | 424 | # relative position 425 | self_attn_rel_pos_bias = None 426 | slen = prev_output_tokens.size(1) 427 | if self.self_attn_xpos is not None: 428 | if activate_block: 429 | self_attn_rel_pos_bias = self.self_attn_xpos(self.block_size) 430 | else: 431 | rel_pos_len = slen if incremental_state is None else (incremental_state[0]["prev_key"].shape[2] + 1) 432 | self_attn_rel_pos_bias = self.self_attn_xpos(rel_pos_len) 433 | elif self.self_attn_relative_position is not None: 434 | self_attn_rel_pos_bias = self.self_attn_relative_position( 435 | batch_size=x.size(1), qlen=slen, klen=slen 436 | ) 437 | if incremental_state is not None: 438 | self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :] 439 | 440 | cross_attn_rel_pos_bias = None 441 | if self.cross_attn_xpos is not None: 442 | cross_attn_rel_pos_bias = self.cross_attn_xpos(slen + encoder_out["encoder_out"].size(0)) 443 | elif self.cross_attn_relative_position is not None: 444 | cross_attn_rel_pos_bias = self.cross_attn_relative_position( 445 | batch_size=x.size(1), 446 | qlen=slen, 447 | klen=encoder_out["encoder_out"].size(0), 448 | ) 449 | if incremental_state is not None: 450 | cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[:, -1:, :] 451 | 452 | # decoder layers 453 | inner_states = [x] 454 | 455 | if encoder_out is None: 456 | l_aux = [] 457 | else: 458 | l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else [] 459 | 460 | for idx, layer in enumerate(self.layers): 461 | if incremental_state is None: 462 | if activate_block: 463 | self_attn_mask = torch.triu( 464 | torch.zeros([self.half_block_size, self.block_size]) 465 | .float() 466 | .fill_(float("-inf")) 467 | .type_as(x), 468 | self.half_block_size + 1, 469 | ) 470 | else: 471 | self_attn_mask = torch.triu( 472 | torch.zeros([x.size(0), x.size(0)]) 473 | .float() 474 | .fill_(float("-inf")) 475 | .type_as(x), 476 | 1, 477 | ) 478 | else: 479 | self_attn_mask = None 480 | if idx not in incremental_state: 481 | incremental_state[idx] = {} 482 | 483 | x, layer_attn, _, l_aux_i = layer( 484 | x, 485 | encoder_out["encoder_out"] if encoder_out is not None else None, 486 | encoder_out["encoder_padding_mask"] 487 | if encoder_out is not None 488 | else None, 489 | incremental_state[idx] if incremental_state is not None else None, 490 | self_attn_mask=self_attn_mask, 491 | self_attn_padding_mask=self_attn_padding_mask, 492 | self_attn_rel_pos=self_attn_rel_pos_bias, 493 | cross_attn_rel_pos=cross_attn_rel_pos_bias, 494 | ) 495 | l_aux.append(l_aux_i) 496 | inner_states.append(x) 497 | if self.block_size > 0 and incremental_state is not None: 498 | if incremental_state[idx]["prev_key"].shape[2] > self.block_size: # Window Attention is implemented here 499 | incremental_state[idx]["prev_key"] = incremental_state[idx]["prev_key"][:, :, -self.block_size:] 500 | incremental_state[idx]["prev_value"] = incremental_state[idx]["prev_value"][:, :, -self.block_size:] 501 | 502 | 503 | if self.layer_norm is not None: 504 | x = self.layer_norm(x) 505 | 506 | x = x.transpose(0, 1) 507 | if activate_block: 508 | x = x[:, :src_length] 509 | 510 | if not features_only: 511 | x = self.output_layer(x) 512 | 513 | return x, { 514 | "inner_states": inner_states, 515 | "l_aux": l_aux, 516 | "attn": None, 517 | } 518 | 519 | def output_layer(self, features): 520 | return self.output_projection(features) 521 | -------------------------------------------------------------------------------- /torchscale/architecture/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import math 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from apex.normalization import FusedLayerNorm as LayerNorm 10 | from fairscale.nn import checkpoint_wrapper, wrap 11 | 12 | from torchscale.architecture.utils import init_bert_params 13 | from torchscale.component.droppath import DropPath 14 | from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts 15 | from torchscale.component.multihead_attention import MultiheadAttention 16 | from torchscale.component.multiway_network import MultiwayWrapper, set_split_position 17 | from torchscale.component.relative_position_bias import RelativePositionBias 18 | from torchscale.component.xmoe.moe_layer import MOELayer 19 | from torchscale.component.xmoe.routing import Top1Gate, Top2Gate 20 | 21 | 22 | class EncoderLayer(nn.Module): 23 | def __init__(self, args, depth, is_moe_layer=False, is_encoder_decoder=False): 24 | super().__init__() 25 | self.args = args 26 | self.embed_dim = args.encoder_embed_dim 27 | self.self_attn = self.build_self_attention(self.embed_dim, args) 28 | self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim)) 29 | self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True) 30 | 31 | if args.drop_path_rate > 0: 32 | drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[ 33 | depth 34 | ] 35 | self.drop_path = DropPath(drop_path_prob) 36 | else: 37 | self.drop_path = None 38 | 39 | self.normalize_before = args.encoder_normalize_before 40 | self.is_moe_layer = is_moe_layer 41 | self.ffn_dim = args.encoder_ffn_embed_dim 42 | 43 | if not self.is_moe_layer: 44 | self.ffn = MultiwayWrapper( 45 | args, 46 | self.build_ffn( 47 | self.embed_dim, 48 | self.args, 49 | ), 50 | ) 51 | else: 52 | assert not self.args.multiway 53 | if args.moe_top1_expert: 54 | gate = Top1Gate( 55 | self.embed_dim, 56 | args.moe_expert_count, 57 | use_fp32=args.moe_gating_use_fp32, 58 | moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction, 59 | use_xmoe=args.use_xmoe, 60 | ) 61 | else: 62 | gate = Top2Gate( 63 | self.embed_dim, 64 | args.moe_expert_count, 65 | args.moe_gating_use_fp32, 66 | args.moe_second_expert_policy, 67 | args.moe_normalize_gate_prob_before_dropping, 68 | args.moe_eval_capacity_token_fraction, 69 | use_xmoe=args.use_xmoe, 70 | ) 71 | experts = make_experts(args, self.embed_dim, self.ffn_dim) 72 | self.moe_layer = MOELayer(gate, experts, args) 73 | self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim)) 74 | 75 | if args.deepnorm: 76 | if is_encoder_decoder: 77 | self.alpha = ( 78 | math.pow( 79 | math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625 80 | ) 81 | * 0.81 82 | ) 83 | else: 84 | self.alpha = math.pow(2.0 * args.encoder_layers, 0.25) 85 | else: 86 | self.alpha = 1.0 87 | 88 | def build_ffn(self, embed_dim, args): 89 | return FeedForwardNetwork( 90 | embed_dim, 91 | self.ffn_dim, 92 | args.activation_fn, 93 | args.dropout, 94 | args.activation_dropout, 95 | args.subln, 96 | ) 97 | 98 | def build_self_attention(self, embed_dim, args): 99 | return MultiheadAttention( 100 | args, 101 | embed_dim, 102 | args.encoder_attention_heads, 103 | dropout=args.attention_dropout, 104 | self_attention=True, 105 | encoder_decoder_attention=False, 106 | subln=args.subln, 107 | ) 108 | 109 | def residual_connection(self, x, residual): 110 | return residual * self.alpha + x 111 | 112 | def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None): 113 | if attn_mask is not None: 114 | attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) 115 | 116 | residual = x 117 | if self.normalize_before: 118 | x = self.self_attn_layer_norm(x) 119 | x, _ = self.self_attn( 120 | query=x, 121 | key=x, 122 | value=x, 123 | key_padding_mask=encoder_padding_mask, 124 | attn_mask=attn_mask, 125 | rel_pos=rel_pos, 126 | ) 127 | x = self.dropout_module(x) 128 | 129 | if self.drop_path is not None: 130 | x = self.drop_path(x) 131 | 132 | x = self.residual_connection(x, residual) 133 | if not self.normalize_before: 134 | x = self.self_attn_layer_norm(x) 135 | 136 | residual = x 137 | if self.normalize_before: 138 | x = self.final_layer_norm(x) 139 | if not self.is_moe_layer: 140 | x = self.ffn(x) 141 | l_aux = None 142 | else: 143 | x = x.transpose(0, 1) 144 | x, l_aux = self.moe_layer(x) 145 | x = x.transpose(0, 1) 146 | 147 | if self.drop_path is not None: 148 | x = self.drop_path(x) 149 | 150 | x = self.residual_connection(x, residual) 151 | if not self.normalize_before: 152 | x = self.final_layer_norm(x) 153 | return x, l_aux 154 | 155 | 156 | class Encoder(nn.Module): 157 | def __init__( 158 | self, 159 | args, 160 | embed_tokens=None, 161 | embed_positions=None, 162 | output_projection=None, 163 | is_encoder_decoder=False, 164 | **kwargs 165 | ): 166 | self.args = args 167 | super().__init__(**kwargs) 168 | 169 | self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True) 170 | 171 | embed_dim = args.encoder_embed_dim 172 | self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) 173 | 174 | self.embed_tokens = embed_tokens 175 | self.embed_positions = embed_positions 176 | 177 | if ( 178 | output_projection is None 179 | and not is_encoder_decoder 180 | and not args.no_output_layer 181 | and args.vocab_size > 0 182 | ): 183 | self.output_projection = self.build_output_projection(args) 184 | else: 185 | self.output_projection = output_projection 186 | 187 | if args.layernorm_embedding: 188 | self.layernorm_embedding = MultiwayWrapper( 189 | args, LayerNorm(embed_dim), dim=1 190 | ) 191 | else: 192 | self.layernorm_embedding = None 193 | 194 | self.layers = nn.ModuleList([]) 195 | 196 | moe_freq = args.moe_freq 197 | for i in range(args.encoder_layers): 198 | is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0 199 | self.layers.append( 200 | self.build_encoder_layer( 201 | args, 202 | depth=i, 203 | is_moe_layer=is_moe_layer, 204 | is_encoder_decoder=is_encoder_decoder, 205 | ) 206 | ) 207 | self.num_layers = len(self.layers) 208 | 209 | if args.encoder_normalize_before: 210 | self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim)) 211 | else: 212 | self.layer_norm = None 213 | 214 | if args.rel_pos_buckets > 0 and args.max_rel_pos > 0: 215 | self.relative_position = RelativePositionBias( 216 | num_buckets=args.rel_pos_buckets, 217 | max_distance=args.max_rel_pos, 218 | n_heads=args.encoder_attention_heads, 219 | ) 220 | else: 221 | self.relative_position = None 222 | 223 | if args.bert_init: 224 | self.apply(init_bert_params) 225 | 226 | if args.deepnorm: 227 | if is_encoder_decoder: 228 | init_scale = ( 229 | math.pow( 230 | math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625 231 | ) 232 | / 1.15 233 | ) 234 | else: 235 | init_scale = math.pow(8.0 * args.encoder_layers, 0.25) 236 | for name, p in self.named_parameters(): 237 | if ( 238 | "fc1" in name 239 | or "fc2" in name 240 | or "out_proj" in name 241 | or "v_proj" in name 242 | ): 243 | p.data.div_(init_scale) 244 | 245 | if args.subln: 246 | if is_encoder_decoder: 247 | init_scale = math.sqrt( 248 | math.log(3 * args.decoder_layers) 249 | * math.log(2 * args.encoder_layers) 250 | / 3 251 | ) 252 | else: 253 | init_scale = math.sqrt(math.log(args.encoder_layers * 2)) 254 | for name, p in self.named_parameters(): 255 | if ( 256 | "fc1" in name 257 | or "fc2" in name 258 | or "out_proj" in name 259 | or "v_proj" in name 260 | ): 261 | p.data.mul_(init_scale) 262 | 263 | def build_output_projection( 264 | self, 265 | args, 266 | ): 267 | if args.share_encoder_input_output_embed: 268 | assert args.encoder_embedding_type == "language" 269 | output_projection = torch.nn.Linear( 270 | self.embed_tokens.weight.shape[1], 271 | self.embed_tokens.weight.shape[0], 272 | bias=False, 273 | ) 274 | output_projection.weight = self.embed_tokens.weight 275 | else: 276 | output_projection = torch.nn.Linear( 277 | args.encoder_embed_dim, args.vocab_size, bias=False 278 | ) 279 | torch.nn.init.normal_( 280 | output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5 281 | ) 282 | return output_projection 283 | 284 | def build_encoder_layer( 285 | self, args, depth, is_moe_layer=False, is_encoder_decoder=False 286 | ): 287 | layer = EncoderLayer( 288 | args, 289 | depth, 290 | is_moe_layer=is_moe_layer, 291 | is_encoder_decoder=is_encoder_decoder, 292 | ) 293 | if args.checkpoint_activations: 294 | layer = checkpoint_wrapper(layer) 295 | if args.fsdp: 296 | layer = wrap(layer) 297 | return layer 298 | 299 | def forward_embedding( 300 | self, 301 | src_tokens, 302 | token_embedding=None, 303 | ): 304 | if token_embedding is None: 305 | token_embedding = self.embed_tokens(src_tokens) 306 | x = embed = self.embed_scale * token_embedding 307 | if self.embed_positions is not None: 308 | if src_tokens is not None: 309 | x = embed + self.embed_positions(src_tokens) 310 | else: 311 | x = embed + self.embed_positions(x) 312 | if self.layernorm_embedding is not None: 313 | x = self.layernorm_embedding(x) 314 | x = self.dropout_module(x) 315 | return x, embed 316 | 317 | def forward( 318 | self, 319 | src_tokens, 320 | encoder_padding_mask=None, 321 | return_all_hiddens=False, 322 | token_embeddings=None, 323 | multiway_split_position=None, 324 | features_only=False, 325 | **kwargs 326 | ): 327 | assert src_tokens is not None or token_embeddings is not None 328 | 329 | if encoder_padding_mask is None: 330 | if src_tokens is not None: 331 | encoder_padding_mask = torch.zeros_like( 332 | src_tokens, device=src_tokens.device 333 | ).bool() 334 | else: 335 | encoder_padding_mask = torch.zeros( 336 | [token_embeddings.size(0), token_embeddings.size(1)], 337 | device=token_embeddings.device, 338 | ).bool() 339 | 340 | if multiway_split_position is not None: 341 | assert self.args.multiway 342 | self.apply(set_split_position(multiway_split_position)) 343 | 344 | x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) 345 | x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) 346 | 347 | x = x.transpose(0, 1) 348 | 349 | encoder_states = [] 350 | 351 | if return_all_hiddens: 352 | encoder_states.append(x) 353 | 354 | rel_pos_bias = None 355 | if self.relative_position is not None: 356 | rel_pos_bias = self.relative_position( 357 | batch_size=x.size(1), qlen=x.size(0), klen=x.size(0) 358 | ) 359 | 360 | l_aux = [] 361 | for layer in self.layers: 362 | x, l_aux_i = layer( 363 | x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias 364 | ) 365 | if return_all_hiddens: 366 | assert encoder_states is not None 367 | encoder_states.append(x) 368 | l_aux.append(l_aux_i) 369 | 370 | if self.layer_norm is not None: 371 | x = self.layer_norm(x) 372 | 373 | if not features_only and self.output_projection is not None: 374 | x = self.output_projection(x) 375 | 376 | return { 377 | "encoder_out": x, 378 | "encoder_embedding": encoder_embedding, 379 | "encoder_padding_mask": encoder_padding_mask, 380 | "encoder_states": encoder_states, 381 | "l_aux": l_aux, 382 | } 383 | -------------------------------------------------------------------------------- /torchscale/architecture/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch.nn as nn 5 | 6 | from torchscale.architecture.decoder import Decoder 7 | from torchscale.architecture.encoder import Encoder 8 | 9 | 10 | class EncoderDecoder(nn.Module): 11 | def __init__( 12 | self, 13 | args, 14 | encoder_embed_tokens=None, 15 | encoder_embed_positions=None, 16 | decoder_embed_tokens=None, 17 | decoder_embed_positions=None, 18 | output_projection=None, 19 | **kwargs 20 | ): 21 | super().__init__() 22 | self.args = args 23 | if args.share_all_embeddings: 24 | args.share_decoder_input_output_embed = True 25 | 26 | self.encoder = Encoder( 27 | args, 28 | encoder_embed_tokens, 29 | encoder_embed_positions, 30 | is_encoder_decoder=True, 31 | **kwargs 32 | ) 33 | 34 | if args.share_all_embeddings and decoder_embed_tokens is None: 35 | decoder_embed_tokens = self.encoder.embed_tokens 36 | 37 | self.decoder = Decoder( 38 | args, 39 | decoder_embed_tokens, 40 | decoder_embed_positions, 41 | output_projection, 42 | is_encoder_decoder=True, 43 | **kwargs 44 | ) 45 | 46 | def forward( 47 | self, 48 | src_tokens, 49 | prev_output_tokens, 50 | return_all_hiddens=False, 51 | features_only=False, 52 | **kwargs 53 | ): 54 | encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens) 55 | decoder_out = self.decoder( 56 | prev_output_tokens, 57 | encoder_out=encoder_out, 58 | features_only=features_only, 59 | return_all_hiddens=return_all_hiddens, 60 | ) 61 | return decoder_out 62 | -------------------------------------------------------------------------------- /torchscale/architecture/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch.nn as nn 5 | 6 | from torchscale.component.multihead_attention import MultiheadAttention 7 | from torchscale.component.multiway_network import MultiwayNetwork 8 | 9 | 10 | def init_bert_params(module): 11 | def normal_(data): 12 | data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) 13 | 14 | if isinstance(module, nn.Linear): 15 | normal_(module.weight.data) 16 | if module.bias is not None: 17 | module.bias.data.zero_() 18 | if isinstance(module, nn.Embedding): 19 | normal_(module.weight.data) 20 | if module.padding_idx is not None: 21 | module.weight.data[module.padding_idx].zero_() 22 | if isinstance(module, MultiheadAttention): 23 | if isinstance(module.q_proj, MultiwayNetwork): 24 | normal_(module.q_proj.A.weight.data) 25 | normal_(module.q_proj.B.weight.data) 26 | normal_(module.k_proj.A.weight.data) 27 | normal_(module.k_proj.B.weight.data) 28 | normal_(module.v_proj.A.weight.data) 29 | normal_(module.v_proj.B.weight.data) 30 | else: 31 | normal_(module.q_proj.weight.data) 32 | normal_(module.k_proj.weight.data) 33 | normal_(module.v_proj.weight.data) 34 | -------------------------------------------------------------------------------- /torchscale/component/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /torchscale/component/droppath.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch.nn as nn 5 | from timm.models.layers import drop_path 6 | 7 | 8 | class DropPath(nn.Module): 9 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 10 | 11 | def __init__(self, drop_prob=None): 12 | super(DropPath, self).__init__() 13 | self.drop_prob = drop_prob 14 | 15 | def forward(self, x): 16 | return drop_path(x, self.drop_prob, self.training) 17 | 18 | def extra_repr(self): 19 | return "p={}".format(self.drop_prob) 20 | -------------------------------------------------------------------------------- /torchscale/component/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class VisionLanguageEmbedding(nn.Module): 10 | def __init__(self, text_embed, vision_embed): 11 | super().__init__() 12 | self.text_embed = text_embed 13 | self.vision_embed = vision_embed 14 | 15 | def forward(self, textual_tokens, visual_tokens, **kwargs): 16 | if textual_tokens is None: 17 | return self.vision_embed(visual_tokens) 18 | 19 | if visual_tokens is None: 20 | return self.text_embed(textual_tokens) 21 | 22 | x1 = self.vision_embed(visual_tokens) 23 | x2 = self.text_embed(textual_tokens) 24 | 25 | return torch.cat([x1, x2], dim=1) 26 | 27 | 28 | class VisionEmbedding(nn.Module): 29 | """Image to Patch Embedding""" 30 | 31 | def __init__( 32 | self, 33 | img_size=224, 34 | patch_size=16, 35 | in_chans=3, 36 | embed_dim=768, 37 | contain_mask_token=False, 38 | prepend_cls_token=False, 39 | ): 40 | super().__init__() 41 | img_size = (img_size, img_size) 42 | patch_size = (patch_size, patch_size) 43 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 44 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 45 | self.img_size = img_size 46 | self.patch_size = patch_size 47 | self.num_patches = num_patches 48 | 49 | self.proj = nn.Conv2d( 50 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size 51 | ) 52 | 53 | if contain_mask_token: 54 | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 55 | else: 56 | self.mask_token = None 57 | 58 | if prepend_cls_token: 59 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 60 | else: 61 | self.cls_token = None 62 | 63 | def forward(self, x, masked_position=None, **kwargs): 64 | B, C, H, W = x.shape 65 | assert ( 66 | H == self.img_size[0] and W == self.img_size[1] 67 | ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 68 | x = self.proj(x).flatten(2).transpose(1, 2) 69 | 70 | batch_size, seq_len, _ = x.size() 71 | 72 | if masked_position is not None: 73 | assert self.mask_token is not None 74 | mask_token = self.mask_token.expand(batch_size, seq_len, -1) 75 | w = masked_position.unsqueeze(-1).type_as(mask_token) 76 | x = x * (1 - w) + mask_token * w 77 | 78 | if self.cls_token is not None: 79 | cls_tokens = self.cls_token.expand( 80 | batch_size, -1, -1 81 | ) # stole cls_tokens impl from Phil Wang, thanks 82 | x = torch.cat((cls_tokens, x), dim=1) 83 | 84 | return x 85 | 86 | 87 | class TextEmbedding(nn.Embedding): 88 | def reset_parameters(self): 89 | nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5) 90 | self._fill_padding_idx_with_zero() 91 | 92 | 93 | class PositionalEmbedding(nn.Embedding): 94 | def forward( 95 | self, 96 | x, 97 | positions=None, 98 | **kwargs, 99 | ): 100 | if positions is None: 101 | # being consistent with Fairseq, which starts from 2. 102 | positions = ( 103 | torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0) 104 | ) 105 | return F.embedding( 106 | positions, 107 | self.weight, 108 | self.padding_idx, 109 | self.max_norm, 110 | self.norm_type, 111 | self.scale_grad_by_freq, 112 | self.sparse, 113 | ) 114 | -------------------------------------------------------------------------------- /torchscale/component/feedforward_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from apex.normalization import FusedLayerNorm as LayerNorm 8 | 9 | 10 | class set_torch_seed(object): 11 | def __init__(self, seed): 12 | assert isinstance(seed, int) 13 | self.rng_state = self.get_rng_state() 14 | 15 | torch.manual_seed(seed) 16 | if torch.cuda.is_available(): 17 | torch.cuda.manual_seed(seed) 18 | 19 | def get_rng_state(self): 20 | state = {"torch_rng_state": torch.get_rng_state()} 21 | if torch.cuda.is_available(): 22 | state["cuda_rng_state"] = torch.cuda.get_rng_state() 23 | return state 24 | 25 | def set_rng_state(self, state): 26 | torch.set_rng_state(state["torch_rng_state"]) 27 | if torch.cuda.is_available(): 28 | torch.cuda.set_rng_state(state["cuda_rng_state"]) 29 | 30 | def __enter__(self): 31 | return self 32 | 33 | def __exit__(self, *exc): 34 | self.set_rng_state(self.rng_state) 35 | 36 | 37 | def make_experts(args, embed_dim, expert_ffn_dim): 38 | world_size = ( 39 | 1 40 | if not torch.distributed.is_initialized() 41 | else torch.distributed.get_world_size() 42 | ) 43 | expert_list = [] 44 | ddp_rank = args.ddp_rank 45 | start_seed = torch.randint(1000000, (1,)).item() 46 | # at least as many experts than gpus 47 | if args.moe_expert_count >= world_size: 48 | assert ( 49 | args.moe_expert_count % world_size == 0 50 | ), f"{args.moe_expert_count}, {world_size}" 51 | local_moe_expert_count = args.moe_expert_count // world_size 52 | for i in range(local_moe_expert_count): 53 | with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i): 54 | expert_list.append( 55 | FeedForwardNetwork( 56 | embed_dim, 57 | expert_ffn_dim, 58 | args.activation_fn, 59 | args.dropout, 60 | args.activation_dropout, 61 | args.subln, 62 | ) 63 | ) 64 | else: 65 | assert ( 66 | world_size % args.moe_expert_count == 0 67 | ), f"{world_size}, {args.moe_expert_count}" 68 | 69 | with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count): 70 | expert_list.append( 71 | FeedForwardNetwork( 72 | embed_dim, 73 | expert_ffn_dim, 74 | args.activation_fn, 75 | args.dropout, 76 | args.activation_dropout, 77 | args.subln, 78 | ) 79 | ) 80 | experts = nn.ModuleList(expert_list) 81 | return experts 82 | 83 | 84 | def get_activation_fn(activation): 85 | if activation == "relu": 86 | return F.relu 87 | elif activation == "gelu": 88 | return F.gelu 89 | else: 90 | raise NotImplementedError 91 | 92 | 93 | class FeedForwardNetwork(nn.Module): 94 | def __init__( 95 | self, 96 | embed_dim, 97 | ffn_dim, 98 | activation_fn, 99 | dropout, 100 | activation_dropout, 101 | subln=False, 102 | ): 103 | super().__init__() 104 | self.embed_dim = embed_dim 105 | self.activation_fn = get_activation_fn(activation=str(activation_fn)) 106 | self.activation_dropout_module = torch.nn.Dropout( 107 | activation_dropout, inplace=True 108 | ) 109 | self.dropout_module = torch.nn.Dropout(dropout, inplace=True) 110 | self.fc1 = nn.Linear(self.embed_dim, ffn_dim) 111 | self.fc2 = nn.Linear(ffn_dim, self.embed_dim) 112 | self.ffn_layernorm = LayerNorm(ffn_dim) if subln else None 113 | 114 | def reset_parameters(self): 115 | self.fc1.reset_parameters() 116 | self.fc2.reset_parameters() 117 | if self.ffn_layernorm is not None: 118 | self.ffn_layernorm.reset_parameters() 119 | 120 | def forward(self, x): 121 | x_shape = x.shape 122 | x = x.reshape(-1, x.size(-1)) 123 | x = self.fc1(x) 124 | x = self.activation_fn(x.float()).type_as(x) 125 | x = self.activation_dropout_module(x) 126 | if self.ffn_layernorm is not None: 127 | x = self.ffn_layernorm(x) 128 | x = self.fc2(x) 129 | x = x.view(x_shape) 130 | x = self.dropout_module(x) 131 | return x 132 | -------------------------------------------------------------------------------- /torchscale/component/multihead_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from apex.normalization import FusedLayerNorm as LayerNorm 9 | from torch import nn 10 | 11 | from .multiway_network import MultiwayWrapper 12 | 13 | def rotate_every_two(x): 14 | x1 = x[:, :, ::2] 15 | x2 = x[:, :, 1::2] 16 | x = torch.stack((-x2, x1), dim=-1) 17 | return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ 18 | 19 | def duplicate_interleave(m): 20 | """ 21 | A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. 22 | """ 23 | dim0 = m.shape[0] 24 | m = m.view(-1, 1) # flatten the matrix 25 | m = m.repeat(1, 2) # repeat all elements into the 2nd dimension 26 | m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy 27 | return m 28 | 29 | def apply_rotary_pos_emb(x, sin, cos, scale=1): 30 | sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos)) 31 | # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) 32 | return (x * cos) + (rotate_every_two(x) * sin) 33 | 34 | class MultiheadAttention(nn.Module): 35 | def __init__( 36 | self, 37 | args, 38 | embed_dim, 39 | num_heads, 40 | dropout=0.0, 41 | self_attention=False, 42 | encoder_decoder_attention=False, 43 | subln=False, 44 | ): 45 | super().__init__() 46 | self.embed_dim = embed_dim 47 | self.num_heads = num_heads 48 | self.head_dim = embed_dim // num_heads 49 | self.scaling = self.head_dim**-0.5 50 | self.block_size = args.block_size 51 | self.half_block_size = self.block_size // 2 52 | 53 | self.self_attention = self_attention 54 | self.encoder_decoder_attention = encoder_decoder_attention 55 | assert self.self_attention ^ self.encoder_decoder_attention 56 | 57 | self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) 58 | self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) 59 | self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) 60 | self.out_proj = MultiwayWrapper( 61 | args, nn.Linear(embed_dim, embed_dim, bias=True) 62 | ) 63 | self.inner_attn_ln = ( 64 | MultiwayWrapper(args, LayerNorm(self.embed_dim)) 65 | if subln and self.self_attention 66 | else None 67 | ) 68 | self.dropout_module = torch.nn.Dropout(dropout, inplace=True) 69 | 70 | def reset_parameters(self): 71 | nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) 72 | nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) 73 | nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) 74 | nn.init.xavier_uniform_(self.out_proj.weight) 75 | nn.init.constant_(self.out_proj.bias, 0.0) 76 | 77 | def forward( 78 | self, 79 | query, 80 | key, 81 | value, 82 | incremental_state=None, 83 | key_padding_mask=None, 84 | attn_mask=None, 85 | rel_pos=None, 86 | ): 87 | tgt_len, bsz, embed_dim = query.size() 88 | src_len = tgt_len 89 | assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" 90 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 91 | 92 | src_len, key_bsz, _ = key.size() 93 | assert key_bsz == bsz, f"{query.size(), key.size()}" 94 | assert value is not None 95 | assert src_len, bsz == value.shape[:2] 96 | 97 | q = self.q_proj(query) # tgt_len, bsz, dim 98 | k = self.k_proj(key) 99 | v = self.v_proj(value) 100 | q *= self.scaling 101 | if self.block_size > 0 and tgt_len > self.block_size: # divide block 102 | assert tgt_len % self.half_block_size == 0 103 | if incremental_state is not None: 104 | incremental_state["prev_key"] = k.view( 105 | bsz, self.num_heads, -1, self.head_dim 106 | ) 107 | incremental_state["prev_value"] = v.view( 108 | bsz, self.num_heads, -1, self.head_dim 109 | ) 110 | 111 | q = q.view(-1, self.half_block_size, bsz * self.num_heads, self.head_dim).transpose(1, 2).reshape(-1, self.half_block_size, self.head_dim) 112 | k = F.pad(k, (0, 0, 0, 0, self.half_block_size, 0)).unfold(0, self.block_size, self.half_block_size).reshape(-1, self.head_dim, self.block_size).transpose(1, 2) 113 | v = F.pad(v, (0, 0, 0, 0, self.half_block_size, 0)).unfold(0, self.block_size, self.half_block_size).reshape(-1, self.head_dim, self.block_size).transpose(1, 2) 114 | bsz *= tgt_len // self.half_block_size 115 | tgt_len = self.half_block_size 116 | src_len = self.block_size 117 | 118 | else: 119 | q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 120 | k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 121 | v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 122 | if incremental_state is not None: 123 | if "prev_key" in incremental_state: 124 | prev_key = incremental_state["prev_key"].view( 125 | bsz * self.num_heads, -1, self.head_dim 126 | ) 127 | prev_value = incremental_state["prev_value"].view( 128 | bsz * self.num_heads, -1, self.head_dim 129 | ) 130 | k = torch.cat([prev_key, k], dim=1) 131 | v = torch.cat([prev_value, v], dim=1) 132 | incremental_state["prev_key"] = k.view( 133 | bsz, self.num_heads, -1, self.head_dim 134 | ) 135 | incremental_state["prev_value"] = v.view( 136 | bsz, self.num_heads, -1, self.head_dim 137 | ) 138 | src_len = k.size(1) 139 | 140 | if isinstance(rel_pos, tuple): # XPos implementation 141 | sin, cos, scale = rel_pos 142 | if self.self_attention: 143 | k = apply_rotary_pos_emb(k, sin, cos, scale = 1 / scale) 144 | q = apply_rotary_pos_emb(q, sin[-q.shape[1]:], cos[-q.shape[1]:], scale = scale[-q.shape[1]:]) 145 | else: 146 | k = apply_rotary_pos_emb(k, sin[:k.shape[1]], cos[:k.shape[1]], scale = 1 / scale[:k.shape[1]]) 147 | q = apply_rotary_pos_emb(q, sin[k.shape[1]:], cos[k.shape[1]:], scale = scale[k.shape[1]:]) 148 | 149 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 150 | if attn_mask is not None: 151 | attn_weights = torch.nan_to_num(attn_weights) 152 | attn_mask = attn_mask.unsqueeze(0) 153 | attn_weights += attn_mask 154 | 155 | if key_padding_mask is not None: 156 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 157 | attn_weights = attn_weights.masked_fill( 158 | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 159 | float("-inf"), 160 | ) 161 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 162 | 163 | if isinstance(rel_pos, torch.Tensor): 164 | rel_pos = rel_pos.view(attn_weights.size()) 165 | attn_weights = attn_weights + rel_pos 166 | 167 | attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( 168 | attn_weights 169 | ) 170 | attn_probs = self.dropout_module(attn_weights) 171 | attn = torch.bmm(attn_probs, v) 172 | if bsz > key_bsz: # merge block 173 | attn = attn.view(-1, key_bsz * self.num_heads, self.half_block_size, self.head_dim).transpose(1, 2).reshape(-1, key_bsz, embed_dim) 174 | else: 175 | attn = attn.transpose(0, 1).reshape(-1, bsz, embed_dim) 176 | 177 | if self.inner_attn_ln is not None: 178 | attn = self.inner_attn_ln(attn) 179 | 180 | attn = self.out_proj(attn) 181 | attn_weights = attn_weights.view( 182 | bsz, self.num_heads, tgt_len, src_len 183 | ).transpose(1, 0) 184 | return attn, attn_weights 185 | -------------------------------------------------------------------------------- /torchscale/component/multiway_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import copy 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def MultiwayWrapper(args, module, dim=0): 11 | if args.multiway: 12 | return MultiwayNetwork(module, dim=dim) 13 | return module 14 | 15 | 16 | def set_split_position(position): 17 | def apply_fn(module): 18 | if hasattr(module, "split_position"): 19 | module.split_position = position 20 | 21 | return apply_fn 22 | 23 | 24 | class MultiwayNetwork(nn.Module): 25 | def __init__(self, module, dim=0): 26 | super().__init__() 27 | self.dim = dim 28 | self.A = module 29 | self.B = copy.deepcopy(module) 30 | self.B.reset_parameters() 31 | self.split_position = -1 32 | 33 | def forward(self, x, **kwargs): 34 | if self.split_position == -1: 35 | return self.A(x, **kwargs) 36 | if self.split_position == 0: 37 | return self.B(x, **kwargs) 38 | x1, x2 = torch.split( 39 | x, 40 | [self.split_position, x.size(self.dim) - self.split_position], 41 | dim=self.dim, 42 | ) 43 | # x1, x2 = x[:self.split_position], x[self.split_position:] 44 | y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs) 45 | return torch.cat([y1, y2], dim=self.dim) 46 | -------------------------------------------------------------------------------- /torchscale/component/relative_position_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class RelativePositionBias(nn.Module): 11 | def __init__( 12 | self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12 13 | ): 14 | super().__init__() 15 | self.bidirectional = bidirectional 16 | self.num_buckets = num_buckets 17 | self.max_distance = max_distance 18 | self.n_heads = n_heads 19 | self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads) 20 | 21 | @staticmethod 22 | def _relative_position_bucket( 23 | relative_position, bidirectional=True, num_buckets=32, max_distance=128 24 | ): 25 | ret = 0 26 | n = -relative_position 27 | if bidirectional: 28 | num_buckets //= 2 29 | ret += (n < 0).to(torch.long) * num_buckets 30 | n = torch.abs(n) 31 | else: 32 | n = torch.max(n, torch.zeros_like(n)) 33 | 34 | max_exact = num_buckets // 2 35 | is_small = n < max_exact 36 | 37 | val_if_large = max_exact + ( 38 | torch.log(n.float() / max_exact) 39 | / math.log(max_distance / max_exact) 40 | * (num_buckets - max_exact) 41 | ).to(torch.long) 42 | val_if_large = torch.min( 43 | val_if_large, torch.full_like(val_if_large, num_buckets - 1) 44 | ) 45 | 46 | ret += torch.where(is_small, n, val_if_large) 47 | return ret 48 | 49 | def compute_bias(self, qlen, klen, step=None): 50 | step = 0 if step is None else step 51 | context_position = torch.arange( 52 | step, 53 | step + qlen, 54 | dtype=torch.long, 55 | device=self.relative_attention_bias.weight.device, 56 | )[:, None] 57 | memory_position = torch.arange( 58 | klen, dtype=torch.long, device=self.relative_attention_bias.weight.device 59 | )[None, :] 60 | relative_position = memory_position - context_position # shape (qlen, klen) 61 | 62 | rp_bucket = self._relative_position_bucket( 63 | relative_position, # shape (qlen, klen) 64 | bidirectional=self.bidirectional, 65 | num_buckets=self.num_buckets, 66 | ) 67 | rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) 68 | values = self.relative_attention_bias( 69 | rp_bucket 70 | ) # shape (qlen, klen, num_heads) 71 | values = values.permute([2, 0, 1]).unsqueeze( 72 | 0 73 | ) # shape (1, num_heads, qlen, klen) 74 | return values 75 | 76 | def forward(self, batch_size, qlen, klen, step=None): 77 | # shape (batch * num_heads, qlen, klen) 78 | return ( 79 | self.compute_bias(qlen, klen, step) 80 | .repeat(batch_size, 1, 1, 1) 81 | .view(-1, qlen, klen) 82 | ) 83 | -------------------------------------------------------------------------------- /torchscale/component/xmoe/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /torchscale/component/xmoe/moe_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 5 | # 6 | # This source code is licensed under the BSD license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # NOTE: This is a mirror of the code in 10 | # https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe 11 | 12 | import logging 13 | import time 14 | from typing import Any, Tuple, cast 15 | 16 | import torch 17 | import torch.distributed as dist 18 | from torch import Tensor 19 | from torch.nn import Module, ModuleList 20 | 21 | try: 22 | from fairseq.modules.moe import MOELayer 23 | 24 | has_fairseq = True 25 | Base = MOELayer 26 | except ModuleNotFoundError: 27 | Base = Module 28 | has_fairseq = False 29 | 30 | try: 31 | # To enable Tutel MoE optimizations: 32 | # python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x 33 | from tutel import moe as tutel_moe 34 | 35 | has_tutel, fused_cumsum_sub_one = True, tutel_moe.fast_cumsum_sub_one 36 | except ModuleNotFoundError: 37 | has_tutel, fused_cumsum_sub_one = False, lambda mask: torch.cumsum(mask, dim=0) - 1 38 | 39 | logger = logging.getLogger(__name__) 40 | 41 | 42 | # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity 43 | # See https://arxiv.org/pdf/2006.16668.pdf for details. 44 | 45 | # Based on https://github.com/pytorch/pytorch/pull/40762 46 | class _AllToAll(torch.autograd.Function): 47 | @staticmethod 48 | def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore 49 | ctx.group = group 50 | input = input.contiguous() 51 | output = torch.empty_like(input) 52 | if torch.distributed.is_initialized(): 53 | dist.all_to_all_single(output, input, group=group) 54 | else: 55 | assert group is None 56 | output = input 57 | return output 58 | 59 | @staticmethod 60 | def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: 61 | return (None, _AllToAll.apply(ctx.group, *grad_output)) 62 | 63 | 64 | def _find_my_group_index(grouped_ranks): 65 | my_rank = dist.get_rank() 66 | for i, group in enumerate(grouped_ranks): 67 | if my_rank in group: 68 | return i 69 | raise RuntimeError 70 | 71 | 72 | def get_moe_group(moe_expert_count): 73 | if dist.is_initialized(): 74 | if not hasattr(get_moe_group, "_moe_groups"): 75 | world_size = dist.get_world_size() 76 | 77 | if world_size <= moe_expert_count: 78 | assert moe_expert_count % world_size == 0 79 | moe_groups = [[i] for i in range(world_size)] 80 | 81 | else: 82 | assert world_size % moe_expert_count == 0 83 | ranks_per_group = world_size // moe_expert_count 84 | moe_groups = [ 85 | [i + j * moe_expert_count for j in range(ranks_per_group)] 86 | for i in range(moe_expert_count) 87 | ] 88 | 89 | get_moe_group._moe_group_idx = moe_groups 90 | get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups] 91 | 92 | my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx) 93 | return get_moe_group._moe_groups[my_group_idx] 94 | 95 | 96 | def get_all2all_group(moe_expert_count): 97 | if dist.is_initialized(): 98 | if not hasattr(get_all2all_group, "_all2all_groups"): 99 | world_size = dist.get_world_size() 100 | 101 | # more experts than world size 102 | if world_size <= moe_expert_count: 103 | assert moe_expert_count % world_size == 0 104 | all2all_groups = [[i for i in range(world_size)]] 105 | 106 | # larger world than num experts 107 | else: 108 | assert world_size % moe_expert_count == 0 109 | ranks_per_group = world_size // moe_expert_count 110 | all2all_groups = [ 111 | [i * moe_expert_count + j for j in range(moe_expert_count)] 112 | for i in range(ranks_per_group) 113 | ] 114 | 115 | get_all2all_group._all2all_group_idx = all2all_groups 116 | get_all2all_group._all2all_groups = [ 117 | dist.new_group(g) for g in all2all_groups 118 | ] 119 | 120 | my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx) 121 | return get_all2all_group._all2all_groups[my_group_idx] 122 | 123 | 124 | class MOELayer(Base): 125 | """MOELayer module which implements MixtureOfExperts as described in Gshard_. 126 | :: 127 | 128 | gate = Top2Gate(model_dim, num_experts) 129 | moe = MOELayer(gate, expert) 130 | output = moe(input) 131 | l_aux = moe.l_aux 132 | 133 | .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf 134 | 135 | Args: 136 | gate (torch.nn.Module): 137 | gate network 138 | expert (torch.nn.Module): 139 | expert network 140 | """ 141 | 142 | def __init__(self, gate, experts, args): 143 | if has_fairseq: 144 | super(Base, self).__init__() 145 | else: 146 | super().__init__() 147 | self.gate = gate 148 | if type(experts) == ModuleList: 149 | self.experts = cast(ModuleList, experts) 150 | else: 151 | self.experts = ModuleList([experts]) 152 | self.expert_group = get_moe_group(args.moe_expert_count) 153 | self.all2all_group = get_all2all_group(args.moe_expert_count) 154 | self.world_size = dist.get_world_size(group=self.expert_group) 155 | self.all2all_size = dist.get_world_size(group=self.all2all_group) 156 | for p in experts.parameters(): 157 | p.expert = True # type: ignore 158 | self.num_local_experts = len(self.experts) 159 | self.args = args 160 | self.in_generation = False 161 | self.a2a_cuda_event_intervals = [] 162 | self.a2a_cpu_time_ms = 0.0 163 | 164 | def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor: 165 | assert len(input) == 1, "only single input Tensor supported" 166 | input = input[0] 167 | assert ( 168 | len(input.shape) == 3 169 | ), "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel" 170 | if input_padding_mask is not None: 171 | assert ( 172 | len(input_padding_mask.shape) == 2 173 | ), "input Tensor must have dimensions: (s)equence, (t)oken" 174 | assert input_padding_mask.shape[0] == input.shape[0] 175 | assert input_padding_mask.shape[1] == input.shape[1] 176 | # assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts" 177 | 178 | # Implement Algorithm 2 from GShard paper. 179 | d_model = input.shape[2] 180 | # Pad to expected batch size 181 | input_shape = list(input.shape) 182 | expected_bsz = ( 183 | getattr(self.args, "batch_size", 0) 184 | if self.training 185 | else getattr(self.args, "batch_size_valid", 0) 186 | ) 187 | # This indicates that --batch-size or --max-sentences is not specified 188 | if expected_bsz is None: 189 | expected_bsz = 0 190 | # Note: Padding is not necessary at generation time at present 191 | # because all DDP workers process the same batch. Also, batch size at generation time 192 | # can be different from that present in the checkpoint state 193 | if ( 194 | not self.in_generation 195 | and expected_bsz != 0 196 | and input_shape[0] != expected_bsz 197 | ): 198 | logger.warning( 199 | f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})" 200 | ) 201 | assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}" 202 | padded_input = torch.zeros( 203 | (expected_bsz, input_shape[1], input_shape[2]), 204 | dtype=input.dtype, 205 | layout=input.layout, 206 | device=input.device, 207 | ) 208 | padded_input[: input_shape[0], :, :] = input 209 | input = padded_input 210 | 211 | padded_input_padding_mask = torch.ones( 212 | ( 213 | expected_bsz, 214 | input_shape[1], 215 | ), 216 | dtype=torch.bool, 217 | device=input.device, 218 | ) 219 | if input_padding_mask is not None: 220 | padded_input_padding_mask[: input_shape[0], :] = input_padding_mask 221 | else: 222 | padded_input_padding_mask[: input_shape[0], :] = False 223 | input_padding_mask = padded_input_padding_mask 224 | 225 | # Reshape into S tokens by dropping sequence dimension. 226 | reshaped_input = input.reshape(-1, d_model) 227 | reshaped_input_shape = reshaped_input.shape 228 | reshaped_input_padding_mask = ( 229 | input_padding_mask.reshape(-1) if input_padding_mask is not None else None 230 | ) 231 | 232 | # Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences 233 | # Pro of --max-tokens: more flexible for MT variable sequence lengths 234 | # Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM 235 | if expected_bsz == 0: 236 | expected_dim = reshaped_input_shape[0] * torch.ones( 237 | (1,), dtype=torch.long, device=input.device 238 | ) 239 | dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX) 240 | expected_dim = int(expected_dim.item()) 241 | padded_input = torch.zeros( 242 | (expected_dim, reshaped_input_shape[1]), 243 | dtype=input.dtype, 244 | layout=input.layout, 245 | device=input.device, 246 | ) 247 | padded_input[: reshaped_input_shape[0], :] = reshaped_input 248 | reshaped_input = padded_input 249 | 250 | padded_input_padding_mask = torch.ones( 251 | (expected_dim,), dtype=torch.bool, device=padded_input.device 252 | ) 253 | if reshaped_input_padding_mask is not None: 254 | padded_input_padding_mask[ 255 | : reshaped_input_shape[0] 256 | ] = reshaped_input_padding_mask 257 | else: 258 | padded_input_padding_mask[: reshaped_input_shape[0]] = False 259 | reshaped_input_padding_mask = padded_input_padding_mask 260 | 261 | if has_tutel: 262 | l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate( 263 | reshaped_input, reshaped_input_padding_mask 264 | ) 265 | S, M = reshaped_input.size(0), reshaped_input.size(1) 266 | 267 | if not hasattr(self, "_tutel_dispatcher"): 268 | self._tutel_dispatcher = tutel_moe.fast_dispatcher( 269 | E, C, M, dispatch_dtype=reshaped_input.dtype 270 | ) 271 | self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C) 272 | dispatched_input = self._tutel_dispatcher.encode(reshaped_input) 273 | else: 274 | l_aux, combine_weights, dispatch_mask, self.metadata = self.gate( 275 | reshaped_input, reshaped_input_padding_mask 276 | ) 277 | 278 | dispatch_mask = dispatch_mask.to(input.dtype).permute( 279 | 1, 2, 0 280 | ) # S,E,C -> E,C,S 281 | E, C, S = dispatch_mask.size() 282 | M = reshaped_input.size(1) 283 | assert reshaped_input.size() == (S, M) 284 | # einsum("sec,sm->ecm") 285 | dispatched_input = torch.mm( 286 | dispatch_mask.view(E * C, S), reshaped_input 287 | ) # -> (E*C),M 288 | 289 | if self.all2all_size > 1: 290 | dispatched_input = self.all_to_all_wrapper(dispatched_input) 291 | 292 | # Re-shape after all-to-all: ecm -> gecm 293 | dispatched_input = dispatched_input.reshape( 294 | self.all2all_size, self.num_local_experts, -1, d_model 295 | ) 296 | chunks = dispatched_input.chunk(self.num_local_experts, dim=1) 297 | expert_outputs = [] 298 | for chunk, expert in zip(chunks, self.experts): 299 | expert_outputs += [expert(chunk)] 300 | expert_output = torch.cat(expert_outputs, dim=1) 301 | 302 | if self.all2all_size > 1: 303 | expert_output = self.all_to_all_wrapper(expert_output) 304 | 305 | # Re-shape back: gecm -> ecm 306 | expert_output = expert_output.reshape( 307 | self.all2all_size * self.num_local_experts, -1, d_model 308 | ) 309 | 310 | if has_tutel: 311 | combined_output = self._tutel_dispatcher.decode( 312 | expert_output.view(E * C, M) 313 | ) 314 | else: 315 | # einsum("sec,ecm->sm") 316 | combined_output = combine_weights.view(S, E * C).mm( 317 | expert_output.view(E * C, M) 318 | ) 319 | 320 | # Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences 321 | combined_output = combined_output[: reshaped_input_shape[0], :] 322 | combined_output = combined_output.reshape(input.shape) 323 | combined_output = combined_output[: input_shape[0], :, :] 324 | 325 | self.record_all_to_all_stats() 326 | 327 | return combined_output, l_aux 328 | 329 | def prepare_for_inference_(self): 330 | self.in_generation = True 331 | 332 | def all_to_all_wrapper(self, input: Tensor): 333 | dummy_a2a = getattr(self.args, "dummy_a2a", False) 334 | if dummy_a2a: 335 | input = input.contiguous() 336 | output = input.detach().clone() 337 | return input 338 | # always record times, since it is not a lot of overhead 339 | # if we do not log it we simply clear it off in record_all_to_all_stats 340 | cuda_start = torch.cuda.Event(enable_timing=True) 341 | cuda_end = torch.cuda.Event(enable_timing=True) 342 | cpu_start = time.time() * 1000 343 | cuda_start.record() 344 | output = _AllToAll.apply(self.all2all_group, input) 345 | cuda_end.record() 346 | cpu_end = time.time() * 1000 347 | self.a2a_cpu_time_ms += cpu_end - cpu_start 348 | self.a2a_cuda_event_intervals.append((cuda_start, cuda_end)) 349 | return output 350 | 351 | def record_all_to_all_stats(self): 352 | # controlled via an argument as we want to minimize any impact from torch.cuda.synchronize() 353 | record_a2a_perf_stats = getattr(self.args, "record_a2a_perf_stats", False) 354 | if record_a2a_perf_stats: 355 | torch.cuda.synchronize() 356 | self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms 357 | a2a_cuda_time_ms = 0.0 358 | for ev_start, ev_end in self.a2a_cuda_event_intervals: 359 | a2a_cuda_time_ms += ev_start.elapsed_time(ev_end) 360 | self.metadata["all_to_all_cuda_time_ms"] = a2a_cuda_time_ms 361 | # reset stats 362 | self.a2a_cpu_time_ms = 0.0 363 | self.a2a_cuda_event_intervals = [] 364 | -------------------------------------------------------------------------------- /torchscale/component/xmoe/routing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 5 | # 6 | # This source code is licensed under the BSD license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # Implementation of Top2Gating described in https://arxiv.org/pdf/2006.16668.pdf 10 | # Code is inspired by Top2GatingOnLogits from lingvo: 11 | # https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477 12 | 13 | # NOTE: This is a mirror of the code in 14 | # https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe 15 | 16 | import math 17 | from typing import Callable, Dict, Optional, Tuple 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | from torch import Tensor 22 | 23 | from .moe_layer import fused_cumsum_sub_one, has_tutel 24 | 25 | # use a fixed temperature to compute balance loss 26 | TEMPERATURE_FOR_L_UAX = 0.07 27 | 28 | # maximum capacity of 1 expert as a fraction of number of tokens in the batch 29 | # Note: setting this to 1.0 causes inference to significantly slow down 30 | EVAL_CAPACITY_TOKEN_FRACTION = 0.25 31 | 32 | # logging 33 | SAMPLE_FRACTION = 0.2 34 | 35 | 36 | def top1gating( 37 | logits: torch.Tensor, 38 | input_mask: Optional[torch.Tensor] = None, 39 | use_fp32=False, 40 | capacity_factor=1.0, 41 | eval_mode=False, 42 | moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION, 43 | use_xmoe=False, 44 | gate_obj=None, 45 | ) -> Tuple[Tensor, Tensor, Tensor, Dict]: 46 | """Implements Top2Gating on logits.""" 47 | metadata = {} 48 | if use_fp32: 49 | orig_dtype = logits.dtype 50 | logits = logits.float() 51 | 52 | gates = F.softmax(logits, dim=1) 53 | metadata["entropy_gating"] = entropy(probs=gates).mean().detach() 54 | 55 | # gates has shape of SE 56 | num_tokens = gates.shape[0] 57 | num_experts = gates.shape[1] 58 | if moe_eval_capacity_token_fraction > 0.0 and eval_mode: 59 | capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens) 60 | else: 61 | # capacity = capacity_factor * S/E 62 | capacity = int(capacity_factor * math.ceil(num_tokens / num_experts)) 63 | 64 | # Create a mask for 1st's expert per token 65 | indices1_s = torch.argmax(gates, dim=1) 66 | mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True) 67 | if input_mask is not None and input_mask.any(): 68 | nonpadding = ~input_mask 69 | mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype) 70 | 71 | # for logging (percent of tokens routed to each expert) 72 | expert1_hist = ( 73 | 100 74 | * torch.histc( 75 | (indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts 76 | ) 77 | / num_tokens 78 | ) 79 | metadata["unused_expert1_count"] = (expert1_hist == 0).sum() 80 | expert1_hist = ( 81 | torch.sort(expert1_hist, dim=0, descending=True).values 82 | + torch.finfo(torch.float32).tiny 83 | ) 84 | 85 | sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1) 86 | metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum() 87 | metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum() 88 | 89 | gates1_s = (gates * mask1).sum(dim=1) 90 | 91 | # Compute locations in capacity buffer 92 | locations1 = fused_cumsum_sub_one(mask1) 93 | 94 | # Compute l_aux 95 | me = torch.mean(gates, dim=0) 96 | ce = torch.mean(mask1.to(gates.dtype), dim=0) 97 | 98 | l_aux = torch.mean(me * ce) 99 | l_aux = l_aux * num_experts * num_experts 100 | 101 | if has_tutel: 102 | locations1_s = torch.sum(locations1 * mask1, dim=1) 103 | return ( 104 | l_aux, 105 | metadata, 106 | capacity, 107 | num_experts, 108 | [ 109 | indices1_s, 110 | ], 111 | [ 112 | locations1_s, 113 | ], 114 | [ 115 | gates1_s, 116 | ], 117 | ) 118 | 119 | # Remove locations outside capacity from mask 120 | mask1 = mask1 * torch.lt(locations1, capacity) 121 | # Store the capacity location for each token 122 | locations1_s = torch.sum(locations1 * mask1, dim=1) 123 | 124 | # Calculate combine_weights and dispatch_mask 125 | gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se") 126 | # locations1_sc = num_tokens * capacity 127 | locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True) 128 | combine1_sec = torch.bmm( 129 | # einsum("se,sc->sec") 130 | gates1.unsqueeze(-1), 131 | locations1_sc.to(gates1.dtype).unsqueeze(1), 132 | ) 133 | dispatch_mask = combine1_sec.bool() 134 | if use_fp32: 135 | return l_aux, combine1_sec.to(orig_dtype), dispatch_mask, metadata 136 | else: 137 | return l_aux, combine1_sec, dispatch_mask, metadata 138 | 139 | 140 | class Top1Gate(torch.nn.Module): 141 | """Gate module which implements Top2Gating as described in Gshard_. 142 | :: 143 | 144 | gate = Top2Gate(model_dim, num_experts) 145 | l_aux, combine_weights, dispatch_mask = gate(input) 146 | 147 | .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf 148 | 149 | Args: 150 | model_dim (int): 151 | size of model embedding dimension 152 | num_experts (ints): 153 | number of experts in model 154 | """ 155 | 156 | wg: torch.nn.Linear 157 | 158 | def __init__( 159 | self, 160 | model_dim: int, 161 | num_experts: int, 162 | use_fp32=False, 163 | input_noise_type=None, 164 | capacity_factor=1.0, 165 | moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION, 166 | use_xmoe=False, 167 | ) -> None: 168 | # TODO: merge this to top2gate.py 169 | # 170 | super().__init__() 171 | 172 | if not use_xmoe: 173 | self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) 174 | else: 175 | self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False) 176 | wg = torch.empty(num_experts, 16) 177 | torch.nn.init.orthogonal_(wg, gain=0.32) 178 | self.register_parameter("wg", torch.nn.Parameter(wg)) 179 | 180 | self.use_xmoe = use_xmoe 181 | self.use_fp32 = use_fp32 182 | self.input_noise_type = input_noise_type 183 | self.capacity_factor = capacity_factor 184 | self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction 185 | 186 | def forward(self, input, mask=None): # type: ignore 187 | if self.use_xmoe: 188 | input = self.wg_reduction(input) 189 | with torch.no_grad(): 190 | wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True) 191 | self.wg.mul_(1.5 / wg_norm) 192 | logits = self._cosine(input, self.wg) 193 | logits = self._make_finite(logits) 194 | else: 195 | logits = self.wg(input) 196 | 197 | return top1gating( 198 | logits, 199 | mask, 200 | use_fp32=self.use_fp32, 201 | capacity_factor=self.capacity_factor, 202 | eval_mode=not self.training, 203 | moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction, 204 | use_xmoe=self.use_xmoe, 205 | gate_obj=self, 206 | ) 207 | 208 | def _make_finite(self, scores): 209 | ok = scores.isfinite() 210 | if not ok.all(): 211 | # NaNs here can break the assignment algorithm 212 | scores[~ok] = scores[ok].min() 213 | return scores 214 | 215 | def _get_gating_temperature(self, eps=1e-4): 216 | if self.gating_t.data.item() < eps: 217 | return eps 218 | return self.gating_t 219 | 220 | def _cosine(self, mat1, mat2, eps=1e-4): 221 | assert mat1.dim() == 2 222 | assert mat2.dim() == 2 223 | # mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps) 224 | mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps) 225 | return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1) 226 | 227 | 228 | gumbel_map: Dict[torch.device, Callable] = {} 229 | 230 | 231 | def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: 232 | gumbel = gumbel_map.get(device) 233 | if gumbel is None: 234 | one = torch.tensor(1.0, device=device) 235 | zero = torch.tensor(0.0, device=device) 236 | gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore 237 | gumbel_map[device] = gumbel 238 | return gumbel(shape) 239 | 240 | 241 | def one_hot(indices: torch.Tensor, num_classes: int, unsqueeze_indices=False) -> Tensor: 242 | if unsqueeze_indices: 243 | indices = indices.unsqueeze(-1) 244 | assert indices.shape[-1] == 1, "last dimension of indices must be have size 1" 245 | output = torch.zeros( 246 | indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype 247 | ) 248 | output.scatter_(len(output.shape) - 1, indices, 1) 249 | return output 250 | 251 | 252 | def entropy(probs): 253 | logits = torch.distributions.utils.probs_to_logits(probs) 254 | p_log_p = probs * logits 255 | return -p_log_p.sum(-1) 256 | 257 | 258 | def top2gating( 259 | logits: torch.Tensor, 260 | input_mask: Optional[torch.Tensor] = None, 261 | use_fp32=False, 262 | second_expert_policy="sampling", 263 | normalize_gate_prob_before_dropping=False, 264 | eval_mode=False, 265 | moe_eval_capacity_token_fraction=0.25, 266 | batch_prioritized_routing=False, 267 | ) -> Tuple[Tensor, Tensor, Tensor]: 268 | """Implements Top2Gating on logits.""" 269 | metadata = {} 270 | if use_fp32: 271 | orig_dtype = logits.dtype 272 | logits = logits.float() 273 | gates = F.softmax(logits, dim=1) 274 | metadata["entropy_gating"] = entropy(probs=gates).mean().detach() 275 | # gates has shape of SE 276 | num_tokens = gates.shape[0] 277 | num_experts = gates.shape[1] 278 | if moe_eval_capacity_token_fraction > 0.0 and eval_mode: 279 | capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens) 280 | else: 281 | # capacity = 2S/E 282 | capacity = 2 * math.ceil(num_tokens / num_experts) 283 | 284 | # Create a mask for 1st's expert per token 285 | indices1_s = torch.argmax(gates, dim=1, keepdim=True) 286 | mask1 = one_hot(indices1_s, num_experts) 287 | if second_expert_policy == "sampling": 288 | # Create a mask for 2nd's expert per token using Gumbel-max trick 289 | # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ 290 | logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) 291 | else: 292 | logits_w_noise = logits 293 | # Replace top-expert with min value 294 | logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf")) 295 | indices2_s = torch.argmax(logits_except1, dim=1, keepdim=True) 296 | mask2 = one_hot(indices2_s, num_experts) 297 | gates1_s = (gates * mask1).sum(dim=1) 298 | gates2_s = (gates * mask2).sum(dim=1) 299 | 300 | if normalize_gate_prob_before_dropping: 301 | # Normalize gate probabilities 302 | denom_s = gates1_s + gates2_s 303 | # Avoid divide-by-zero 304 | denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps) 305 | gates1_s = gates1_s / denom_s 306 | gates2_s = gates2_s / denom_s 307 | 308 | if second_expert_policy == "random": 309 | sampled = (2 * gates2_s) > torch.rand_like(gates2_s) 310 | mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0) 311 | 312 | # Compute locations in capacity buffer 313 | if input_mask is not None and input_mask.any(): 314 | nonpadding = ~input_mask 315 | mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype) 316 | mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype) 317 | 318 | if batch_prioritized_routing: 319 | # if batch_prioritized_routing: 320 | importance_scores = -1 * gates.max(dim=1)[0] 321 | sorted_mask1 = mask1[importance_scores.argsort(dim=0)] 322 | sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1 323 | importance_sorted_locations1 = sorted_cumsum1[ 324 | importance_scores.argsort(dim=0).argsort(dim=0) 325 | ] 326 | 327 | sorted_mask2 = mask2[importance_scores.argsort(dim=0)] 328 | sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2 329 | importance_sorted_locations2 = sorted_cumsum2[ 330 | importance_scores.argsort(dim=0).argsort(dim=0) 331 | ] 332 | 333 | importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True) 334 | 335 | locations1, locations2 = ( 336 | importance_sorted_locations1, 337 | importance_sorted_locations2, 338 | ) 339 | else: 340 | locations1 = fused_cumsum_sub_one(mask1) 341 | locations2 = fused_cumsum_sub_one(mask2) 342 | # Update 2nd's location by accounting for locations of 1st 343 | locations2 += torch.sum(mask1, dim=0, keepdim=True) 344 | 345 | # Compute l_aux 346 | me = torch.mean(gates, dim=0) 347 | ce = torch.mean(mask1.to(gates.dtype), dim=0) 348 | l_aux = torch.mean(me * ce) 349 | l_aux = l_aux * num_experts * num_experts 350 | 351 | # for logging purposes 352 | metadata["overflow_expert1"] = ( 353 | 100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1) 354 | ) 355 | metadata["overflow_expert2"] = ( 356 | 100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2) 357 | ) 358 | 359 | # Remove locations outside capacity from mask 360 | mask1_, mask2_ = mask1, mask2 361 | mask1 = mask1 * torch.lt(locations1, capacity) 362 | mask2 = mask2 * torch.lt(locations2, capacity) 363 | 364 | # for logging (percent of tokens routed to each expert) 365 | expert1_hist = ( 366 | 100 367 | * torch.histc( 368 | (indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts 369 | ) 370 | / num_tokens 371 | ) 372 | metadata["unused_expert1_count"] = (expert1_hist == 0).sum() 373 | expert1_hist = ( 374 | torch.sort(expert1_hist, dim=0, descending=True).values 375 | + torch.finfo(torch.float32).tiny 376 | ) 377 | 378 | expert2_hist = ( 379 | 100 380 | * torch.histc( 381 | (indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts 382 | ) 383 | / num_tokens 384 | ) 385 | metadata["unused_expert2_count"] = (expert2_hist == 0).sum() 386 | expert2_hist = ( 387 | torch.sort(expert2_hist, dim=0, descending=True).values 388 | + torch.finfo(torch.float32).tiny 389 | ) 390 | 391 | sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1) 392 | metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum() 393 | metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum() 394 | 395 | metadata["expert2_balance_top"] = expert2_hist[:sample_count].sum() 396 | metadata["expert2_balance_bottom"] = expert2_hist[-sample_count:].sum() 397 | 398 | if not normalize_gate_prob_before_dropping: 399 | # Normalize gate probabilities 400 | gates1_s = (gates * mask1).sum(dim=1) 401 | gates2_s = (gates * mask2).sum(dim=1) 402 | denom_s = gates1_s + gates2_s 403 | # Avoid divide-by-zero 404 | denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps) 405 | gates1_s /= denom_s 406 | gates2_s /= denom_s 407 | 408 | if has_tutel: 409 | locations1_s = torch.sum(locations1 * mask1_, dim=1) 410 | locations2_s = torch.sum(locations2 * mask2_, dim=1) 411 | return ( 412 | l_aux, 413 | metadata, 414 | capacity, 415 | num_experts, 416 | [indices1_s, indices2_s], 417 | [locations1_s, locations2_s], 418 | [gates1_s, gates2_s], 419 | ) 420 | 421 | # Store the capacity location for each token 422 | locations1_s = torch.sum(locations1 * mask1, dim=1) 423 | locations2_s = torch.sum(locations2 * mask2, dim=1) 424 | 425 | # Calculate combine_weights and dispatch_mask 426 | gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se") 427 | gates2 = gates2_s.unsqueeze(-1) * mask2.to(gates2_s.dtype) # einsum("s,se->se") 428 | locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True) 429 | locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True) 430 | combine1_sec = torch.bmm( 431 | # einsum("se,sc->sec") 432 | gates1.unsqueeze(-1), 433 | locations1_sc.to(gates1.dtype).unsqueeze(1), 434 | ) 435 | combine2_sec = torch.bmm( 436 | # einsum("se,sc->sec") 437 | gates2.unsqueeze(-1), 438 | locations2_sc.to(gates2.dtype).unsqueeze(1), 439 | ) 440 | combine_weights = combine1_sec + combine2_sec 441 | dispatch_mask = combine_weights.bool() 442 | if use_fp32: 443 | return l_aux, combine_weights.to(orig_dtype), dispatch_mask, metadata 444 | else: 445 | return l_aux, combine_weights, dispatch_mask, metadata 446 | 447 | 448 | class Top2Gate(torch.nn.Module): 449 | """Gate module which implements Top2Gating as described in Gshard_. 450 | :: 451 | 452 | gate = Top2Gate(model_dim, num_experts) 453 | l_aux, combine_weights, dispatch_mask = gate(input) 454 | 455 | .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf 456 | 457 | Args: 458 | model_dim (int): 459 | size of model embedding dimension 460 | num_experts (ints): 461 | number of experts in model 462 | """ 463 | 464 | wg: torch.nn.Linear 465 | 466 | def __init__( 467 | self, 468 | model_dim: int, 469 | num_experts: int, 470 | use_fp32=False, 471 | second_expert_policy="sampling", 472 | normalize_gate_prob_before_dropping=False, 473 | moe_eval_capacity_token_fraction=0.25, 474 | batch_prioritized_routing=False, 475 | use_xmoe=False, 476 | ) -> None: 477 | super().__init__() 478 | if not use_xmoe: 479 | self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) 480 | else: 481 | self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False) 482 | wg = torch.empty(num_experts, 16) 483 | torch.nn.init.orthogonal_(wg, gain=0.32) 484 | self.register_parameter("wg", torch.nn.Parameter(wg)) 485 | self.use_fp32 = use_fp32 486 | self.second_expert_policy = second_expert_policy 487 | self.normalize_gate_prob_before_dropping = normalize_gate_prob_before_dropping 488 | self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction 489 | self.batch_prioritized_routing = batch_prioritized_routing 490 | self.use_xmoe = use_xmoe 491 | 492 | def forward(self, input, mask=None): # type: ignore 493 | if self.use_xmoe: 494 | input = self.wg_reduction(input) 495 | with torch.no_grad(): 496 | wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True) 497 | self.wg.mul_(1.5 / wg_norm) 498 | logits = self._cosine(input, self.wg) 499 | logits = self._make_finite(logits) 500 | else: 501 | logits = self.wg(input) 502 | return top2gating( 503 | logits, 504 | mask, 505 | use_fp32=self.use_fp32, 506 | second_expert_policy=self.second_expert_policy, 507 | normalize_gate_prob_before_dropping=self.normalize_gate_prob_before_dropping, 508 | eval_mode=not self.training, 509 | moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction, 510 | batch_prioritized_routing=self.batch_prioritized_routing, 511 | ) 512 | 513 | def _cosine(self, mat1, mat2, eps=1e-4): 514 | assert mat1.dim() == 2 515 | assert mat2.dim() == 2 516 | # mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps) 517 | mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps) 518 | return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1) 519 | 520 | def _make_finite(self, scores): 521 | ok = scores.isfinite() 522 | if not ok.all(): 523 | # NaNs here can break the assignment algorithm 524 | scores[~ok] = scores[ok].min() 525 | return scores 526 | -------------------------------------------------------------------------------- /torchscale/component/xpos_relative_position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | def fixed_pos_embedding(x): 8 | seq_len, dim = x.shape 9 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim)) 10 | sinusoid_inp = ( 11 | torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x) 12 | ) 13 | return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) 14 | 15 | 16 | class XPos(nn.Module): 17 | def __init__( 18 | self, head_dim, scale_base = 512 19 | ): 20 | super().__init__() 21 | self.head_dim = head_dim 22 | self.scale_base = scale_base 23 | self.register_buffer( 24 | "scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim) 25 | ) 26 | 27 | def forward(self, len): 28 | scale = self.scale ** (torch.arange(0, len, 1) - len // 2).to(self.scale).div(self.scale_base)[:, None] 29 | sin, cos = fixed_pos_embedding(scale) 30 | return (sin, cos, scale) 31 | -------------------------------------------------------------------------------- /torchscale/model/BEiT3.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torchscale.architecture.encoder import Encoder 8 | from torchscale.component.embedding import ( 9 | PositionalEmbedding, 10 | TextEmbedding, 11 | VisionEmbedding, 12 | ) 13 | from torchscale.component.multiway_network import MultiwayWrapper 14 | 15 | 16 | class BEiT3(nn.Module): 17 | def __init__(self, args, **kwargs): 18 | super().__init__() 19 | self.args = args 20 | assert args.multiway 21 | assert args.vocab_size > 0 22 | assert not args.share_encoder_input_output_embed 23 | self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim) 24 | self.vision_embed = VisionEmbedding( 25 | args.img_size, 26 | args.patch_size, 27 | args.in_chans, 28 | args.encoder_embed_dim, 29 | contain_mask_token=True, 30 | prepend_cls_token=True, 31 | ) 32 | embed_positions = MultiwayWrapper( 33 | args, 34 | PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim), 35 | dim=1, 36 | ) 37 | self.encoder = Encoder( 38 | args, 39 | embed_tokens=None, 40 | embed_positions=embed_positions, 41 | output_projection=None, 42 | is_encoder_decoder=False, 43 | ) 44 | 45 | def forward( 46 | self, 47 | textual_tokens=None, 48 | visual_tokens=None, 49 | text_padding_position=None, 50 | vision_masked_position=None, 51 | ): 52 | assert textual_tokens is not None or visual_tokens is not None 53 | 54 | if textual_tokens is None: 55 | x = self.vision_embed(visual_tokens, vision_masked_position) 56 | encoder_padding_mask = None 57 | multiway_split_position = -1 58 | elif visual_tokens is None: 59 | x = self.text_embed(textual_tokens) 60 | encoder_padding_mask = text_padding_position 61 | multiway_split_position = 0 62 | else: 63 | x1 = self.vision_embed(visual_tokens, vision_masked_position) 64 | multiway_split_position = x1.size(1) 65 | x2 = self.text_embed(textual_tokens) 66 | x = torch.cat([x1, x2], dim=1) 67 | 68 | if text_padding_position is not None: 69 | encoder_padding_mask = torch.cat( 70 | [ 71 | torch.zeros(x1.shape[:-1]).to(x1.device).bool(), 72 | text_padding_position, 73 | ], 74 | dim=1, 75 | ) 76 | else: 77 | encoder_padding_mask = None 78 | 79 | encoder_out = self.encoder( 80 | src_tokens=None, 81 | encoder_padding_mask=encoder_padding_mask, 82 | token_embeddings=x, 83 | multiway_split_position=multiway_split_position, 84 | ) 85 | 86 | return encoder_out 87 | -------------------------------------------------------------------------------- /torchscale/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | --------------------------------------------------------------------------------