├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── examples
├── __init__.py
└── fairseq
│ ├── README.md
│ ├── __init__.py
│ ├── generate.py
│ ├── interactive.py
│ ├── models
│ ├── __init__.py
│ ├── bert.py
│ ├── language_modeling.py
│ └── machine_translation.py
│ ├── tasks
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── basic_loader.py
│ │ ├── mlm_loader.py
│ │ └── utils.py
│ └── pretraining.py
│ ├── train.py
│ └── utils
│ ├── __init__.py
│ └── sparse_clip.py
├── setup.py
├── tests
├── __init__.py
├── test_decoder.py
├── test_encoder.py
└── test_encoder_decoder.py
└── torchscale
├── __init__.py
├── architecture
├── __init__.py
├── config.py
├── decoder.py
├── encoder.py
├── encoder_decoder.py
└── utils.py
├── component
├── __init__.py
├── droppath.py
├── embedding.py
├── feedforward_network.py
├── multihead_attention.py
├── multiway_network.py
├── relative_position_bias.py
├── xmoe
│ ├── __init__.py
│ ├── moe_layer.py
│ └── routing.py
└── xpos_relative_position.py
└── model
├── BEiT3.py
└── __init__.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 | ##
4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
5 |
6 | # User-specific files
7 | *.rsuser
8 | *.suo
9 | *.user
10 | *.userosscache
11 | *.sln.docstates
12 |
13 | # User-specific files (MonoDevelop/Xamarin Studio)
14 | *.userprefs
15 |
16 | # Mono auto generated files
17 | mono_crash.*
18 |
19 | # Build results
20 | [Dd]ebug/
21 | [Dd]ebugPublic/
22 | [Rr]elease/
23 | [Rr]eleases/
24 | x64/
25 | x86/
26 | [Aa][Rr][Mm]/
27 | [Aa][Rr][Mm]64/
28 | bld/
29 | [Bb]in/
30 | [Oo]bj/
31 | [Ll]og/
32 | [Ll]ogs/
33 |
34 | # Visual Studio 2015/2017 cache/options directory
35 | .vs/
36 | # Uncomment if you have tasks that create the project's static files in wwwroot
37 | #wwwroot/
38 |
39 | # Visual Studio 2017 auto generated files
40 | Generated\ Files/
41 |
42 | # MSTest test Results
43 | [Tt]est[Rr]esult*/
44 | [Bb]uild[Ll]og.*
45 |
46 | # NUnit
47 | *.VisualState.xml
48 | TestResult.xml
49 | nunit-*.xml
50 |
51 | # Build Results of an ATL Project
52 | [Dd]ebugPS/
53 | [Rr]eleasePS/
54 | dlldata.c
55 |
56 | # Benchmark Results
57 | BenchmarkDotNet.Artifacts/
58 |
59 | # .NET Core
60 | project.lock.json
61 | project.fragment.lock.json
62 | artifacts/
63 |
64 | # StyleCop
65 | StyleCopReport.xml
66 |
67 | # Files built by Visual Studio
68 | *_i.c
69 | *_p.c
70 | *_h.h
71 | *.ilk
72 | *.meta
73 | *.obj
74 | *.iobj
75 | *.pch
76 | *.pdb
77 | *.ipdb
78 | *.pgc
79 | *.pgd
80 | *.rsp
81 | *.sbr
82 | *.tlb
83 | *.tli
84 | *.tlh
85 | *.tmp
86 | *.tmp_proj
87 | *_wpftmp.csproj
88 | *.log
89 | *.vspscc
90 | *.vssscc
91 | .builds
92 | *.pidb
93 | *.svclog
94 | *.scc
95 |
96 | # Chutzpah Test files
97 | _Chutzpah*
98 |
99 | # Visual C++ cache files
100 | ipch/
101 | *.aps
102 | *.ncb
103 | *.opendb
104 | *.opensdf
105 | *.sdf
106 | *.cachefile
107 | *.VC.db
108 | *.VC.VC.opendb
109 |
110 | # Visual Studio profiler
111 | *.psess
112 | *.vsp
113 | *.vspx
114 | *.sap
115 |
116 | # Visual Studio Trace Files
117 | *.e2e
118 |
119 | # TFS 2012 Local Workspace
120 | $tf/
121 |
122 | # Guidance Automation Toolkit
123 | *.gpState
124 |
125 | # ReSharper is a .NET coding add-in
126 | _ReSharper*/
127 | *.[Rr]e[Ss]harper
128 | *.DotSettings.user
129 |
130 | # TeamCity is a build add-in
131 | _TeamCity*
132 |
133 | # DotCover is a Code Coverage Tool
134 | *.dotCover
135 |
136 | # AxoCover is a Code Coverage Tool
137 | .axoCover/*
138 | !.axoCover/settings.json
139 |
140 | # Visual Studio code coverage results
141 | *.coverage
142 | *.coveragexml
143 |
144 | # NCrunch
145 | _NCrunch_*
146 | .*crunch*.local.xml
147 | nCrunchTemp_*
148 |
149 | # MightyMoose
150 | *.mm.*
151 | AutoTest.Net/
152 |
153 | # Web workbench (sass)
154 | .sass-cache/
155 |
156 | # Installshield output folder
157 | [Ee]xpress/
158 |
159 | # DocProject is a documentation generator add-in
160 | DocProject/buildhelp/
161 | DocProject/Help/*.HxT
162 | DocProject/Help/*.HxC
163 | DocProject/Help/*.hhc
164 | DocProject/Help/*.hhk
165 | DocProject/Help/*.hhp
166 | DocProject/Help/Html2
167 | DocProject/Help/html
168 |
169 | # Click-Once directory
170 | publish/
171 |
172 | # Publish Web Output
173 | *.[Pp]ublish.xml
174 | *.azurePubxml
175 | # Note: Comment the next line if you want to checkin your web deploy settings,
176 | # but database connection strings (with potential passwords) will be unencrypted
177 | *.pubxml
178 | *.publishproj
179 |
180 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
181 | # checkin your Azure Web App publish settings, but sensitive information contained
182 | # in these scripts will be unencrypted
183 | PublishScripts/
184 |
185 | # NuGet Packages
186 | *.nupkg
187 | # NuGet Symbol Packages
188 | *.snupkg
189 | # The packages folder can be ignored because of Package Restore
190 | **/[Pp]ackages/*
191 | # except build/, which is used as an MSBuild target.
192 | !**/[Pp]ackages/build/
193 | # Uncomment if necessary however generally it will be regenerated when needed
194 | #!**/[Pp]ackages/repositories.config
195 | # NuGet v3's project.json files produces more ignorable files
196 | *.nuget.props
197 | *.nuget.targets
198 |
199 | # Microsoft Azure Build Output
200 | csx/
201 | *.build.csdef
202 |
203 | # Microsoft Azure Emulator
204 | ecf/
205 | rcf/
206 |
207 | # Windows Store app package directories and files
208 | AppPackages/
209 | BundleArtifacts/
210 | Package.StoreAssociation.xml
211 | _pkginfo.txt
212 | *.appx
213 | *.appxbundle
214 | *.appxupload
215 |
216 | # Visual Studio cache files
217 | # files ending in .cache can be ignored
218 | *.[Cc]ache
219 | # but keep track of directories ending in .cache
220 | !?*.[Cc]ache/
221 |
222 | # Others
223 | ClientBin/
224 | ~$*
225 | *~
226 | *.dbmdl
227 | *.dbproj.schemaview
228 | *.jfm
229 | *.pfx
230 | *.publishsettings
231 | orleans.codegen.cs
232 |
233 | # Including strong name files can present a security risk
234 | # (https://github.com/github/gitignore/pull/2483#issue-259490424)
235 | #*.snk
236 |
237 | # Since there are multiple workflows, uncomment next line to ignore bower_components
238 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
239 | #bower_components/
240 |
241 | # RIA/Silverlight projects
242 | Generated_Code/
243 |
244 | # Backup & report files from converting an old project file
245 | # to a newer Visual Studio version. Backup files are not needed,
246 | # because we have git ;-)
247 | _UpgradeReport_Files/
248 | Backup*/
249 | UpgradeLog*.XML
250 | UpgradeLog*.htm
251 | ServiceFabricBackup/
252 | *.rptproj.bak
253 |
254 | # SQL Server files
255 | *.mdf
256 | *.ldf
257 | *.ndf
258 |
259 | # Business Intelligence projects
260 | *.rdl.data
261 | *.bim.layout
262 | *.bim_*.settings
263 | *.rptproj.rsuser
264 | *- [Bb]ackup.rdl
265 | *- [Bb]ackup ([0-9]).rdl
266 | *- [Bb]ackup ([0-9][0-9]).rdl
267 |
268 | # Microsoft Fakes
269 | FakesAssemblies/
270 |
271 | # GhostDoc plugin setting file
272 | *.GhostDoc.xml
273 |
274 | # Node.js Tools for Visual Studio
275 | .ntvs_analysis.dat
276 | node_modules/
277 |
278 | # Visual Studio 6 build log
279 | *.plg
280 |
281 | # Visual Studio 6 workspace options file
282 | *.opt
283 |
284 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
285 | *.vbw
286 |
287 | # Visual Studio LightSwitch build output
288 | **/*.HTMLClient/GeneratedArtifacts
289 | **/*.DesktopClient/GeneratedArtifacts
290 | **/*.DesktopClient/ModelManifest.xml
291 | **/*.Server/GeneratedArtifacts
292 | **/*.Server/ModelManifest.xml
293 | _Pvt_Extensions
294 |
295 | # Paket dependency manager
296 | .paket/paket.exe
297 | paket-files/
298 |
299 | # FAKE - F# Make
300 | .fake/
301 |
302 | # CodeRush personal settings
303 | .cr/personal
304 |
305 | # Python Tools for Visual Studio (PTVS)
306 | __pycache__/
307 | *.pyc
308 |
309 | # Cake - Uncomment if you are using it
310 | # tools/**
311 | # !tools/packages.config
312 |
313 | # Tabs Studio
314 | *.tss
315 |
316 | # Telerik's JustMock configuration file
317 | *.jmconfig
318 |
319 | # BizTalk build output
320 | *.btp.cs
321 | *.btm.cs
322 | *.odx.cs
323 | *.xsd.cs
324 |
325 | # OpenCover UI analysis results
326 | OpenCover/
327 |
328 | # Azure Stream Analytics local run output
329 | ASALocalRun/
330 |
331 | # MSBuild Binary and Structured Log
332 | *.binlog
333 |
334 | # NVidia Nsight GPU debugger configuration file
335 | *.nvuser
336 |
337 | # MFractors (Xamarin productivity tool) working folder
338 | .mfractor/
339 |
340 | # Local History for Visual Studio
341 | .localhistory/
342 |
343 | # BeatPulse healthcheck temp database
344 | healthchecksdb
345 |
346 | # Backup folder for Package Reference Convert tool in Visual Studio 2017
347 | MigrationBackup/
348 |
349 | # Ionide (cross platform F# VS Code tools) working folder
350 | .ionide/
351 |
352 |
353 | # Byte-compiled / optimized / DLL files
354 | __pycache__/
355 | *.py[cod]
356 | *$py.class
357 |
358 | # C extensions
359 | *.so
360 |
361 | # Distribution / packaging
362 | .Python
363 | build/
364 | develop-eggs/
365 | dist/
366 | downloads/
367 | eggs/
368 | .eggs/
369 | lib/
370 | lib64/
371 | parts/
372 | sdist/
373 | var/
374 | wheels/
375 | share/python-wheels/
376 | *.egg-info/
377 | .installed.cfg
378 | *.egg
379 | MANIFEST
380 |
381 | # PyInstaller
382 | # Usually these files are written by a python script from a template
383 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
384 | *.manifest
385 | *.spec
386 |
387 | # Installer logs
388 | pip-log.txt
389 | pip-delete-this-directory.txt
390 |
391 | # Unit test / coverage reports
392 | htmlcov/
393 | .tox/
394 | .nox/
395 | .coverage
396 | .coverage.*
397 | .cache
398 | nosetests.xml
399 | coverage.xml
400 | *.cover
401 | *.py,cover
402 | .hypothesis/
403 | .pytest_cache/
404 | cover/
405 |
406 | # Translations
407 | *.mo
408 | *.pot
409 |
410 | # Django stuff:
411 | *.log
412 | local_settings.py
413 | db.sqlite3
414 | db.sqlite3-journal
415 |
416 | # Flask stuff:
417 | instance/
418 | .webassets-cache
419 |
420 | # Scrapy stuff:
421 | .scrapy
422 |
423 | # Sphinx documentation
424 | docs/_build/
425 |
426 | # PyBuilder
427 | .pybuilder/
428 | target/
429 |
430 | # Jupyter Notebook
431 | .ipynb_checkpoints
432 |
433 | # IPython
434 | profile_default/
435 | ipython_config.py
436 |
437 | # pyenv
438 | # For a library or package, you might want to ignore these files since the code is
439 | # intended to run in multiple environments; otherwise, check them in:
440 | # .python-version
441 |
442 | # pipenv
443 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
444 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
445 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
446 | # install all needed dependencies.
447 | #Pipfile.lock
448 |
449 | # poetry
450 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
451 | # This is especially recommended for binary packages to ensure reproducibility, and is more
452 | # commonly ignored for libraries.
453 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
454 | #poetry.lock
455 |
456 | # pdm
457 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
458 | #pdm.lock
459 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
460 | # in version control.
461 | # https://pdm.fming.dev/#use-with-ide
462 | .pdm.toml
463 |
464 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
465 | __pypackages__/
466 |
467 | # Celery stuff
468 | celerybeat-schedule
469 | celerybeat.pid
470 |
471 | # SageMath parsed files
472 | *.sage.py
473 |
474 | # Environments
475 | .env
476 | .venv
477 | env/
478 | venv/
479 | ENV/
480 | env.bak/
481 | venv.bak/
482 |
483 | # Spyder project settings
484 | .spyderproject
485 | .spyproject
486 |
487 | # Rope project settings
488 | .ropeproject
489 |
490 | # mkdocs documentation
491 | /site
492 |
493 | # mypy
494 | .mypy_cache/
495 | .dmypy.json
496 | dmypy.json
497 |
498 | # Pyre type checker
499 | .pyre/
500 |
501 | # pytype static type analyzer
502 | .pytype/
503 |
504 | # Cython debug symbols
505 | cython_debug/
506 |
507 | # PyCharm
508 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
509 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
510 | # and can be added to the global gitignore or merged into this file. For a more nuclear
511 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
512 | #.idea/
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # **LEX**: A Length-Extrapolatable Transformer
2 |
3 | ## Key Feature
4 | - [**XPos**](https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py): An Extrapolatable Position Embedding for Transformer decoder.
5 | - [**BCA**](https://github.com/sunyt32/torchscale/blob/main/torchscale/component/multihead_attention.py#L101): An efficient implementation for Block Causal Attention.
6 |
7 | ## Third-Party Implementation
8 | - XPos: [**Flash-Attention**](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/layers/rotary.py)
9 |
10 |
11 |
12 | # TorchScale - A Library for Transformers at (Any) Scale
13 |
14 |
15 |
16 |
17 |
18 |
19 | TorchScale is a PyTorch library that allows researchers and developers to scale up Transformers efficiently and effectively.
20 | It has the implementation of fundamental research to improve modeling generality and capability as well as training stability and efficiency of scaling Transformers.
21 |
22 | - Stability - [**DeepNet**](https://arxiv.org/abs/2203.00555): scaling Transformers to 1,000 Layers and beyond
23 | - Generality - [**Foundation Transformers (Magneto)**](https://arxiv.org/abs/2210.06423): towards true general-purpose modeling across tasks and modalities (including language, vision, speech, and multimodal)
24 | - Efficiency - [**X-MoE**](https://arxiv.org/abs/2204.09179): scalable & finetunable sparse Mixture-of-Experts (MoE)
25 | - Extrapolatablility - [**LEX**](https://arxiv.org/abs/2212.10554): A Length-Extrapolatable Transformer
26 |
27 | ## News
28 |
29 | - November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)]
30 |
31 | ## Installation
32 |
33 | To install:
34 | ```
35 | pip install torchscale
36 | ```
37 |
38 | Alternatively, you can develop it locally:
39 | ```
40 | git clone https://github.com/microsoft/torchscale.git
41 | cd torchscale
42 | pip install -e .
43 | ```
44 |
45 | ## Getting Started
46 |
47 | It takes only several lines of code to create a model with the above fundamental research features enabled. Here is how to quickly obtain a BERT-like encoder:
48 |
49 | ```python
50 | >>> from torchscale.architecture.config import EncoderConfig
51 | >>> from torchscale.architecture.encoder import Encoder
52 |
53 | >>> config = EncoderConfig(vocab_size=64000)
54 | >>> model = Encoder(config)
55 |
56 | >>> print(model)
57 | ```
58 |
59 | We also support the `Decoder` architecture and the `EncoderDecoder` architecture:
60 |
61 | ```python
62 | # Creating a decoder model
63 | >>> from torchscale.architecture.config import DecoderConfig
64 | >>> from torchscale.architecture.decoder import Decoder
65 |
66 | >>> config = DecoderConfig(vocab_size=64000)
67 | >>> decoder = Decoder(config)
68 | >>> print(decoder)
69 |
70 | # Creating a encoder-decoder model
71 | >>> from torchscale.architecture.config import EncoderDecoderConfig
72 | >>> from torchscale.architecture.encoder_decoder import EncoderDecoder
73 |
74 | >>> config = EncoderDecoderConfig(vocab_size=64000)
75 | >>> encdec = EncoderDecoder(config)
76 | >>> print(encdec)
77 | ```
78 |
79 | ## Key Features
80 |
81 | - [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555)
82 | * enabled by setting *deepnorm=True* in the `Config` class.
83 | * It adjusts both the residual connection and the initialization method according to the model architecture (i.e., encoder, decoder, or encoder-decoder).
84 |
85 | - [SubLN for the model generality and the training stability](https://arxiv.org/abs/2210.06423)
86 | * enabled by *subln=True*. This is enabled by default.
87 | * It introduces another LayerNorm to each sublayer and adjusts the initialization according to the model architecture.
88 | * Note that SubLN and DeepNorm cannot be used in one single model.
89 |
90 | - [X-MoE: efficient and finetunable sparse MoE modeling](https://arxiv.org/abs/2204.09179)
91 | * enabled by *use_xmoe=True*.
92 | * It replaces every *'moe_freq'* `FeedForwardNetwork` layers with the X-MoE layers.
93 |
94 | - [Multiway architecture for multimodality](https://arxiv.org/abs/2208.10442)
95 | * enabled by *multiway=True*.
96 | * It provides a pool of Transformer's parameters used for different modalities.
97 |
98 | - [Extrapolatable position embedding (Xpos)](https://arxiv.org/abs/1910.10683)
99 | * enabled by *xpos_rel_pos=True*.
100 |
101 | - [Blockwise Causal Attention (BCA)](https://arxiv.org/abs/2212.10554)
102 | * enabled by adjusting *block_size*. If *block_size=-1*, BCA will not be implemented.
103 | * Setting *block_size* as pre-training length is recommended.
104 |
105 | - [SparseClip: improving the gradient clipping for sparse MoE models](https://arxiv.org/abs/2211.13184)
106 | * we provide a [sample code](examples/fairseq/utils/sparse_clip.py) that can be easily adapted to the FairSeq (or other) repo.
107 |
108 | Most of the features above can be used by simply passing the corresponding parameters to the config. For example:
109 |
110 | ```python
111 | >>> from torchscale.architecture.config import EncoderConfig
112 | >>> from torchscale.architecture.encoder import Encoder
113 |
114 | >>> config = EncoderConfig(vocab_size=64000, deepnorm=True, multiway=True)
115 | >>> model = Encoder(config)
116 |
117 | >>> print(model)
118 | ```
119 |
120 | ## Examples
121 |
122 | We have the examples of how to use TorchScale in the following scenarios/tasks:
123 |
124 | - Language
125 |
126 | * [Decoder/GPT](examples/fairseq/README.md#example-gpt-pretraining)
127 |
128 | * [Encoder-Decoder/Neural Machine Translation](examples/fairseq/README.md#example-machine-translation)
129 |
130 | * [Encoder/BERT](examples/fairseq/README.md#example-bert-pretraining)
131 |
132 | - Vision
133 |
134 | * ViT/BEiT [In progress]
135 |
136 | - Speech
137 |
138 | - Multimodal
139 |
140 | * [Multiway Transformers/BEiT-3](torchscale/model/BEiT3.py) [In progress]
141 |
142 | We plan to provide more examples regarding different tasks (e.g. vision pretraining and speech recognition) and various deep learning toolkits (e.g. [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). Any comments or PRs are welcome!
143 |
144 | ## Results
145 |
146 | ### Stability Evaluation
147 |
148 |
149 |
150 |
151 |
152 | The training curve is smooth by using TorchScale, while the baseline Transformer cannot converge.
153 |
154 | ### Scaling-up Experiments
155 |
156 |
157 |
158 |
159 |
160 | TorchScale supports arbitrary depths and widths, successfully scaling-up the models without pain.
161 |
162 | ## Acknowledgments
163 |
164 | Some implementations in TorchScale are either adapted from or inspired by the [FairSeq](https://github.com/facebookresearch/fairseq) repository and the [UniLM](https://github.com/microsoft/unilm) repository.
165 |
166 | ## Citations
167 |
168 | If you find this repository useful, please consider citing our work:
169 |
170 | ```
171 | @article{torchscale,
172 | author = {Shuming Ma and Hongyu Wang and Shaohan Huang and Wenhui Wang and Zewen Chi and Li Dong and Alon Benhaim and Barun Patra and Vishrav Chaudhary and Xia Song and Furu Wei},
173 | title = {{TorchScale}: {Transformers} at Scale},
174 | journal = {CoRR},
175 | volume = {abs/2211.13184},
176 | year = {2022}
177 | }
178 | ```
179 |
180 | ```
181 | @article{deepnet,
182 | author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei},
183 | title = {{DeepNet}: Scaling {Transformers} to 1,000 Layers},
184 | journal = {CoRR},
185 | volume = {abs/2203.00555},
186 | year = {2022},
187 | }
188 | ```
189 |
190 | ```
191 | @article{magneto,
192 | author = {Hongyu Wang and Shuming Ma and Shaohan Huang and Li Dong and Wenhui Wang and Zhiliang Peng and Yu Wu and Payal Bajaj and Saksham Singhal and Alon Benhaim and Barun Patra and Zhun Liu and Vishrav Chaudhary and Xia Song and Furu Wei},
193 | title = {Foundation {Transformers}},
194 | journal = {CoRR},
195 | volume = {abs/2210.06423},
196 | year = {2022}
197 | }
198 | ```
199 |
200 | ```
201 | @inproceedings{xmoe,
202 | title={On the Representation Collapse of Sparse Mixture of Experts},
203 | author={Zewen Chi and Li Dong and Shaohan Huang and Damai Dai and Shuming Ma and Barun Patra and Saksham Singhal and Payal Bajaj and Xia Song and Xian-Ling Mao and Heyan Huang and Furu Wei},
204 | booktitle={Advances in Neural Information Processing Systems},
205 | year={2022},
206 | url={https://openreview.net/forum?id=mWaYC6CZf5}
207 | }
208 | ```
209 |
210 | ## Contributing
211 |
212 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
213 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
214 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
215 |
216 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide
217 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
218 | provided by the bot. You will only need to do this once across all repos using our CLA.
219 |
220 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
221 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
222 | contact [Furu Wei](mailto:fuwei@microsoft.com) and [Shuming Ma](mailto:shumma@microsoft.com) with any additional questions or comments.
223 |
224 | ## Trademarks
225 |
226 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
227 | trademarks or logos is subject to and must follow
228 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
229 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
230 | Any use of third-party trademarks or logos are subject to those third-party's policies.
231 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/examples/fairseq/README.md:
--------------------------------------------------------------------------------
1 | # Example: Integration with FairSeq
2 |
3 | ## Setup
4 |
5 | ```bash
6 | # Install the repo as a package:
7 | git clone https://github.com/msranlp/torchscale.git
8 | cd torchscale
9 | pip install -e .
10 | pip install git+https://github.com/shumingma/fairseq.git@moe
11 | pip install git+https://github.com/shumingma/infinibatch.git
12 | pip install iopath
13 | pip install --upgrade numpy
14 | ```
15 |
16 | ## Example: BERT Pretraining
17 |
18 | ### Data Format
19 |
20 | We use a [streaming dataloader](https://github.com/microsoft/infinibatch) to read the data on-the-fly from the disk. It requires the data sharded into multiple small files (e.g. 10K lines per file), as well as a JSON file to contain some meta data and the paths to these files.
21 |
22 | The overall data directory should be organized as follows:
23 | ```
24 | Data/
25 | ├── json/
26 | │ ├── train.json
27 | │ └── valid.json
28 | ├── shard/
29 | │ ├── train/
30 | │ │ ├── 00000.txt
31 | │ │ ├── 00001.txt
32 | │ │ └── ...
33 | │ └── valid/
34 | │ ├── 00000.txt
35 | │ ├── 00001.txt
36 | │ └── ...
37 | ├── dict.txt
38 | └── sentencepiece.bpe.model
39 | ```
40 |
41 | We recommend that each sharded data files contains no more than 10K lines with one sentence per line, and two documents should be separated with an empty line.
42 | ```
43 | Document 1 Line 1
44 | Document 1 Line 2
45 | Document 1 Line 3
46 |
47 | Document 2 Line 1
48 | Document 2 Line 2
49 |
50 | ...
51 | ```
52 |
53 | Also, the JSON file should be in the format like this:
54 | ```
55 | [
56 | {
57 | "source": [
58 | "shard/train/00000.txt",
59 | "shard/train/00001.txt",
60 | ...
61 | ],
62 | "source_lang": "en",
63 | "weight": 1.0
64 | }
65 | ]
66 | ```
67 |
68 | You can quickly get started with our processed vocabulary files: [sentencepiece.bpe.model](https://publicmodel.blob.core.windows.net/torchscale/vocab/sentencepiece.bpe.model) and [dict.txt](https://publicmodel.blob.core.windows.net/torchscale/vocab/dict.txt). Note that this vocabulary is English-only with 64K tokens. To train a new `sentencepiece.bpe.model` on your own data, please refer to the [SentencePiece](https://github.com/google/sentencepiece) repo. With the sentecepiece model and the installed `sentencepiece` library, you can extract the `dict.txt` file from it by
69 | ```
70 | spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt
71 | ```
72 |
73 | ### Training Command
74 | ```bash
75 | cd examples/fairseq/
76 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \
77 | --task pretraining \
78 | --tokens-per-sample 512 \
79 | --mask-prob 0.15 \
80 | --span-length 3.0 \
81 | --leave-unmasked-prob 0.0 \
82 | --random-token-prob 0.0 \
83 | --criterion masked_lm \
84 | --arch mlm_base \
85 | --share-encoder-input-output-embed \
86 | --required-batch-size-multiple 8 \
87 | --spm-model ${PATH_TO_DATA}/sentencepiece.bpe.model \
88 | --dict-file ${PATH_TO_DATA}/dict.txt \
89 | --optimizer adam \
90 | --adam-betas '(0.9,0.98)' \
91 | --adam-eps 1e-6 \
92 | --clip-norm 2.0 \
93 | --lr-scheduler polynomial_decay \
94 | --lr 0.0005 \
95 | --warmup-updates 10000 \
96 | --total-num-update 125000 \
97 | --max-update 125000 \
98 | --max-sentences 32 \
99 | --update-freq 1 \
100 | --log-format simple \
101 | --log-interval 100 \
102 | --disable-validation \
103 | --save-interval-updates 5000 \
104 | --no-epoch-checkpoints \
105 | --fp16 \
106 | --fp16-init-scale 4 \
107 | --fp16-scale-window 256 \
108 | --min-loss-scale 0.0001 \
109 | --seed 1 \
110 | --save-dir ${PATH_TO_CKPT} \
111 | --ddp-backend=no_c10d \
112 | --distributed-no-spawn \
113 | --reset-dataloader \
114 | --batch-read-ahead 10000 \
115 | --rel-pos-buckets 32 \
116 | --max-rel-pos 128 \
117 | --deepnorm
118 | ```
119 |
120 | ## Example: GPT Pretraining
121 |
122 | ### Data Format
123 |
124 | We use the format as in the FairSeq's [language modeling example](https://github.com/facebookresearch/fairseq/tree/main/examples/language_model#1-preprocess-the-data).
125 |
126 | ### Dense Model
127 |
128 | ```bash
129 | cd examples/fairseq/
130 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
131 | ${PATH_TO_DATA} \
132 | --num-workers 2 \
133 | --activation-fn gelu \
134 | --share-decoder-input-output-embed \
135 | --validate-interval-updates 1000 \
136 | --save-interval-updates 1000 \
137 | --no-epoch-checkpoints \
138 | --memory-efficient-fp16 \
139 | --fp16-init-scale 4 \
140 | --arch lm_base \
141 | --task language_modeling \
142 | --sample-break-mode none \
143 | --tokens-per-sample 128 \
144 | --optimizer adam --adam-betas "(0.9, 0.98)" \
145 | --adam-eps 1e-08 \
146 | --clip-norm 0.0 \
147 | --lr 5e-4 \
148 | --lr-scheduler polynomial_decay \
149 | --warmup-updates 750 \
150 | --dropout 0.1 \
151 | --attention-dropout 0.1 \
152 | --weight-decay 0.01 \
153 | --batch-size 4 \
154 | --update-freq 1 \
155 | --required-batch-size-multiple 1 \
156 | --total-num-update 50000 \
157 | --max-update 50000 \
158 | --seed 1 \
159 | --ddp-backend=c10d
160 | ```
161 |
162 | ### Sparse (MoE) Model
163 |
164 | ```bash
165 | cd examples/fairseq/
166 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
167 | ${PATH_TO_DATA} \
168 | --num-workers 2 \
169 | --activation-fn gelu \
170 | --share-decoder-input-output-embed \
171 | --validate-interval-updates 1000 \
172 | --save-interval-updates 1000 \
173 | --no-epoch-checkpoints \
174 | --memory-efficient-fp16 \
175 | --fp16-init-scale 4 \
176 | --arch lm_base \
177 | --task language_modeling \
178 | --sample-break-mode none \
179 | --tokens-per-sample 128 \
180 | --optimizer adam --adam-betas "(0.9, 0.98)" \
181 | --adam-eps 1e-08 \
182 | --clip-norm 0.0 \
183 | --lr 5e-4 \
184 | --lr-scheduler polynomial_decay \
185 | --warmup-updates 750 \
186 | --dropout 0.1 \
187 | --attention-dropout 0.1 \
188 | --weight-decay 0.01 \
189 | --batch-size 4 \
190 | --update-freq 1 \
191 | --required-batch-size-multiple 1 \
192 | --total-num-update 50000 \
193 | --max-update 50000 \
194 | --seed 1 \
195 | --ddp-backend=no_c10d \
196 | --moe-expert-count 2 --moe-freq 2 \
197 | --moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
198 | --moe-eval-capacity-token-fraction -1.0 \
199 | --criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
200 | --use-xmoe
201 | ```
202 |
203 | ## Example: Machine Translation
204 |
205 | ### Data Format
206 |
207 | We follow the FairSeq's [neural machine translation example](https://github.com/facebookresearch/fairseq/tree/main/examples/translation#training-a-new-model) to preprocess the data.
208 |
209 | ### Dense Model
210 |
211 | ```bash
212 | cd examples/fairseq/
213 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
214 | ${PATH_TO_DATA} \
215 | --arch mt_base --share-decoder-input-output-embed \
216 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
217 | --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
218 | --dropout 0.3 --weight-decay 0.0001 \
219 | --max-tokens 4096 --fp16
220 | ```
221 |
222 | ### Sparse (MoE) Model
223 |
224 | ```bash
225 | cd examples/fairseq/
226 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
227 | ${PATH_TO_DATA} \
228 | --arch mt_base --share-decoder-input-output-embed \
229 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
230 | --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
231 | --dropout 0.3 --weight-decay 0.0001 \
232 | --moe-expert-count 2 --moe-freq 2 \
233 | --moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
234 | --moe-eval-capacity-token-fraction -1.0 \
235 | --criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
236 | --use-xmoe \
237 | --max-tokens 4096 --fp16
238 | ```
239 |
--------------------------------------------------------------------------------
/examples/fairseq/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/examples/fairseq/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | # flake8: noqa
5 | import models
6 | import tasks
7 | from fairseq_cli.generate import cli_main
8 |
9 | if __name__ == "__main__":
10 | cli_main()
11 |
--------------------------------------------------------------------------------
/examples/fairseq/interactive.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | # flake8: noqa
5 | import models
6 | import tasks
7 | from fairseq_cli.interactive import cli_main
8 |
9 | if __name__ == "__main__":
10 | cli_main()
11 |
--------------------------------------------------------------------------------
/examples/fairseq/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import argparse
5 | import importlib
6 | import os
7 |
8 | MODEL_REGISTRY = {}
9 | MODEL_DATACLASS_REGISTRY = {}
10 | ARCH_MODEL_REGISTRY = {}
11 | ARCH_MODEL_NAME_REGISTRY = {}
12 | ARCH_MODEL_INV_REGISTRY = {}
13 | ARCH_CONFIG_REGISTRY = {}
14 |
15 | # automatically import any Python files in the models/ directory
16 | models_dir = os.path.dirname(__file__)
17 | for file in os.listdir(models_dir):
18 | path = os.path.join(models_dir, file)
19 | if (
20 | not file.startswith("_")
21 | and not file.startswith(".")
22 | and (file.endswith(".py") or os.path.isdir(path))
23 | ):
24 | model_name = file[: file.find(".py")] if file.endswith(".py") else file
25 | module = importlib.import_module("models." + model_name)
26 |
27 | # extra `model_parser` for sphinx
28 | if model_name in MODEL_REGISTRY:
29 | parser = argparse.ArgumentParser(add_help=False)
30 | group_archs = parser.add_argument_group("Named architectures")
31 | group_archs.add_argument(
32 | "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name]
33 | )
34 | group_args = parser.add_argument_group("Additional command-line arguments")
35 | MODEL_REGISTRY[model_name].add_args(group_args)
36 | globals()[model_name + "_parser"] = parser
37 |
--------------------------------------------------------------------------------
/examples/fairseq/models/bert.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import logging
5 | from dataclasses import dataclass, field
6 | from typing import Optional
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from apex.normalization import FusedLayerNorm as LayerNorm
12 | from fairseq import utils
13 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass
14 | from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
15 | from fairseq.models.squad import SQuADHead
16 | from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
17 | from fairseq.modules import PositionalEmbedding
18 | from omegaconf import II
19 |
20 | from torchscale.architecture.config import EncoderConfig
21 |
22 | from .machine_translation import MTEncoder as Encoder
23 |
24 | DEFAULT_MAX_SOURCE_POSITIONS = 1024
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | @dataclass
30 | class BertConfig(FairseqDataclass):
31 | activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
32 | default="relu", metadata={"help": "activation function to use"}
33 | )
34 | dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
35 | attention_dropout: float = field(
36 | default=0.0, metadata={"help": "dropout probability for attention weights"}
37 | )
38 | activation_dropout: float = field(
39 | default=0.0, metadata={"help": "dropout probability after activation in FFN."}
40 | )
41 | encoder_embed_dim: int = field(
42 | default=512, metadata={"help": "encoder embedding dimension"}
43 | )
44 | encoder_output_dim: int = field(
45 | default=512, metadata={"help": "encoder output dimension"}
46 | )
47 | encoder_input_dim: int = field(
48 | default=512, metadata={"help": "encoder input dimension"}
49 | )
50 | encoder_ffn_embed_dim: int = field(
51 | default=2048, metadata={"help": "encoder embedding dimension for FFN"}
52 | )
53 | encoder_layers: int = field(default=6, metadata={"help": "num encoder layers"})
54 | encoder_attention_heads: int = field(
55 | default=8, metadata={"help": "num encoder attention heads"}
56 | )
57 | encoder_normalize_before: bool = field(
58 | default=False, metadata={"help": "apply layernorm before each encoder block"}
59 | )
60 | no_encoder_final_norm: bool = field(
61 | default=False,
62 | metadata={"help": "don't add an extra layernorm after the last encoder block"},
63 | )
64 | no_token_positional_embeddings: bool = field(
65 | default=False,
66 | metadata={
67 | "help": "if set, disables positional embeddings (outside self attention)"
68 | },
69 | )
70 | share_encoder_input_output_embed: bool = field(
71 | default=False, metadata={"help": "share encoder input and output embeddings"}
72 | )
73 | encoder_learned_pos: bool = field(
74 | default=False,
75 | metadata={"help": "use learned positional embeddings in the encoder"},
76 | )
77 | layernorm_embedding: bool = field(
78 | default=False, metadata={"help": "add layernorm to embedding"}
79 | )
80 | no_scale_embedding: bool = field(
81 | default=False, metadata={"help": "if True, dont scale embeddings"}
82 | )
83 | checkpoint_activations: bool = field(
84 | default=False, metadata={"help": "checkpoint activations at each layer"}
85 | )
86 | offload_activations: bool = field(
87 | default=False,
88 | metadata={"help": "move checkpointed activations to CPU after they are used."},
89 | )
90 | # config for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
91 | encoder_layerdrop: float = field(
92 | default=0.0, metadata={"help": "LayerDrop probability for encoder"}
93 | )
94 | encoder_layers_to_keep: Optional[str] = field(
95 | default=None,
96 | metadata={
97 | "help": "which layers to *keep* when pruning as a comma-separated list"
98 | },
99 | )
100 | # config for Fully Sharded Data Parallel (FSDP) training
101 | min_params_to_wrap: int = field(
102 | default=DEFAULT_MIN_PARAMS_TO_WRAP,
103 | metadata={
104 | "help": (
105 | "minimum number of params for a layer to be wrapped with FSDP() when "
106 | "training with --ddp-backend=fully_sharded. Smaller values will "
107 | "improve memory efficiency, but may make torch.distributed "
108 | "communication less efficient due to smaller input sizes. This option "
109 | "is set to 0 (i.e., always wrap) when --checkpoint-activations or "
110 | "--offload-activations are passed."
111 | )
112 | },
113 | )
114 | max_source_positions: int = field(
115 | default=1024, metadata={"help": "max source positions"}
116 | )
117 | pooler_activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
118 | default="relu", metadata={"help": "activation function to use for pooler layer"}
119 | )
120 | pooler_dropout: float = field(
121 | default=0.0,
122 | metadata={"help": "dropout probability in the masked_lm pooler layers"},
123 | )
124 | # options from other parts of the config
125 | # add_bos_token: bool = II("task.add_bos_token")
126 | # tokens_per_sample: int = II("task.tokens_per_sample")
127 | tpu: bool = II("common.tpu")
128 | rel_pos_buckets: int = field(default=0, metadata={"help": ""})
129 | max_rel_pos: int = field(default=0, metadata={"help": ""})
130 | moe_freq: int = field(
131 | default=0,
132 | metadata={"help": "Frequency at which we insert MoE Transformer layers"},
133 | )
134 | moe_expert_count: int = field(
135 | default=0, metadata={"help": "Number of experts in each MoE Layer"}
136 | )
137 | moe_gating_use_fp32: bool = field(
138 | default=False,
139 | metadata={"help": "Use FP32 computations in MoE top2 gating function"},
140 | )
141 | moe_second_expert_policy: str = field(
142 | default="sampling",
143 | metadata={"help": "policy for second expert, options: all/sampling/random"},
144 | )
145 | moe_normalize_gate_prob_before_dropping: bool = field(
146 | default=False,
147 | metadata={
148 | "help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
149 | },
150 | )
151 | moe_expert_ffn_dim: Optional[int] = field(
152 | default=None, metadata={"help": "MoE expert FFN dimension"}
153 | )
154 | moe_top1_expert: Optional[bool] = field(
155 | default=False, metadata={"help": "Use top1 gate instead of top2"}
156 | )
157 | moe_eval_capacity_token_fraction: Optional[float] = field(
158 | default=0.25,
159 | metadata={
160 | "help": (
161 | "Default: 0.25, Fraction of tokens as capacity during validation, "
162 | "if set to negative, use same as training. range: (0.0, 1.0]."
163 | )
164 | },
165 | )
166 | moe_normalize_expert_grad: Optional[str] = field(
167 | default="world_size",
168 | metadata={
169 | "help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
170 | },
171 | )
172 | record_a2a_perf_stats: Optional[bool] = field(
173 | default=False,
174 | metadata={"help": "records all to all perf stats during distributed training"},
175 | )
176 | dummy_a2a: Optional[bool] = field(
177 | default=False,
178 | metadata={
179 | "help": "By passes all to all during distributed training by returning the input buffer as output"
180 | },
181 | )
182 | moe_batch_prioritized_routing: Optional[bool] = field(
183 | default=False,
184 | metadata={
185 | "help": "if true orders token by the gate prob before capacity dropping."
186 | },
187 | )
188 | ddp_rank: int = II("distributed_training.distributed_rank")
189 | deepnorm: Optional[bool] = field(
190 | default=False,
191 | )
192 | subln: Optional[bool] = field(
193 | default=False,
194 | )
195 |
196 |
197 | @register_model("mlm", dataclass=BertConfig)
198 | class BertModel(BaseFairseqModel):
199 | def __init__(self, args, encoder):
200 | super().__init__()
201 | self.args = args
202 | self.encoder = encoder
203 | self.padding_idx = self.encoder.embed_tokens.padding_idx
204 | self.classification_heads = nn.ModuleDict()
205 |
206 | @classmethod
207 | def build_model(cls, args, task):
208 | """Build a new model instance."""
209 |
210 | args.max_source_positions = getattr(
211 | args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS
212 | )
213 |
214 | embed_tokens = cls.build_embedding(
215 | args, task.dictionary, args.encoder_embed_dim
216 | )
217 |
218 | embed_positions = (
219 | PositionalEmbedding(
220 | args.max_source_positions,
221 | args.encoder_embed_dim,
222 | task.dictionary.pad(),
223 | learned=args.encoder_learned_pos,
224 | )
225 | if not args.no_token_positional_embeddings
226 | else None
227 | )
228 |
229 | lm_head = cls.build_lm_head(
230 | args,
231 | args.encoder_embed_dim,
232 | len(task.dictionary),
233 | args.activation_fn,
234 | weight=embed_tokens.weight,
235 | )
236 |
237 | config = EncoderConfig()
238 | config.override(args)
239 |
240 | encoder = Encoder(
241 | config,
242 | embed_tokens=embed_tokens,
243 | embed_positions=embed_positions,
244 | output_projection=lm_head,
245 | is_encoder_decoder=False,
246 | dictionary=task.dictionary,
247 | )
248 |
249 | return cls(args, encoder)
250 |
251 | @classmethod
252 | def build_embedding(cls, args, dictionary, embed_dim, path=None):
253 | embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad())
254 | return embed_tokens
255 |
256 | @classmethod
257 | def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight):
258 | return LMHead(embed_dim, output_dim, activation_fn, weight)
259 |
260 | def output_layer(self, features, masked_tokens=None):
261 | return self.encoder.output_projection(features, masked_tokens=masked_tokens)
262 |
263 | def register_classification_head(
264 | self, name, num_classes=None, inner_dim=None, **kwargs
265 | ):
266 | """Register a classification head."""
267 | if name in self.classification_heads:
268 | prev_num_classes = self.classification_heads[name].out_proj.out_features
269 | prev_inner_dim = self.classification_heads[name].dense.out_features
270 | if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
271 | logger.warning(
272 | 're-registering head "{}" with num_classes {} (prev: {}) '
273 | "and inner_dim {} (prev: {})".format(
274 | name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
275 | )
276 | )
277 | self.classification_heads[name] = ClassificationHead(
278 | self.args.encoder_embed_dim,
279 | inner_dim or self.args.encoder_embed_dim,
280 | num_classes,
281 | self.args.pooler_activation_fn,
282 | self.args.pooler_dropout,
283 | )
284 |
285 | def register_question_answering_head(self, name, num_classes=None):
286 | self.classification_heads[name] = SQuADHead(
287 | self.args.encoder_embed_dim,
288 | )
289 |
290 | def upgrade_state_dict_named(self, state_dict, name):
291 | prefix = name + "." if name != "" else ""
292 |
293 | # upgrade children modules
294 | super().upgrade_state_dict_named(state_dict, name)
295 |
296 | # Handle new classification heads present in the state dict.
297 | current_head_names = (
298 | []
299 | if not hasattr(self, "classification_heads")
300 | else self.classification_heads.keys()
301 | )
302 | keys_to_delete = []
303 | for k in state_dict.keys():
304 | if not k.startswith(prefix + "classification_heads."):
305 | continue
306 |
307 | head_name = k[len(prefix + "classification_heads.") :].split(".")[0] # noqa: E203
308 | num_classes = state_dict[
309 | prefix + "classification_heads." + head_name + ".out_proj.weight"
310 | ].size(0)
311 | inner_dim = state_dict[
312 | prefix + "classification_heads." + head_name + ".dense.weight"
313 | ].size(0)
314 |
315 | if getattr(self.args, "load_checkpoint_heads", False):
316 | if head_name not in current_head_names:
317 | self.register_classification_head(head_name, num_classes, inner_dim)
318 | else:
319 | if head_name not in current_head_names:
320 | logger.warning(
321 | "deleting classification head ({}) from checkpoint "
322 | "not present in current model: {}".format(head_name, k)
323 | )
324 | keys_to_delete.append(k)
325 | elif (
326 | num_classes
327 | != self.classification_heads[head_name].out_proj.out_features
328 | or inner_dim
329 | != self.classification_heads[head_name].dense.out_features
330 | ):
331 | logger.warning(
332 | "deleting classification head ({}) from checkpoint "
333 | "with different dimensions than current model: {}".format(
334 | head_name, k
335 | )
336 | )
337 | keys_to_delete.append(k)
338 | for k in keys_to_delete:
339 | del state_dict[k]
340 |
341 | # Copy any newly-added classification heads into the state dict
342 | # with their current weights.
343 | if hasattr(self, "classification_heads"):
344 | cur_state = self.classification_heads.state_dict()
345 | for k, v in cur_state.items():
346 | if prefix + "classification_heads." + k not in state_dict:
347 | logger.info("Overwriting " + prefix + "classification_heads." + k)
348 | state_dict[prefix + "classification_heads." + k] = v
349 |
350 | def forward(
351 | self,
352 | src_tokens=None,
353 | features_only=False,
354 | return_all_hiddens=False,
355 | classification_head_name=None,
356 | masked_tokens=None,
357 | **kwargs
358 | ):
359 | encoder_out = self.encoder(
360 | src_tokens, features_only=True, return_all_hiddens=return_all_hiddens
361 | )
362 | x, extra = encoder_out["encoder_out"], encoder_out
363 | x = x.transpose(0, 1)
364 |
365 | if classification_head_name is not None:
366 | x = self.classification_heads[classification_head_name](x)
367 | elif not features_only:
368 | x = self.output_layer(x, masked_tokens=masked_tokens)
369 |
370 | return x, extra
371 |
372 |
373 | class ClassificationHead(nn.Module):
374 | """Head for sentence-level classification tasks."""
375 |
376 | def __init__(
377 | self,
378 | input_dim,
379 | inner_dim,
380 | num_classes,
381 | activation_fn,
382 | pooler_dropout,
383 | ):
384 | super().__init__()
385 | self.dense = nn.Linear(input_dim, inner_dim)
386 | self.activation_fn = utils.get_activation_fn(activation_fn)
387 | self.dropout = nn.Dropout(p=pooler_dropout)
388 | self.out_proj = nn.Linear(inner_dim, num_classes)
389 |
390 | def forward(self, features, **kwargs):
391 | x = features[:, 0, :] # take token (equiv. to [CLS])
392 | x = self.dropout(x)
393 | x = self.dense(x)
394 | x = self.activation_fn(x.float()).type_as(x)
395 | x = self.dropout(x)
396 | x = self.out_proj(x)
397 | return x
398 |
399 |
400 | class LMHead(nn.Module):
401 | """Head for masked language modeling."""
402 |
403 | def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
404 | super().__init__()
405 | self.dense = nn.Linear(embed_dim, embed_dim)
406 | self.activation_fn = utils.get_activation_fn(activation_fn)
407 | self.layer_norm = LayerNorm(embed_dim)
408 |
409 | if weight is None:
410 | weight = nn.Linear(embed_dim, output_dim, bias=False).weight
411 | self.weight = weight
412 | self.bias = nn.Parameter(torch.zeros(output_dim))
413 |
414 | def forward(self, features, masked_tokens=None, **kwargs):
415 | # Only project the masked tokens while training,
416 | # saves both memory and computation
417 | if masked_tokens is not None:
418 | features = features[masked_tokens, :]
419 |
420 | x = self.dense(features)
421 | x = self.activation_fn(x.float()).type_as(x)
422 | x = self.layer_norm(x)
423 | # project back to size of vocabulary with bias
424 | x = F.linear(x, self.weight) + self.bias
425 | return x
426 |
427 |
428 | @register_model_architecture("mlm", "mlm_base")
429 | def base_unilm_architecture(args):
430 | if hasattr(args, "encoder_final_norm"):
431 | args.no_encoder_final_norm = not args.encoder_final_norm
432 |
433 | args.dropout = getattr(args, "dropout", 0.1)
434 | args.attention_dropout = getattr(args, "attention_dropout", 0.0)
435 | args.activation_dropout = getattr(args, "activation_dropout", 0.0)
436 | args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
437 |
438 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
439 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
440 | args.encoder_layers = getattr(args, "encoder_layers", 12)
441 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
442 | args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
443 | args.activation_fn = getattr(args, "activation_fn", "gelu")
444 | args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
445 |
446 | args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
447 | args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
448 |
449 | # args.add_bos_token = getattr(args, "add_bos_token", False)
450 | args.no_token_positional_embeddings = getattr(
451 | args, "no_token_positional_embeddings", False
452 | )
453 | args.share_encoder_input_output_embed = getattr(
454 | args, "share_encoder_input_output_embed", True
455 | )
456 | args.encoder_output_dim = getattr(
457 | args, "encoder_output_dim", args.encoder_embed_dim
458 | )
459 | args.encoder_input_dim = getattr(args, "encoder_input_dim", args.encoder_embed_dim)
460 |
461 | # Model training is not stable without this
462 | args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
463 | args.no_encoder_final_norm = getattr(args, "no_encoder_final_norm", False)
464 |
465 | args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
466 | args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
467 | args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
468 | args.offload_activations = getattr(args, "offload_activations", False)
469 | if args.offload_activations:
470 | args.checkpoint_activations = True
471 |
--------------------------------------------------------------------------------
/examples/fairseq/models/language_modeling.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | # Copyright (c) Facebook, Inc. and its affiliates.
5 | #
6 | # This source code is licensed under the MIT license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | import logging
10 | from dataclasses import dataclass, field
11 | from typing import Optional
12 |
13 | import torch
14 | from fairseq import distributed_utils, utils
15 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass
16 | from fairseq.models import (
17 | FairseqIncrementalDecoder,
18 | FairseqLanguageModel,
19 | register_model,
20 | register_model_architecture,
21 | )
22 | from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
23 | from fairseq.modules import PositionalEmbedding
24 | from omegaconf import II
25 |
26 | from torchscale.architecture.config import DecoderConfig
27 | from torchscale.architecture.decoder import Decoder
28 |
29 | DEFAULT_MAX_TARGET_POSITIONS = 4096
30 | logger = logging.getLogger(__name__)
31 |
32 |
33 | @dataclass
34 | class LanguageConfig(FairseqDataclass):
35 | activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
36 | default="relu", metadata={"help": "activation function to use"}
37 | )
38 | dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
39 | attention_dropout: float = field(
40 | default=0.0, metadata={"help": "dropout probability for attention weights"}
41 | )
42 | activation_dropout: float = field(
43 | default=0.0, metadata={"help": "dropout probability after activation in FFN."}
44 | )
45 | relu_dropout: float = field(
46 | default=0.0, metadata={"help": "dropout probability after activation in FFN."}
47 | )
48 | decoder_embed_dim: int = field(
49 | default=512, metadata={"help": "decoder embedding dimension"}
50 | )
51 | decoder_output_dim: int = field(
52 | default=512, metadata={"help": "decoder output dimension"}
53 | )
54 | decoder_input_dim: int = field(
55 | default=512, metadata={"help": "decoder input dimension"}
56 | )
57 | decoder_ffn_embed_dim: int = field(
58 | default=2048, metadata={"help": "decoder embedding dimension for FFN"}
59 | )
60 | decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
61 | decoder_attention_heads: int = field(
62 | default=8, metadata={"help": "num decoder attention heads"}
63 | )
64 | decoder_normalize_before: bool = field(
65 | default=False, metadata={"help": "apply layernorm before each decoder block"}
66 | )
67 | no_token_positional_embeddings: bool = field(
68 | default=False,
69 | metadata={
70 | "help": "if set, disables positional embeddings (outside self attention)"
71 | },
72 | )
73 | share_decoder_input_output_embed: bool = field(
74 | default=False, metadata={"help": "share decoder input and output embeddings"}
75 | )
76 | decoder_learned_pos: bool = field(
77 | default=False,
78 | metadata={"help": "use learned positional embeddings in the decoder"},
79 | )
80 | layernorm_embedding: bool = field(
81 | default=False, metadata={"help": "add layernorm to embedding"}
82 | )
83 | no_scale_embedding: bool = field(
84 | default=False, metadata={"help": "if True, dont scale embeddings"}
85 | )
86 | checkpoint_activations: bool = field(
87 | default=False, metadata={"help": "checkpoint activations at each layer"}
88 | )
89 | offload_activations: bool = field(
90 | default=False,
91 | metadata={"help": "move checkpointed activations to CPU after they are used."},
92 | )
93 | # config for Fully Sharded Data Parallel (FSDP) training
94 | min_params_to_wrap: int = field(
95 | default=DEFAULT_MIN_PARAMS_TO_WRAP,
96 | metadata={
97 | "help": (
98 | "minimum number of params for a layer to be wrapped with FSDP() when "
99 | "training with --ddp-backend=fully_sharded. Smaller values will "
100 | "improve memory efficiency, but may make torch.distributed "
101 | "communication less efficient due to smaller input sizes. This option "
102 | "is set to 0 (i.e., always wrap) when --checkpoint-activations or "
103 | "--offload-activations are passed."
104 | )
105 | },
106 | )
107 | moe_freq: int = field(
108 | default=0,
109 | metadata={"help": "Frequency at which we insert MoE Transformer layers"},
110 | )
111 | moe_expert_count: int = field(
112 | default=0, metadata={"help": "Number of experts in each MoE Layer"}
113 | )
114 | moe_gating_use_fp32: bool = field(
115 | default=False,
116 | metadata={"help": "Use FP32 computations in MoE top2 gating function"},
117 | )
118 | moe_second_expert_policy: str = field(
119 | default="sampling",
120 | metadata={"help": "policy for second expert, options: all/sampling/random"},
121 | )
122 | moe_normalize_gate_prob_before_dropping: bool = field(
123 | default=False,
124 | metadata={
125 | "help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
126 | },
127 | )
128 | moe_expert_ffn_dim: Optional[int] = field(
129 | default=None, metadata={"help": "MoE expert FFN dimension"}
130 | )
131 | moe_top1_expert: Optional[bool] = field(
132 | default=False, metadata={"help": "Use top1 gate instead of top2"}
133 | )
134 | moe_eval_capacity_token_fraction: Optional[float] = field(
135 | default=0.25,
136 | metadata={
137 | "help": (
138 | "Default: 0.25, Fraction of tokens as capacity during validation, "
139 | "if set to negative, use same as training. range: (0.0, 1.0]."
140 | )
141 | },
142 | )
143 | moe_normalize_expert_grad: Optional[str] = field(
144 | default="world_size",
145 | metadata={
146 | "help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
147 | },
148 | )
149 | record_a2a_perf_stats: Optional[bool] = field(
150 | default=False,
151 | metadata={"help": "records all to all perf stats during distributed training"},
152 | )
153 | dummy_a2a: Optional[bool] = field(
154 | default=False,
155 | metadata={
156 | "help": "By passes all to all during distributed training by returning the input buffer as output"
157 | },
158 | )
159 | moe_batch_prioritized_routing: Optional[bool] = field(
160 | default=False,
161 | metadata={
162 | "help": "if true orders token by the gate prob before capacity dropping."
163 | },
164 | )
165 | use_xmoe: Optional[bool] = field(
166 | default=False,
167 | )
168 |
169 | # options from other parts of the config
170 | add_bos_token: bool = II("task.add_bos_token")
171 | tokens_per_sample: int = II("task.tokens_per_sample")
172 | max_target_positions: Optional[int] = II("task.max_target_positions")
173 | tpu: bool = II("common.tpu")
174 | memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
175 | fp16: bool = II("common.fp16")
176 | fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads")
177 | ddp_backend: str = II("distributed_training.ddp_backend")
178 | world_size: int = II("distributed_training.distributed_world_size")
179 | distributed_rank: int = II("distributed_training.distributed_rank")
180 | ddp_rank: int = II("distributed_training.distributed_rank")
181 | deepnorm: Optional[bool] = field(
182 | default=False,
183 | )
184 | subln: Optional[bool] = field(
185 | default=False,
186 | )
187 | xpos_rel_pos: Optional[bool] = field(
188 | default=False,
189 | metadata={"help": "use XPos as the relative position embhedding"},
190 | )
191 | block_size: Optional[int] = field(
192 | default=2048,
193 | )
194 | rel_pos_buckets: Optional[int] = field(
195 | default=0,
196 | )
197 | max_rel_pos: Optional[int] = field(
198 | default=0,
199 | )
200 |
201 |
202 | @register_model("lm", dataclass=LanguageConfig)
203 | class LanguageModel(FairseqLanguageModel):
204 | def __init__(self, args, decoder):
205 | self.args = args
206 | super().__init__(decoder)
207 |
208 | @classmethod
209 | def build_model(cls, args, task):
210 |
211 | if getattr(args, "max_target_positions", None) is None:
212 | args.max_target_positions = getattr(
213 | args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
214 | )
215 |
216 | embed_tokens = cls.build_embedding(
217 | args, task.source_dictionary, args.decoder_embed_dim
218 | )
219 |
220 | embed_positions = (
221 | PositionalEmbedding(
222 | args.max_target_positions,
223 | args.decoder_embed_dim,
224 | task.dictionary.pad(),
225 | learned=args.decoder_learned_pos,
226 | )
227 | if not args.no_token_positional_embeddings
228 | else None
229 | )
230 |
231 | if args.share_decoder_input_output_embed:
232 | output_projection = torch.nn.Linear(
233 | embed_tokens.weight.shape[1],
234 | embed_tokens.weight.shape[0],
235 | bias=False,
236 | )
237 | output_projection.weight = embed_tokens.weight
238 | else:
239 | output_projection = torch.nn.Linear(
240 | args.decoder_embed_dim, len(task.dictionary), bias=False
241 | )
242 | torch.nn.init.normal_(
243 | output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
244 | )
245 |
246 | if getattr(args, "moe_freq", 0) > 0 and (
247 | getattr(args, "fp16", False)
248 | and not getattr(args, "memory_efficient_fp16", False)
249 | and getattr(args, "ddp_backend", None) != "fully_sharded"
250 | ):
251 | assert (
252 | args.fp16_no_flatten_grads
253 | ), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
254 |
255 | args.ddp_rank = distributed_utils.get_data_parallel_rank()
256 |
257 | config = DecoderConfig()
258 | config.override(args)
259 |
260 | decoder = LMDecoder(
261 | config,
262 | embed_tokens,
263 | embed_positions,
264 | output_projection,
265 | is_encoder_decoder=False,
266 | dictionary=task.dictionary,
267 | )
268 |
269 | return cls(args, decoder)
270 |
271 | @classmethod
272 | def build_embedding(cls, args, dictionary, embed_dim, path=None):
273 | return Embedding(len(dictionary), embed_dim, dictionary.pad())
274 |
275 |
276 | class LMDecoder(Decoder, FairseqIncrementalDecoder):
277 | def forward(self, src_tokens, **kwargs):
278 | self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
279 | return super().forward(src_tokens, self_attn_padding_mask, **kwargs)
280 |
281 | def max_positions(self):
282 | if self.embed_positions is not None:
283 | return self.embed_positions.max_positions
284 | else:
285 | return DEFAULT_MAX_TARGET_POSITIONS
286 |
287 | def reorder_incremental_state_scripting(
288 | self,
289 | incremental_state,
290 | new_order,
291 | ):
292 | for module in incremental_state:
293 | for key in incremental_state[module]:
294 | result = incremental_state[module][key].index_select(0, new_order)
295 | incremental_state[module][key] = result
296 |
297 |
298 | @register_model_architecture("lm", "lm_base")
299 | def base_lm_architecture(args):
300 | # backward compatibility for older model checkpoints
301 | if hasattr(args, "no_tie_adaptive_proj"):
302 | # previous models defined --no-tie-adaptive-proj, so use the existence of
303 | # that option to determine if this is an "old" model checkpoint
304 | args.no_decoder_final_norm = True # old models always set this to True
305 | if args.no_tie_adaptive_proj is False:
306 | args.tie_adaptive_proj = True
307 | if hasattr(args, "decoder_final_norm"):
308 | args.no_decoder_final_norm = not args.decoder_final_norm
309 |
310 | args.dropout = getattr(args, "dropout", 0.1)
311 | args.attention_dropout = getattr(args, "attention_dropout", 0.0)
312 |
313 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
314 | args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
315 | args.decoder_layers = getattr(args, "decoder_layers", 6)
316 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
317 | args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
318 | args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
319 | args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
320 | args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
321 | args.activation_fn = getattr(args, "activation_fn", "relu")
322 |
323 | args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
324 | args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
325 |
326 | args.base_layers = getattr(args, "base_layers", 0)
327 | args.base_sublayers = getattr(args, "base_sublayers", 1)
328 | args.base_shuffle = getattr(args, "base_shuffle", False)
329 |
330 | args.add_bos_token = getattr(args, "add_bos_token", False)
331 | args.no_token_positional_embeddings = getattr(
332 | args, "no_token_positional_embeddings", True
333 | )
334 | args.xpos_rel_pos = getattr(
335 | args, "xpos_rel_pos", True
336 | )
337 | args.block_size = getattr(
338 | args, "block_size", 2048
339 | )
340 | args.share_decoder_input_output_embed = getattr(
341 | args, "share_decoder_input_output_embed", False
342 | )
343 | args.character_embeddings = getattr(args, "character_embeddings", False)
344 |
345 | args.decoder_output_dim = getattr(
346 | args, "decoder_output_dim", args.decoder_embed_dim
347 | )
348 | args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
349 |
350 | # Model training is not stable without this
351 | args.decoder_normalize_before = True
352 | args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
353 |
354 | args.adaptive_input = getattr(args, "adaptive_input", False)
355 | args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
356 | args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
357 |
358 | args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
359 | args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
360 |
361 | args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
362 | args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
363 | args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
364 | args.offload_activations = getattr(args, "offload_activations", False)
365 | if args.offload_activations:
366 | args.checkpoint_activations = True
367 |
368 | @register_model_architecture("lm", "lm_medium")
369 | def lm_medium(args):
370 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
371 | args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
372 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
373 | args.decoder_layers = getattr(args, "decoder_layers", 24)
374 | base_lm_architecture(args)
375 |
376 | @register_model_architecture("lm", "lm_large")
377 | def lm_large(args):
378 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280)
379 | args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120)
380 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20)
381 | args.decoder_layers = getattr(args, "decoder_layers", 36)
382 | base_lm_architecture(args)
383 |
384 |
385 | @register_model_architecture("lm", "lm_xl")
386 | def lm_xl(args):
387 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1600)
388 | args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6400)
389 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 25)
390 | args.decoder_layers = getattr(args, "decoder_layers", 48)
391 | base_lm_architecture(args)
392 |
393 |
394 | @register_model_architecture("lm", "lm_base_abs")
395 | def lm_base_abs(args):
396 | args.xpos_rel_pos = getattr(
397 | args, "xpos_rel_pos", False
398 | )
399 | args.no_token_positional_embeddings = getattr(
400 | args, "no_token_positional_embeddings", False
401 | )
402 | base_lm_architecture(args)
403 |
404 | @register_model_architecture("lm", "lm_base_bucket")
405 | def lm_base_bucket(args):
406 | args.xpos_rel_pos = getattr(
407 | args, "xpos_rel_pos", False
408 | )
409 | args.rel_pos_buckets = getattr(
410 | args, "rel_pos_buckets", 128
411 | )
412 | args.max_rel_pos = getattr(
413 | args, "max_rel_pos", 2048
414 | )
415 | base_lm_architecture(args)
416 |
417 |
--------------------------------------------------------------------------------
/examples/fairseq/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import argparse
5 | import importlib
6 | import os
7 |
8 | # register dataclass
9 | TASK_DATACLASS_REGISTRY = {}
10 | TASK_REGISTRY = {}
11 | TASK_CLASS_NAMES = set()
12 |
13 | # automatically import any Python files in the tasks/ directory
14 | tasks_dir = os.path.dirname(__file__)
15 | for file in os.listdir(tasks_dir):
16 | path = os.path.join(tasks_dir, file)
17 | if (
18 | not file.startswith("_")
19 | and not file.startswith(".")
20 | and (file.endswith(".py") or os.path.isdir(path))
21 | ):
22 | task_name = file[: file.find(".py")] if file.endswith(".py") else file
23 | module = importlib.import_module("tasks." + task_name)
24 |
25 | # expose `task_parser` for sphinx
26 | if task_name in TASK_REGISTRY:
27 | parser = argparse.ArgumentParser(add_help=False)
28 | group_task = parser.add_argument_group("Task name")
29 | # fmt: off
30 | group_task.add_argument('--task', metavar=task_name,
31 | help='Enable this task with: ``--task=' + task_name + '``')
32 | # fmt: on
33 | group_args = parser.add_argument_group("Additional command-line arguments")
34 | TASK_REGISTRY[task_name].add_args(group_args)
35 | globals()[task_name + "_parser"] = parser
36 |
--------------------------------------------------------------------------------
/examples/fairseq/tasks/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/examples/fairseq/tasks/data/basic_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch
5 | from infinibatch.iterators import CheckpointableIterator
6 |
7 | from . import utils
8 |
9 |
10 | class BaseBatchGen(CheckpointableIterator):
11 | """
12 | This is a base class for batch generators that use infinibatch
13 | """
14 |
15 | def __init__(self):
16 | self._iter = None
17 | self.epoch = 1
18 | self.next_epoch_idx = 1
19 | self.sharded_checkpoint = True
20 | self.should_close_after_finished = True
21 |
22 | def _build_iter(self):
23 | """
24 | Build infinibatch iterator and assign to self._iter
25 | """
26 | raise NotImplementedError()
27 |
28 | def _move_to_tensor(self, batch):
29 | def to_tensor(x):
30 | return torch.tensor(x)
31 |
32 | return utils.apply_to_sample(to_tensor, batch)
33 |
34 | @property
35 | def iterator(self):
36 | if self._iter is None:
37 | raise NotImplementedError("_build_iter() must called first")
38 | return self._iter
39 |
40 | def __iter__(self):
41 | if self._iter is None:
42 | raise NotImplementedError("_build_iter() must called first")
43 | return self._iter
44 |
45 | def __next__(self):
46 | return next(self._iter)
47 |
48 | def setstate(self, value):
49 | self._iter.setstate(value)
50 |
51 | def getstate(self):
52 | return self._iter.getstate()
53 |
54 | def close(self):
55 | self._iter.close()
56 |
57 | def __len__(self) -> int:
58 | return 819200000
59 |
60 | def next_epoch_itr(
61 | self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
62 | ):
63 | return self
64 |
65 | def end_of_epoch(self) -> bool:
66 | return False
67 |
68 | def state_dict(self):
69 | """Returns a dictionary containing a whole state of the iterator."""
70 | return self.getstate()
71 |
72 | def load_state_dict(self, state_dict):
73 | """Copies the state of the iterator from the given *state_dict*."""
74 | self.setstate(state_dict)
75 |
76 | @property
77 | def first_batch(self):
78 | return "DUMMY"
79 |
--------------------------------------------------------------------------------
/examples/fairseq/tasks/data/mlm_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import copy
5 | import itertools
6 | import os
7 |
8 | import numpy as np
9 | from infinibatch import iterators
10 |
11 | from .basic_loader import BaseBatchGen
12 | from .utils import NativeCheckpointableIterator, WeightIterator
13 |
14 |
15 | class MLMLoader(BaseBatchGen):
16 | def __init__(
17 | self,
18 | args,
19 | dataset,
20 | dictionary,
21 | tokenizer,
22 | max_tokens=None,
23 | max_sentences=None,
24 | max_positions=None,
25 | ignore_invalid_inputs=False,
26 | required_batch_size_multiple=1,
27 | seed=1,
28 | num_shards=1,
29 | shard_id=0,
30 | ):
31 | super().__init__()
32 | self.args = args
33 | self.data = dataset.data
34 | self.data_dir = dataset.data_dir
35 | self.shuffle = dataset.shuffle
36 | self.dictionary = dictionary
37 | self.tokenizer = tokenizer
38 |
39 | self.max_tokens = max_tokens
40 | self.max_sentences = max_sentences
41 | self.max_positions = max_positions
42 | self.tokens_per_sample = args.tokens_per_sample
43 | self.sample_break_mode = args.sample_break_mode
44 | self.ignore_invalid_inputs = ignore_invalid_inputs
45 | self.required_batch_size_multiple = required_batch_size_multiple
46 | self.seed = str(seed)
47 | self.num_shards = num_shards
48 | self.shard_id = shard_id
49 |
50 | self.batch_read_ahead = args.batch_read_ahead
51 |
52 | self._build_iter()
53 |
54 | def _build_iter(self):
55 | tokenized_lines = self._multilingual_tokenize()
56 | self.padded_batches = self._batchify(tokenized_lines)
57 |
58 | prefetch_batches = iterators.PrefetchIterator(
59 | self.padded_batches,
60 | buffer_size=10000,
61 | buffer_in_main_process=True,
62 | log_empty_buffer_warning=True and self.shard_id == 0,
63 | )
64 |
65 | prefetch_batches = iterators.MapIterator(prefetch_batches, self._move_to_tensor)
66 |
67 | self._iter = prefetch_batches
68 |
69 | def _multilingual_tokenize(self):
70 | multilingual_iters = []
71 | weights = []
72 |
73 | for data in self.data:
74 | multilingual_iters.append(self._tokenize(data))
75 | if "weight" in data:
76 | weights.append(float(data["weight"]))
77 | else:
78 | weights.append(int(data["count"]))
79 |
80 | if len(multilingual_iters) == 1:
81 | return multilingual_iters[0]
82 |
83 | sampling_iterator = WeightIterator(weights)
84 | control_iterator = NativeCheckpointableIterator(sampling_iterator)
85 | tokenized_lines = iterators.MultiplexIterator(
86 | control_iterator, multilingual_iters
87 | )
88 |
89 | return tokenized_lines
90 |
91 | def _tokenize(self, data):
92 | """
93 | data:
94 | {
95 | 'source': list[Path],
96 | 'source_lang': str,
97 | 'count': int,
98 | 'weight': float,
99 | 'name': str,
100 | }
101 | """
102 | dataset = list(
103 | zip(
104 | data["source"],
105 | itertools.repeat(data["source_lang"]),
106 | )
107 | )
108 |
109 | if self.shuffle:
110 | chunk_files = iterators.InfinitePermutationSourceIterator(
111 | dataset,
112 | seed=self.seed,
113 | shuffle=self.shuffle,
114 | num_instances=self.num_shards,
115 | instance_rank=self.shard_id,
116 | )
117 | else:
118 | chunk_files = iterators.ChunkedSourceIterator(
119 | dataset,
120 | num_instances=self.num_shards,
121 | instance_rank=self.shard_id,
122 | )
123 |
124 | tokenized_lines = iterators.SelectManyIterator(
125 | chunk_files, lambda files: self._read_from_files(*files)
126 | )
127 | tokenized_lines = iterators.SamplingRandomMapIterator(
128 | tokenized_lines, self._prepare, self.seed
129 | )
130 |
131 | return tokenized_lines
132 |
133 | def _batchify(self, lines):
134 |
135 | if self.max_sentences is not None:
136 | if self.batch_read_ahead > 0:
137 | lines = iterators.BlockwiseShuffleIterator(
138 | lines, self.batch_read_ahead, self.seed
139 | )
140 | batches = iterators.FixedBatchIterator(lines, self.max_sentences)
141 | else:
142 |
143 | def dynamic_batch_size(sample):
144 | lengths = [len(x) for x in sample]
145 | batch_size = self.max_tokens // max(lengths)
146 | batch_size = (
147 | batch_size
148 | // self.required_batch_size_multiple
149 | * self.required_batch_size_multiple
150 | )
151 | return max(1, batch_size)
152 |
153 | batches = iterators.BucketedReadaheadBatchIterator(
154 | lines,
155 | read_ahead=self.batch_read_ahead,
156 | key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
157 | batch_size=dynamic_batch_size,
158 | shuffle=self.shuffle,
159 | seed=self.seed,
160 | )
161 |
162 | def collate(batch):
163 | batch_size = len(batch)
164 |
165 | mlm_source_max_length = max([len(x[0]) for x in batch])
166 | mlm_target_max_length = max([len(x[1]) for x in batch])
167 | s2s_source_max_length = max([len(x[2]) for x in batch])
168 | s2s_target_max_length = max([len(x[3]) for x in batch])
169 |
170 | mlm_source_ids = np.full(
171 | shape=(batch_size, mlm_source_max_length),
172 | dtype=np.int32,
173 | fill_value=self.dictionary.pad(),
174 | )
175 | mlm_target_ids = np.full(
176 | shape=(batch_size, mlm_target_max_length),
177 | dtype=np.int32,
178 | fill_value=self.dictionary.pad(),
179 | )
180 | s2s_source_ids = np.full(
181 | shape=(batch_size, s2s_source_max_length),
182 | dtype=np.int32,
183 | fill_value=self.dictionary.pad(),
184 | )
185 | s2s_target_ids = np.full(
186 | shape=(batch_size, s2s_target_max_length - 1),
187 | dtype=np.int32,
188 | fill_value=self.dictionary.pad(),
189 | )
190 | s2s_prev_input_ids = np.full(
191 | shape=(batch_size, s2s_target_max_length - 1),
192 | dtype=np.int32,
193 | fill_value=self.dictionary.pad(),
194 | )
195 |
196 | for i, (
197 | mlm_input_ids,
198 | mlm_label_ids,
199 | s2s_input_ids,
200 | s2s_label_ids,
201 | ) in enumerate(batch):
202 | mlm_source_ids[i, : len(mlm_input_ids)] = mlm_input_ids
203 | mlm_target_ids[i, : len(mlm_label_ids)] = mlm_label_ids
204 | s2s_source_ids[i, : len(s2s_input_ids)] = s2s_input_ids
205 | s2s_target_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[1:]
206 | s2s_prev_input_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[:-1]
207 |
208 | ret_batch = {
209 | "net_input": {
210 | "src_tokens": mlm_source_ids.astype(np.int64),
211 | },
212 | "target": mlm_target_ids.astype(np.int64),
213 | "nsentences": batch_size,
214 | "ntokens": sum([len(x[0]) for x in batch]),
215 | }
216 |
217 | return ret_batch
218 |
219 | padded_batches = iterators.MapIterator(batches, collate)
220 |
221 | return padded_batches
222 |
223 | def _prepare(self, _random, doc):
224 | nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc)
225 | nonnoise_spans, noise_spans = self._span_corruption(_random, doc)
226 | return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans
227 |
228 | def _mask_lm(self, _random, doc):
229 | def mask_tokens():
230 | return ""
231 |
232 | length = len(doc)
233 | mask_tokens_num = int(length * self.args.mask_prob)
234 | mask_tokens_num = min(max(mask_tokens_num, 1), length - 1)
235 | possible_mask_positions = _random.sample(range(length), k=mask_tokens_num)
236 | possible_mask_positions = sorted(possible_mask_positions)
237 |
238 | nonmasked_tokens = copy.deepcopy(doc)
239 | masked_tokens = [self.dictionary.pad() for _ in range(len(doc))]
240 |
241 | for position in possible_mask_positions:
242 | # masked_tokens.append(nonmasked_tokens[position])
243 | masked_tokens[position] = nonmasked_tokens[position]
244 | nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()]
245 |
246 | return nonmasked_tokens, masked_tokens
247 |
248 | def _span_corruption(self, _random, doc):
249 | def mask_tokens(i):
250 | return f""
251 |
252 | length = len(doc)
253 | noise_tokens_num = int(length * self.args.mask_prob)
254 | noise_tokens_num = min(max(noise_tokens_num, 1), length - 1)
255 | noise_spans_num = int(noise_tokens_num / self.args.span_length)
256 | noise_spans_num = max(noise_spans_num, 1)
257 | nonnoise_tokens_num = length - noise_tokens_num
258 |
259 | if noise_spans_num == 1:
260 | noise_split_positions = [0, noise_tokens_num]
261 | else:
262 | possible_split_positions = list(range(1, noise_tokens_num))
263 | _random.shuffle(possible_split_positions)
264 | noise_split_positions = sorted(
265 | possible_split_positions[: noise_spans_num - 1]
266 | )
267 | noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
268 |
269 | possible_insert_positions = list(range(nonnoise_tokens_num))
270 | _random.shuffle(possible_insert_positions)
271 | noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num])
272 |
273 | nonnoise_spans, noise_spans = [], []
274 | last_end = 0
275 | for i in range(noise_spans_num):
276 | start_pos = noise_insert_positions[i] + noise_split_positions[i]
277 | end_pos = noise_insert_positions[i] + noise_split_positions[i + 1]
278 | mask_id = self.dictionary.indices[mask_tokens(i)]
279 |
280 | if getattr(self.args, "remove_target_sentinel", False):
281 | noise_spans.append(doc[start_pos:end_pos])
282 | else:
283 | noise_spans.append([mask_id] + doc[start_pos:end_pos])
284 |
285 | if getattr(self.args, "remove_source_sentinel", False):
286 | nonnoise_spans.extend(doc[last_end:start_pos])
287 | else:
288 | nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id])
289 |
290 | last_end = end_pos
291 |
292 | nonnoise_spans.extend(doc[last_end:])
293 | noise_spans = sum(noise_spans, [])
294 |
295 | return nonnoise_spans, noise_spans
296 |
297 | def _read_from_files(self, source_file, source_lang):
298 | # data = []
299 | file_path = os.path.join(self.data_dir, source_file)
300 |
301 | if not os.path.exists(file_path):
302 | print("| file {} not exists".format(file_path), flush=True)
303 | return iter([]) # skip bad file
304 |
305 | with open(file_path, "r", encoding="utf8") as f:
306 | lines = f.read().strip().split("\n")
307 |
308 | doc = [self.dictionary.bos()]
309 | for line in lines:
310 | if line == "":
311 | if self.sample_break_mode == "complete_doc":
312 | # data.append(doc)
313 | yield doc
314 | doc = [self.dictionary.bos()]
315 | continue
316 |
317 | tokenized_line = self.tokenizer.EncodeAsPieces(line)
318 | tokenized_id = [
319 | self.dictionary.index(token) for token in tokenized_line
320 | ] + [self.dictionary.eos_index]
321 |
322 | if len(tokenized_id) > self.tokens_per_sample:
323 | continue
324 | if len(doc) + len(tokenized_id) > self.tokens_per_sample:
325 | # data.append(doc)
326 | yield doc
327 | doc = [self.dictionary.bos()]
328 | doc.extend(tokenized_id)
329 |
330 | if len(doc) > 1 and len(doc) <= self.tokens_per_sample:
331 | # data.append(doc)
332 | yield doc
333 |
334 | # return data
335 |
--------------------------------------------------------------------------------
/examples/fairseq/tasks/data/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import collections
5 | from random import Random
6 | from typing import Dict, Iterable, Optional
7 |
8 | import numpy as np
9 | from infinibatch import iterators
10 |
11 |
12 | def apply_to_sample(f, sample):
13 | if hasattr(sample, "__len__") and len(sample) == 0:
14 | return {}
15 |
16 | def _apply(x):
17 | if isinstance(x, np.ndarray):
18 | return f(x)
19 | elif isinstance(x, collections.OrderedDict):
20 | # OrderedDict has attributes that needs to be preserved
21 | od = collections.OrderedDict(
22 | (key, _apply(value)) for key, value in x.items()
23 | )
24 | od.__dict__ = x.__dict__
25 | return od
26 | elif isinstance(x, dict):
27 | return {key: _apply(value) for key, value in x.items()}
28 | elif isinstance(x, list):
29 | return [_apply(x) for x in x]
30 | elif isinstance(x, tuple):
31 | return tuple(_apply(x) for x in x)
32 | elif isinstance(x, set):
33 | return {_apply(x) for x in x}
34 | else:
35 | return x
36 |
37 | return _apply(sample)
38 |
39 |
40 | class NativeCheckpointableIterator(iterators.CheckpointableIterator):
41 | def __init__(self, iterable: Iterable):
42 | self._input_iterable = iterable
43 | self.setstate(None)
44 |
45 | def getstate(self) -> Dict:
46 | return {"num_items_yielded": self._num_items_yielded}
47 |
48 | def setstate(self, checkpoint: Optional[Dict]):
49 | self._iterator = iter(self._input_iterable)
50 | self._num_items_yielded = (
51 | iterators._advance_iterator(self._iterator, checkpoint["num_items_yielded"])
52 | if checkpoint is not None
53 | else 0
54 | )
55 |
56 | def __next__(self):
57 | item = next(self._iterator)
58 | self._num_items_yielded += 1
59 | return item
60 |
61 | def close(self):
62 | pass
63 |
64 |
65 | class WeightIterator(object):
66 | def __init__(self, weights, seed):
67 | self.weights = weights
68 | self.seed = seed
69 | self.control_index = list(range(len(weights)))
70 | self.setstate(None)
71 |
72 | def __iter__(self):
73 | return self
74 |
75 | def getstate(self):
76 | return {"random_state": self._random_state}
77 |
78 | def setstate(self, checkpoint):
79 | self._random_state = checkpoint["random_state"] if checkpoint else None
80 | self._random = (
81 | None # this will trigger the lazy initialization in self.__next__
82 | )
83 |
84 | def __next__(self):
85 | if self._random is None:
86 | self._random = Random(self.seed)
87 | if self._random_state is not None:
88 | self._random.setstate(self._random_state)
89 | idx = self._random.choices(self.control_index, self.weights)[0]
90 | self._random_state = self._random.getstate()
91 | return idx
92 |
93 | def close(self):
94 | pass
95 |
--------------------------------------------------------------------------------
/examples/fairseq/tasks/pretraining.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import json
5 | import logging
6 | import os
7 | from argparse import Namespace
8 |
9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | #
11 | # This source code is licensed under the MIT license found in the
12 | # LICENSE file in the root directory of this source tree.
13 | from dataclasses import dataclass, field
14 |
15 | import sentencepiece as spm
16 | from fairseq import utils
17 | from fairseq.data import Dictionary
18 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass
19 | from fairseq.tasks import FairseqTask, register_task
20 | from omegaconf import II, MISSING
21 |
22 | from .data.mlm_loader import MLMLoader
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
27 | SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
28 |
29 |
30 | @dataclass
31 | class PretrainingConfig(FairseqDataclass):
32 | data: str = field(
33 | default=MISSING,
34 | metadata={
35 | "help": "colon separated path to data directories list, \
36 | will be iterated upon during epochs in round-robin manner"
37 | },
38 | )
39 | sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
40 | default="complete",
41 | metadata={
42 | "help": 'If omitted or "none", fills each sample with tokens-per-sample '
43 | 'tokens. If set to "complete", splits samples only at the end '
44 | "of sentence, but may include multiple sentences per sample. "
45 | '"complete_doc" is similar but respects doc boundaries. '
46 | 'If set to "eos", includes only one sentence per sample.'
47 | },
48 | )
49 | tokens_per_sample: int = field(
50 | default=1024,
51 | metadata={"help": "max number of tokens per sample for LM dataset"},
52 | )
53 | mask_prob: float = field(
54 | default=0.15,
55 | metadata={"help": "probability of replacing a token with mask"},
56 | )
57 | leave_unmasked_prob: float = field(
58 | default=0.1,
59 | metadata={"help": "probability that a masked token is unmasked"},
60 | )
61 | random_token_prob: float = field(
62 | default=0.1,
63 | metadata={"help": "probability of replacing a token with a random token"},
64 | )
65 | freq_weighted_replacement: bool = field(
66 | default=False,
67 | metadata={"help": "sample random replacement words based on word frequencies"},
68 | )
69 | mask_whole_words: bool = field(
70 | default=False,
71 | metadata={"help": "mask whole words; you may also want to set --bpe"},
72 | )
73 | mask_multiple_length: int = field(
74 | default=1,
75 | metadata={"help": "repeat the mask indices multiple times"},
76 | )
77 | mask_stdev: float = field(
78 | default=0.0,
79 | metadata={"help": "stdev of the mask length"},
80 | )
81 | shorten_method: SHORTEN_METHOD_CHOICES = field(
82 | default="none",
83 | metadata={
84 | "help": "if not none, shorten sequences that exceed --tokens-per-sample"
85 | },
86 | )
87 | shorten_data_split_list: str = field(
88 | default="",
89 | metadata={
90 | "help": "comma-separated list of dataset splits to apply shortening to, "
91 | 'e.g., "train,valid" (default: all dataset splits)'
92 | },
93 | )
94 | seed: int = II("common.seed")
95 | span_length: float = field(
96 | default=3.0,
97 | metadata={"help": "average span length for masking"},
98 | )
99 | remove_source_sentinel: bool = field(
100 | default=False,
101 | metadata={"help": "remove the source sentinel for the span corruption task"},
102 | )
103 | remove_target_sentinel: bool = field(
104 | default=False,
105 | metadata={"help": "remove the target sentinel for the span corruption task"},
106 | )
107 | batch_read_ahead: int = field(
108 | default=100000,
109 | metadata={"help": "batch read ahead size for infinibatch"},
110 | )
111 | required_batch_size_multiple: int = II("dataset.required_batch_size_multiple")
112 | spm_model: str = field(
113 | default="",
114 | metadata={"help": "sentencepice model to tokenize the data"},
115 | )
116 | dict_file: str = field(
117 | default="",
118 | metadata={"help": ""},
119 | )
120 |
121 |
122 | @register_task("pretraining", dataclass=PretrainingConfig)
123 | class PLMTask(FairseqTask):
124 | def __init__(self, cfg, dictionary, tokenizer):
125 | super().__init__(cfg)
126 | self.cfg = cfg
127 | self.dictionary = dictionary
128 | self.tokenizer = tokenizer
129 | self.seed = cfg.seed
130 | self.mask_idx = dictionary.index("")
131 |
132 | @classmethod
133 | def setup_task(cls, cfg, **kwargs):
134 | paths = utils.split_paths(cfg.data)
135 | assert len(paths) > 0
136 | if cfg.dict_file != "":
137 | dictionary = Dictionary.load(cfg.dict_file)
138 | else:
139 | dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
140 |
141 | # add mask token
142 | dictionary.add_symbol("")
143 | for i in range(100):
144 | dictionary.add_symbol(f"")
145 |
146 | dictionary.pad_to_multiple_(cfg.required_batch_size_multiple)
147 | logger.info("dictionary: {} types".format(len(dictionary)))
148 |
149 | # tokenizer = SentencepieceBPE(Namespace(sentencepiece_model=cfg.spm_model))
150 | tokenizer = spm.SentencePieceProcessor()
151 | tokenizer.Load(cfg.spm_model)
152 | return cls(cfg, dictionary, tokenizer)
153 |
154 | def load_dataset(self, split, epoch=1, combine=False, **kwargs):
155 | self.datasets[split] = {
156 | "data": json.load(open(f"{self.cfg.data}/json/{split}.json")),
157 | "data_dir": self.cfg.data,
158 | "shuffle": True if split == "train" else False,
159 | }
160 | self.datasets[split] = Namespace(**self.datasets[split])
161 |
162 | def dataset(self, split):
163 | if split not in self.datasets:
164 | raise KeyError("Dataset not loaded: " + split)
165 |
166 | return self.datasets[split]
167 |
168 | def get_batch_iterator(
169 | self,
170 | dataset,
171 | max_tokens=None,
172 | max_sentences=None,
173 | max_positions=None,
174 | ignore_invalid_inputs=False,
175 | required_batch_size_multiple=1,
176 | seed=1,
177 | num_shards=1,
178 | shard_id=0,
179 | num_workers=0,
180 | epoch=1,
181 | data_buffer_size=0,
182 | disable_iterator_cache=False,
183 | ):
184 | return MLMLoader(
185 | self.cfg,
186 | dataset,
187 | self.dictionary,
188 | self.tokenizer,
189 | max_tokens=max_tokens,
190 | max_sentences=max_sentences,
191 | max_positions=max_positions,
192 | ignore_invalid_inputs=ignore_invalid_inputs,
193 | required_batch_size_multiple=required_batch_size_multiple,
194 | seed=seed,
195 | num_shards=num_shards,
196 | shard_id=shard_id,
197 | )
198 |
199 | @property
200 | def source_dictionary(self):
201 | return self.dictionary
202 |
203 | @property
204 | def target_dictionary(self):
205 | return self.dictionary
206 |
--------------------------------------------------------------------------------
/examples/fairseq/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | # flake8: noqa
5 | import models
6 | import tasks
7 | from fairseq_cli.train import cli_main
8 |
9 | if __name__ == "__main__":
10 | cli_main()
11 |
--------------------------------------------------------------------------------
/examples/fairseq/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/examples/fairseq/utils/sparse_clip.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import math
5 | import warnings
6 |
7 | import torch
8 | import torch.distributed as dist
9 | from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
10 |
11 |
12 | @torch.no_grad()
13 | def clip_grad_norm_(
14 | params, max_norm, moe_expert_count, aggregate_norm_fn=None
15 | ) -> torch.Tensor:
16 | def grad_exists(p):
17 | return p is not None and getattr(p, "grad", None) is not None
18 |
19 | if isinstance(params, torch.Tensor):
20 | params = [params]
21 | params = list(params)
22 | params = list(filter(grad_exists, params))
23 | grads, expert_grads, base_expert_grads, sharded_grads = [], [], [], []
24 | denom = math.sqrt(max(dist.get_global_world_size(), moe_expert_count))
25 | for p in params:
26 | if hasattr(p, "expert"):
27 | expert_grads.append(p.grad.detach() / denom)
28 | elif hasattr(p, "base_expert"):
29 | base_expert_grads.append(p.grad.detach())
30 | elif hasattr(p, "_is_sharded"):
31 | sharded_grads.append(p.grad.detach())
32 | else:
33 | grads.append(p.grad.detach())
34 | if len(grads) == 0:
35 | if len(params) > 0:
36 | total_norm = params[0].new_tensor(0.0)
37 | else:
38 | total_norm = torch.tensor(0.0)
39 | elif len(grads) == 1:
40 | total_norm = torch.norm(grads[0], p=2, dtype=torch.float32)
41 | else:
42 | if multi_tensor_l2norm_available:
43 | total_norm = multi_tensor_total_norm(grads)
44 | else:
45 | if torch.cuda.is_available():
46 | warnings.warn(
47 | "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; "
48 | "you may get better performance by installing NVIDIA's apex library"
49 | )
50 | device = torch.cuda.current_device()
51 | elif grads[0].device.type == "xla":
52 | device = grads[0].device
53 | else:
54 | device = torch.device("cpu")
55 | total_norm = torch.norm(
56 | torch.stack(
57 | [torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]
58 | )
59 | )
60 |
61 | # calculate split_norm and all_reduce with other workers
62 | norms = [total_norm]
63 | for split_grads in [expert_grads, sharded_grads]:
64 | if len(split_grads) == 0:
65 | continue
66 | split_norm = torch.norm(
67 | torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads])
68 | )
69 | if dist.is_initialized():
70 | split_norm.pow_(2)
71 | dist.all_reduce(split_norm)
72 | split_norm.sqrt_()
73 | norms.append(split_norm)
74 | if len(norms) > 1:
75 | total_norm = torch.norm(torch.stack(norms))
76 |
77 | if aggregate_norm_fn is not None:
78 | total_norm = aggregate_norm_fn(total_norm)
79 |
80 | if max_norm > 0:
81 | max_norm = float(max_norm)
82 | clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
83 | for g in grads + expert_grads + sharded_grads + base_expert_grads:
84 | g.mul_(clip_coef)
85 | return total_norm
86 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | from io import open
5 |
6 | from setuptools import find_packages, setup
7 |
8 | setup(
9 | name="torchscale",
10 | version="0.1.1",
11 | author="TorchScale Team",
12 | author_email="Shuming.Ma@microsoft.com",
13 | description="Transformers at any scale",
14 | long_description=open("README.md", "r", encoding="utf-8").read(),
15 | long_description_content_type="text/markdown",
16 | keywords="Transformers at any scale",
17 | license="MIT",
18 | url="https://github.com/msranlp/torchscale",
19 | packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
20 | install_requires=["apex", "torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
21 | python_requires=">=3.8.0",
22 | classifiers=[
23 | "Programming Language :: Python :: 3",
24 | ],
25 | )
26 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/tests/test_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import pytest
5 | import torch
6 |
7 | from torchscale.architecture.config import DecoderConfig
8 | from torchscale.architecture.decoder import Decoder
9 |
10 | testcases = [
11 | {},
12 | {"vocab_size": 64000},
13 | {"activation_fn": "relu"},
14 | {"drop_path_rate": 0.1},
15 | {"decoder_normalize_before": False},
16 | {"no_scale_embedding": False},
17 | {"layernorm_embedding": True},
18 | {"rel_pos_buckets": 32, "max_rel_pos": 256},
19 | {"deepnorm": True, "subln": False, "decoder_normalize_before": False},
20 | {"bert_init": True},
21 | {"multiway": True},
22 | {"share_decoder_input_output_embed": True},
23 | {"checkpoint_activations": True},
24 | {"fsdp": True},
25 | ]
26 |
27 |
28 | @pytest.mark.parametrize("args", testcases)
29 | def test_decoder(args):
30 | config = DecoderConfig(**args)
31 | model = Decoder(config)
32 | prev_output_tokens = torch.ones(2, 10)
33 | token_embeddings = torch.rand(2, 10, config.decoder_embed_dim)
34 | model(
35 | prev_output_tokens=prev_output_tokens,
36 | token_embeddings=token_embeddings,
37 | features_only=True,
38 | )
39 |
--------------------------------------------------------------------------------
/tests/test_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import pytest
5 | import torch
6 |
7 | from torchscale.architecture.config import EncoderConfig
8 | from torchscale.architecture.encoder import Encoder
9 |
10 | testcases = [
11 | {},
12 | {"vocab_size": 64000},
13 | {"activation_fn": "relu"},
14 | {"drop_path_rate": 0.1},
15 | {"encoder_normalize_before": False},
16 | {"no_scale_embedding": False},
17 | {"layernorm_embedding": True},
18 | {"rel_pos_buckets": 32, "max_rel_pos": 256},
19 | {"deepnorm": True, "subln": False, "encoder_normalize_before": False},
20 | {"bert_init": True},
21 | {"multiway": True},
22 | {"share_encoder_input_output_embed": True},
23 | {"checkpoint_activations": True},
24 | {"fsdp": True},
25 | ]
26 |
27 |
28 | @pytest.mark.parametrize("args", testcases)
29 | def test_encoder(args):
30 | config = EncoderConfig(**args)
31 | model = Encoder(config)
32 | token_embeddings = torch.rand(2, 10, config.encoder_embed_dim)
33 | model(src_tokens=None, token_embeddings=token_embeddings)
34 |
--------------------------------------------------------------------------------
/tests/test_encoder_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import pytest
5 | import torch
6 |
7 | from torchscale.architecture.config import EncoderDecoderConfig
8 | from torchscale.architecture.encoder_decoder import EncoderDecoder
9 | from torchscale.component.embedding import PositionalEmbedding, TextEmbedding
10 |
11 | testcases = [
12 | {},
13 | {"vocab_size": 64000},
14 | {"activation_fn": "relu"},
15 | {"drop_path_rate": 0.1},
16 | {"encoder_normalize_before": False, "decoder_normalize_before": False},
17 | {"no_scale_embedding": False},
18 | {"layernorm_embedding": True},
19 | {"rel_pos_buckets": 32, "max_rel_pos": 256},
20 | {
21 | "deepnorm": True,
22 | "subln": False,
23 | "encoder_normalize_before": False,
24 | "decoder_normalize_before": False,
25 | },
26 | {"bert_init": True},
27 | {"multiway": True},
28 | {"share_decoder_input_output_embed": True},
29 | {"share_all_embeddings": True},
30 | {"checkpoint_activations": True},
31 | {"fsdp": True},
32 | ]
33 |
34 |
35 | @pytest.mark.parametrize("args", testcases)
36 | def test_decoder(args):
37 | config = EncoderDecoderConfig(**args)
38 | model = EncoderDecoder(
39 | config,
40 | encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
41 | decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
42 | encoder_embed_positions=PositionalEmbedding(
43 | config.max_source_positions, config.encoder_embed_dim
44 | ),
45 | decoder_embed_positions=PositionalEmbedding(
46 | config.max_target_positions, config.decoder_embed_dim
47 | ),
48 | )
49 |
50 | src_tokens = torch.ones(2, 20).long()
51 | prev_output_tokens = torch.ones(2, 10).long()
52 |
53 | model(
54 | src_tokens=src_tokens,
55 | prev_output_tokens=prev_output_tokens,
56 | features_only=True,
57 | )
58 |
--------------------------------------------------------------------------------
/torchscale/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/torchscale/architecture/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/torchscale/architecture/config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 |
5 | class EncoderConfig(object):
6 | def __init__(self, **kwargs):
7 | self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
8 | self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
9 | self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
10 | self.encoder_layers = kwargs.pop("encoder_layers", 12)
11 | self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
12 | self.activation_fn = kwargs.pop("activation_fn", "gelu")
13 | self.dropout = kwargs.pop("dropout", 0.0)
14 | self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
15 | self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
16 | self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
17 | self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
18 | self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
19 | self.moe_freq = kwargs.pop("moe_freq", 0)
20 | self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
21 | self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
22 | self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
23 | self.moe_eval_capacity_token_fraction = kwargs.pop(
24 | "moe_eval_capacity_token_fraction", 0.25
25 | )
26 | self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
27 | self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
28 | "moe_normalize_gate_prob_before_dropping", False
29 | )
30 | self.use_xmoe = kwargs.pop("use_xmoe", False)
31 | self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
32 | self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
33 | self.deepnorm = kwargs.pop("deepnorm", False)
34 | self.subln = kwargs.pop("subln", True)
35 | self.bert_init = kwargs.pop("bert_init", False)
36 | self.multiway = kwargs.pop("multiway", False)
37 | self.share_encoder_input_output_embed = kwargs.pop(
38 | "share_encoder_input_output_embed", False
39 | )
40 | self.max_source_positions = kwargs.pop("max_source_positions", 1024)
41 | self.no_output_layer = kwargs.pop("no_output_layer", False)
42 | # Text
43 | self.vocab_size = kwargs.pop("vocab_size", -1)
44 | # Vision
45 | self.img_size = kwargs.pop("img_size", 224)
46 | self.patch_size = kwargs.pop("patch_size", 16)
47 | self.in_chans = kwargs.pop("in_chans", 3)
48 | # Fairscale
49 | self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
50 | self.fsdp = kwargs.pop("fsdp", False)
51 | self.ddp_rank = kwargs.pop("ddp_rank", 0)
52 |
53 | if self.deepnorm:
54 | self.encoder_normalize_before = False
55 | self.subln = False
56 | if self.subln:
57 | self.encoder_normalize_before = True
58 | self.deepnorm = False
59 | if self.use_xmoe:
60 | self.moe_normalize_gate_prob_before_dropping = True
61 | self.moe_second_expert_policy = "random"
62 | assert self.moe_freq > 0 and self.moe_expert_count > 0
63 |
64 | def override(self, args):
65 | for hp in self.__dict__.keys():
66 | if getattr(args, hp, None) is not None:
67 | self.__dict__[hp] = getattr(args, hp, None)
68 |
69 |
70 | class DecoderConfig(object):
71 | def __init__(self, **kwargs):
72 | self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
73 | self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
74 | self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
75 | self.decoder_layers = kwargs.pop("decoder_layers", 12)
76 | self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
77 | self.activation_fn = kwargs.pop("activation_fn", "gelu")
78 | self.dropout = kwargs.pop("dropout", 0.0)
79 | self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
80 | self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
81 | self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
82 | self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
83 | self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
84 | self.moe_freq = kwargs.pop("moe_freq", 0)
85 | self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
86 | self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
87 | self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
88 | self.moe_eval_capacity_token_fraction = kwargs.pop(
89 | "moe_eval_capacity_token_fraction", 0.25
90 | )
91 | self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
92 | self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
93 | "moe_normalize_gate_prob_before_dropping", False
94 | )
95 | self.use_xmoe = kwargs.pop("use_xmoe", False)
96 | self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
97 | self.block_size = kwargs.pop("block_size", 2048)
98 | self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
99 | self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
100 | self.deepnorm = kwargs.pop("deepnorm", False)
101 | self.subln = kwargs.pop("subln", True)
102 | self.bert_init = kwargs.pop("bert_init", False)
103 | self.multiway = kwargs.pop("multiway", False)
104 | self.share_decoder_input_output_embed = kwargs.pop(
105 | "share_decoder_input_output_embed", False
106 | )
107 | self.max_target_positions = kwargs.pop("max_target_positions", 1024)
108 | self.no_output_layer = kwargs.pop("no_output_layer", False)
109 | # Text
110 | self.vocab_size = kwargs.pop("vocab_size", -1)
111 | # Fairscale
112 | self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
113 | self.fsdp = kwargs.pop("fsdp", False)
114 | self.ddp_rank = kwargs.pop("ddp_rank", 0)
115 |
116 | if self.deepnorm:
117 | self.decoder_normalize_before = False
118 | self.subln = False
119 | if self.subln:
120 | self.decoder_normalize_before = True
121 | self.deepnorm = False
122 | if self.use_xmoe:
123 | self.moe_normalize_gate_prob_before_dropping = True
124 | self.moe_second_expert_policy = "random"
125 | assert self.moe_freq > 0 and self.moe_expert_count > 0
126 |
127 | def override(self, args):
128 | for hp in self.__dict__.keys():
129 | if getattr(args, hp, None) is not None:
130 | self.__dict__[hp] = getattr(args, hp, None)
131 |
132 |
133 | class EncoderDecoderConfig(object):
134 | def __init__(self, **kwargs):
135 | self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
136 | self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
137 | self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
138 | self.encoder_layers = kwargs.pop("encoder_layers", 12)
139 | self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
140 | self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
141 | self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
142 | self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
143 | self.decoder_layers = kwargs.pop("decoder_layers", 12)
144 | self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
145 | self.activation_fn = kwargs.pop("activation_fn", "gelu")
146 | self.dropout = kwargs.pop("dropout", 0.0)
147 | self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
148 | self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
149 | self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
150 | self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
151 | self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
152 | self.moe_freq = kwargs.pop("moe_freq", 0)
153 | self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
154 | self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
155 | self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
156 | self.moe_eval_capacity_token_fraction = kwargs.pop(
157 | "moe_eval_capacity_token_fraction", 0.25
158 | )
159 | self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
160 | self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
161 | "moe_normalize_gate_prob_before_dropping", False
162 | )
163 | self.use_xmoe = kwargs.pop("use_xmoe", False)
164 | self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
165 | self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
166 | self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
167 | self.deepnorm = kwargs.pop("deepnorm", False)
168 | self.subln = kwargs.pop("subln", True)
169 | self.bert_init = kwargs.pop("bert_init", False)
170 | self.multiway = kwargs.pop("multiway", False)
171 | self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
172 | self.share_decoder_input_output_embed = kwargs.pop(
173 | "share_decoder_input_output_embed", False
174 | )
175 | self.max_source_positions = kwargs.pop("max_source_positions", 1024)
176 | self.max_target_positions = kwargs.pop("max_target_positions", 1024)
177 | self.no_output_layer = kwargs.pop("no_output_layer", False)
178 | # Text
179 | self.vocab_size = kwargs.pop("vocab_size", -1)
180 | # Fairscale
181 | self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
182 | self.fsdp = kwargs.pop("fsdp", False)
183 | self.ddp_rank = kwargs.pop("ddp_rank", 0)
184 |
185 | if self.deepnorm:
186 | self.encoder_normalize_before = False
187 | self.decoder_normalize_before = False
188 | self.subln = False
189 | if self.subln:
190 | self.encoder_normalize_before = True
191 | self.decoder_normalize_before = True
192 | self.deepnorm = False
193 | if self.use_xmoe:
194 | self.moe_normalize_gate_prob_before_dropping = True
195 | self.moe_second_expert_policy = "random"
196 | assert self.moe_freq > 0 and self.moe_expert_count > 0
197 |
198 | def override(self, args):
199 | for hp in self.__dict__.keys():
200 | if getattr(args, hp, None) is not None:
201 | self.__dict__[hp] = getattr(args, hp, None)
202 |
--------------------------------------------------------------------------------
/torchscale/architecture/decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import math
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from apex.normalization import FusedLayerNorm as LayerNorm
11 | from fairscale.nn import checkpoint_wrapper, wrap
12 |
13 | from torchscale.architecture.utils import init_bert_params
14 | from torchscale.component.droppath import DropPath
15 | from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
16 | from torchscale.component.multihead_attention import MultiheadAttention
17 | from torchscale.component.xpos_relative_position import XPos
18 | from torchscale.component.relative_position_bias import RelativePositionBias
19 | from torchscale.component.xmoe.moe_layer import MOELayer
20 | from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
21 |
22 |
23 | class DecoderLayer(nn.Module):
24 | def __init__(
25 | self,
26 | args,
27 | depth,
28 | is_moe_layer=False,
29 | is_encoder_decoder=False,
30 | ):
31 | super().__init__()
32 | self.args = args
33 | self.embed_dim = args.decoder_embed_dim
34 | self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
35 |
36 | if args.drop_path_rate > 0:
37 | drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[
38 | depth
39 | ]
40 | self.drop_path = DropPath(drop_path_prob)
41 | else:
42 | self.drop_path = None
43 |
44 | self.self_attn = self.build_self_attention(self.embed_dim, args)
45 |
46 | self.normalize_before = args.decoder_normalize_before
47 |
48 | self.self_attn_layer_norm = LayerNorm(self.embed_dim)
49 |
50 | if not is_encoder_decoder:
51 | self.encoder_attn = None
52 | self.encoder_attn_layer_norm = None
53 | else:
54 | self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
55 | self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
56 |
57 | self.is_moe_layer = is_moe_layer
58 | self.ffn_dim = args.decoder_ffn_embed_dim
59 |
60 | if not self.is_moe_layer:
61 | self.ffn = self.build_ffn(
62 | self.embed_dim,
63 | self.args,
64 | )
65 | else:
66 | if args.moe_top1_expert:
67 | gate = Top1Gate(
68 | self.embed_dim,
69 | args.moe_expert_count,
70 | use_fp32=args.moe_gating_use_fp32,
71 | moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
72 | use_xmoe=args.use_xmoe,
73 | )
74 | else:
75 | gate = Top2Gate(
76 | self.embed_dim,
77 | args.moe_expert_count,
78 | args.moe_gating_use_fp32,
79 | args.moe_second_expert_policy,
80 | args.moe_normalize_gate_prob_before_dropping,
81 | args.moe_eval_capacity_token_fraction,
82 | use_xmoe=args.use_xmoe,
83 | )
84 | experts = make_experts(args, self.embed_dim, self.ffn_dim)
85 | self.moe_layer = MOELayer(gate, experts, args)
86 |
87 | self.final_layer_norm = LayerNorm(self.embed_dim)
88 |
89 | if args.deepnorm:
90 | if is_encoder_decoder:
91 | self.alpha = math.pow(3.0 * args.decoder_layers, 0.25)
92 | else:
93 | self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
94 | else:
95 | self.alpha = 1.0
96 |
97 | def build_ffn(self, embed_dim, args):
98 | return FeedForwardNetwork(
99 | embed_dim,
100 | self.ffn_dim,
101 | args.activation_fn,
102 | args.dropout,
103 | args.activation_dropout,
104 | args.subln,
105 | )
106 |
107 | def build_self_attention(self, embed_dim, args):
108 | return MultiheadAttention(
109 | args,
110 | embed_dim,
111 | args.decoder_attention_heads,
112 | dropout=args.attention_dropout,
113 | self_attention=True,
114 | encoder_decoder_attention=False,
115 | subln=args.subln,
116 | )
117 |
118 | def build_encoder_attention(self, embed_dim, args):
119 | return MultiheadAttention(
120 | args,
121 | embed_dim,
122 | args.decoder_attention_heads,
123 | dropout=args.attention_dropout,
124 | self_attention=False,
125 | encoder_decoder_attention=True,
126 | subln=args.subln,
127 | )
128 |
129 | def residual_connection(self, x, residual):
130 | return residual * self.alpha + x
131 |
132 | def forward(
133 | self,
134 | x,
135 | encoder_out=None,
136 | encoder_padding_mask=None,
137 | incremental_state=None,
138 | self_attn_mask=None,
139 | self_attn_padding_mask=None,
140 | self_attn_rel_pos=None,
141 | cross_attn_rel_pos=None,
142 | ):
143 | residual = x
144 | if self.normalize_before:
145 | x = self.self_attn_layer_norm(x)
146 |
147 | x, attn = self.self_attn(
148 | query=x,
149 | key=x,
150 | value=x,
151 | key_padding_mask=self_attn_padding_mask,
152 | incremental_state=incremental_state,
153 | attn_mask=self_attn_mask,
154 | rel_pos=self_attn_rel_pos,
155 | )
156 | x = self.dropout_module(x)
157 |
158 | if self.drop_path is not None:
159 | x = self.drop_path(x)
160 |
161 | x = self.residual_connection(x, residual)
162 | if not self.normalize_before:
163 | x = self.self_attn_layer_norm(x)
164 |
165 | if self.encoder_attn is not None and encoder_out is not None:
166 | residual = x
167 | if self.normalize_before:
168 | x = self.encoder_attn_layer_norm(x)
169 |
170 | x, attn = self.encoder_attn(
171 | query=x,
172 | key=encoder_out,
173 | value=encoder_out,
174 | key_padding_mask=encoder_padding_mask,
175 | incremental_state=None,
176 | rel_pos=cross_attn_rel_pos,
177 | )
178 | x = self.dropout_module(x)
179 |
180 | if self.drop_path is not None:
181 | x = self.drop_path(x)
182 |
183 | x = self.residual_connection(x, residual)
184 | if not self.normalize_before:
185 | x = self.encoder_attn_layer_norm(x)
186 |
187 | residual = x
188 | if self.normalize_before:
189 | x = self.final_layer_norm(x)
190 | if not self.is_moe_layer:
191 | x = self.ffn(x)
192 | l_aux = None
193 | else:
194 | x = x.transpose(0, 1)
195 | x, l_aux = self.moe_layer(x)
196 | x = x.transpose(0, 1)
197 |
198 | if self.drop_path is not None:
199 | x = self.drop_path(x)
200 |
201 | x = self.residual_connection(x, residual)
202 | if not self.normalize_before:
203 | x = self.final_layer_norm(x)
204 |
205 | return x, attn, None, l_aux
206 |
207 |
208 | class Decoder(nn.Module):
209 | def __init__(
210 | self,
211 | args,
212 | embed_tokens=None,
213 | embed_positions=None,
214 | output_projection=None,
215 | is_encoder_decoder=False,
216 | **kwargs
217 | ):
218 | super().__init__(**kwargs)
219 | self.args = args
220 |
221 | self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
222 |
223 | embed_dim = args.decoder_embed_dim
224 | self.embed_dim = embed_dim
225 | self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
226 |
227 | self.embed_tokens = embed_tokens
228 | self.embed_positions = embed_positions
229 |
230 | if (
231 | output_projection is None
232 | and not args.no_output_layer
233 | and args.vocab_size > 0
234 | ):
235 | self.output_projection = self.build_output_projection(args)
236 | else:
237 | self.output_projection = output_projection
238 |
239 | if args.layernorm_embedding:
240 | self.layernorm_embedding = LayerNorm(embed_dim)
241 | else:
242 | self.layernorm_embedding = None
243 |
244 | self.layers = nn.ModuleList([])
245 |
246 | moe_freq = args.moe_freq
247 | for i in range(args.decoder_layers):
248 | is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
249 | self.layers.append(
250 | self.build_decoder_layer(
251 | args,
252 | depth=i,
253 | is_moe_layer=is_moe_layer,
254 | is_encoder_decoder=is_encoder_decoder,
255 | )
256 | )
257 |
258 | self.num_layers = len(self.layers)
259 |
260 | if args.decoder_normalize_before:
261 | self.layer_norm = LayerNorm(embed_dim)
262 | else:
263 | self.layer_norm = None
264 |
265 | self.output_projection = output_projection
266 |
267 | self.block_size = args.block_size
268 | self.half_block_size = self.block_size // 2
269 |
270 | self.self_attn_xpos = None
271 | self.cross_attn_xpos = None
272 | self.self_attn_relative_position = None
273 | self.cross_attn_relative_position = None
274 | if args.xpos_rel_pos:
275 | self.self_attn_xpos = XPos(
276 | args.decoder_embed_dim // args.decoder_attention_heads
277 | )
278 | if is_encoder_decoder:
279 | self.cross_attn_xpos = XPos(
280 | args.decoder_embed_dim // args.decoder_attention_heads
281 | )
282 | elif args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
283 | self.self_attn_relative_position = RelativePositionBias(
284 | num_buckets=args.rel_pos_buckets,
285 | max_distance=args.max_rel_pos,
286 | n_heads=args.decoder_attention_heads,
287 | )
288 | if is_encoder_decoder:
289 | self.cross_attn_relative_position = RelativePositionBias(
290 | num_buckets=args.rel_pos_buckets,
291 | max_distance=args.max_rel_pos,
292 | n_heads=args.decoder_attention_heads,
293 | )
294 |
295 | if args.bert_init:
296 | self.apply(init_bert_params)
297 |
298 | if args.deepnorm:
299 | if is_encoder_decoder:
300 | init_scale = math.pow(12.0 * args.decoder_layers, 0.25)
301 | else:
302 | init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
303 | for name, p in self.named_parameters():
304 | if (
305 | "fc1" in name
306 | or "fc2" in name
307 | or "out_proj" in name
308 | or "v_proj" in name
309 | ):
310 | p.data.div_(init_scale)
311 |
312 | if args.subln:
313 | if is_encoder_decoder:
314 | init_scale = math.sqrt(math.log(args.decoder_layers * 3))
315 | else:
316 | init_scale = math.sqrt(math.log(args.decoder_layers * 2))
317 | for name, p in self.named_parameters():
318 | if "encoder_attn" in name:
319 | continue
320 | if (
321 | "fc1" in name
322 | or "fc2" in name
323 | or "out_proj" in name
324 | or "v_proj" in name
325 | ):
326 | p.data.mul_(init_scale)
327 |
328 | def build_output_projection(
329 | self,
330 | args,
331 | ):
332 | if args.share_decoder_input_output_embed:
333 | output_projection = torch.nn.Linear(
334 | self.embed_tokens.weight.shape[1],
335 | self.embed_tokens.weight.shape[0],
336 | bias=False,
337 | )
338 | output_projection.weight = self.embed_tokens.weight
339 | else:
340 | output_projection = torch.nn.Linear(
341 | args.decoder_embed_dim, args.vocab_size, bias=False
342 | )
343 | torch.nn.init.normal_(
344 | output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
345 | )
346 | return output_projection
347 |
348 | def build_decoder_layer(
349 | self, args, depth, is_moe_layer=False, is_encoder_decoder=False
350 | ):
351 | layer = DecoderLayer(
352 | args,
353 | depth,
354 | is_moe_layer=is_moe_layer,
355 | is_encoder_decoder=is_encoder_decoder,
356 | )
357 | if args.checkpoint_activations:
358 | layer = checkpoint_wrapper(layer)
359 | if args.fsdp:
360 | layer = wrap(layer)
361 | return layer
362 |
363 | def forward_embedding(
364 | self,
365 | tokens,
366 | token_embedding=None,
367 | incremental_state=None,
368 | ):
369 | positions = None
370 | if self.embed_positions is not None:
371 | positions = self.embed_positions(
372 | tokens, incremental_state=incremental_state
373 | )
374 |
375 | if incremental_state is not None:
376 | tokens = tokens[:, -1:]
377 | if positions is not None:
378 | positions = positions[:, -1:]
379 |
380 | if token_embedding is None:
381 | token_embedding = self.embed_tokens(tokens)
382 |
383 | x = embed = self.embed_scale * token_embedding
384 |
385 | if positions is not None:
386 | x += positions
387 |
388 | if self.layernorm_embedding is not None:
389 | x = self.layernorm_embedding(x)
390 |
391 | x = self.dropout_module(x)
392 |
393 | return x, embed
394 |
395 | def forward(
396 | self,
397 | prev_output_tokens,
398 | self_attn_padding_mask=None,
399 | encoder_out=None,
400 | incremental_state=None,
401 | features_only=False,
402 | return_all_hiddens=False,
403 | token_embeddings=None,
404 | **kwargs
405 | ):
406 | if self.block_size > 0 and prev_output_tokens.shape[1] > self.block_size and incremental_state is None: # padding to complete block
407 | activate_block = True
408 | src_length = prev_output_tokens.shape[1]
409 | pad_length = (src_length + self.half_block_size - 1) // self.half_block_size * self.half_block_size
410 | align_pad_length = pad_length - src_length
411 | if self_attn_padding_mask is None:
412 | self_attn_padding_mask = torch.zeros_like(prev_output_tokens)
413 |
414 | prev_output_tokens = F.pad(prev_output_tokens, (0, align_pad_length))
415 | self_attn_padding_mask = F.pad(self_attn_padding_mask, (self.half_block_size, align_pad_length), value=1).unfold(1, self.block_size, self.half_block_size).transpose(0, 1).reshape(-1, self.block_size)
416 | else:
417 | activate_block = False
418 | # embed tokens and positions
419 | x, _ = self.forward_embedding(
420 | prev_output_tokens, token_embeddings, incremental_state
421 | )
422 | x = x.transpose(0, 1)
423 |
424 | # relative position
425 | self_attn_rel_pos_bias = None
426 | slen = prev_output_tokens.size(1)
427 | if self.self_attn_xpos is not None:
428 | if activate_block:
429 | self_attn_rel_pos_bias = self.self_attn_xpos(self.block_size)
430 | else:
431 | rel_pos_len = slen if incremental_state is None else (incremental_state[0]["prev_key"].shape[2] + 1)
432 | self_attn_rel_pos_bias = self.self_attn_xpos(rel_pos_len)
433 | elif self.self_attn_relative_position is not None:
434 | self_attn_rel_pos_bias = self.self_attn_relative_position(
435 | batch_size=x.size(1), qlen=slen, klen=slen
436 | )
437 | if incremental_state is not None:
438 | self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :]
439 |
440 | cross_attn_rel_pos_bias = None
441 | if self.cross_attn_xpos is not None:
442 | cross_attn_rel_pos_bias = self.cross_attn_xpos(slen + encoder_out["encoder_out"].size(0))
443 | elif self.cross_attn_relative_position is not None:
444 | cross_attn_rel_pos_bias = self.cross_attn_relative_position(
445 | batch_size=x.size(1),
446 | qlen=slen,
447 | klen=encoder_out["encoder_out"].size(0),
448 | )
449 | if incremental_state is not None:
450 | cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[:, -1:, :]
451 |
452 | # decoder layers
453 | inner_states = [x]
454 |
455 | if encoder_out is None:
456 | l_aux = []
457 | else:
458 | l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
459 |
460 | for idx, layer in enumerate(self.layers):
461 | if incremental_state is None:
462 | if activate_block:
463 | self_attn_mask = torch.triu(
464 | torch.zeros([self.half_block_size, self.block_size])
465 | .float()
466 | .fill_(float("-inf"))
467 | .type_as(x),
468 | self.half_block_size + 1,
469 | )
470 | else:
471 | self_attn_mask = torch.triu(
472 | torch.zeros([x.size(0), x.size(0)])
473 | .float()
474 | .fill_(float("-inf"))
475 | .type_as(x),
476 | 1,
477 | )
478 | else:
479 | self_attn_mask = None
480 | if idx not in incremental_state:
481 | incremental_state[idx] = {}
482 |
483 | x, layer_attn, _, l_aux_i = layer(
484 | x,
485 | encoder_out["encoder_out"] if encoder_out is not None else None,
486 | encoder_out["encoder_padding_mask"]
487 | if encoder_out is not None
488 | else None,
489 | incremental_state[idx] if incremental_state is not None else None,
490 | self_attn_mask=self_attn_mask,
491 | self_attn_padding_mask=self_attn_padding_mask,
492 | self_attn_rel_pos=self_attn_rel_pos_bias,
493 | cross_attn_rel_pos=cross_attn_rel_pos_bias,
494 | )
495 | l_aux.append(l_aux_i)
496 | inner_states.append(x)
497 | if self.block_size > 0 and incremental_state is not None:
498 | if incremental_state[idx]["prev_key"].shape[2] > self.block_size: # Window Attention is implemented here
499 | incremental_state[idx]["prev_key"] = incremental_state[idx]["prev_key"][:, :, -self.block_size:]
500 | incremental_state[idx]["prev_value"] = incremental_state[idx]["prev_value"][:, :, -self.block_size:]
501 |
502 |
503 | if self.layer_norm is not None:
504 | x = self.layer_norm(x)
505 |
506 | x = x.transpose(0, 1)
507 | if activate_block:
508 | x = x[:, :src_length]
509 |
510 | if not features_only:
511 | x = self.output_layer(x)
512 |
513 | return x, {
514 | "inner_states": inner_states,
515 | "l_aux": l_aux,
516 | "attn": None,
517 | }
518 |
519 | def output_layer(self, features):
520 | return self.output_projection(features)
521 |
--------------------------------------------------------------------------------
/torchscale/architecture/encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import math
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from apex.normalization import FusedLayerNorm as LayerNorm
10 | from fairscale.nn import checkpoint_wrapper, wrap
11 |
12 | from torchscale.architecture.utils import init_bert_params
13 | from torchscale.component.droppath import DropPath
14 | from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
15 | from torchscale.component.multihead_attention import MultiheadAttention
16 | from torchscale.component.multiway_network import MultiwayWrapper, set_split_position
17 | from torchscale.component.relative_position_bias import RelativePositionBias
18 | from torchscale.component.xmoe.moe_layer import MOELayer
19 | from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
20 |
21 |
22 | class EncoderLayer(nn.Module):
23 | def __init__(self, args, depth, is_moe_layer=False, is_encoder_decoder=False):
24 | super().__init__()
25 | self.args = args
26 | self.embed_dim = args.encoder_embed_dim
27 | self.self_attn = self.build_self_attention(self.embed_dim, args)
28 | self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim))
29 | self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
30 |
31 | if args.drop_path_rate > 0:
32 | drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[
33 | depth
34 | ]
35 | self.drop_path = DropPath(drop_path_prob)
36 | else:
37 | self.drop_path = None
38 |
39 | self.normalize_before = args.encoder_normalize_before
40 | self.is_moe_layer = is_moe_layer
41 | self.ffn_dim = args.encoder_ffn_embed_dim
42 |
43 | if not self.is_moe_layer:
44 | self.ffn = MultiwayWrapper(
45 | args,
46 | self.build_ffn(
47 | self.embed_dim,
48 | self.args,
49 | ),
50 | )
51 | else:
52 | assert not self.args.multiway
53 | if args.moe_top1_expert:
54 | gate = Top1Gate(
55 | self.embed_dim,
56 | args.moe_expert_count,
57 | use_fp32=args.moe_gating_use_fp32,
58 | moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
59 | use_xmoe=args.use_xmoe,
60 | )
61 | else:
62 | gate = Top2Gate(
63 | self.embed_dim,
64 | args.moe_expert_count,
65 | args.moe_gating_use_fp32,
66 | args.moe_second_expert_policy,
67 | args.moe_normalize_gate_prob_before_dropping,
68 | args.moe_eval_capacity_token_fraction,
69 | use_xmoe=args.use_xmoe,
70 | )
71 | experts = make_experts(args, self.embed_dim, self.ffn_dim)
72 | self.moe_layer = MOELayer(gate, experts, args)
73 | self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim))
74 |
75 | if args.deepnorm:
76 | if is_encoder_decoder:
77 | self.alpha = (
78 | math.pow(
79 | math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
80 | )
81 | * 0.81
82 | )
83 | else:
84 | self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
85 | else:
86 | self.alpha = 1.0
87 |
88 | def build_ffn(self, embed_dim, args):
89 | return FeedForwardNetwork(
90 | embed_dim,
91 | self.ffn_dim,
92 | args.activation_fn,
93 | args.dropout,
94 | args.activation_dropout,
95 | args.subln,
96 | )
97 |
98 | def build_self_attention(self, embed_dim, args):
99 | return MultiheadAttention(
100 | args,
101 | embed_dim,
102 | args.encoder_attention_heads,
103 | dropout=args.attention_dropout,
104 | self_attention=True,
105 | encoder_decoder_attention=False,
106 | subln=args.subln,
107 | )
108 |
109 | def residual_connection(self, x, residual):
110 | return residual * self.alpha + x
111 |
112 | def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None):
113 | if attn_mask is not None:
114 | attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
115 |
116 | residual = x
117 | if self.normalize_before:
118 | x = self.self_attn_layer_norm(x)
119 | x, _ = self.self_attn(
120 | query=x,
121 | key=x,
122 | value=x,
123 | key_padding_mask=encoder_padding_mask,
124 | attn_mask=attn_mask,
125 | rel_pos=rel_pos,
126 | )
127 | x = self.dropout_module(x)
128 |
129 | if self.drop_path is not None:
130 | x = self.drop_path(x)
131 |
132 | x = self.residual_connection(x, residual)
133 | if not self.normalize_before:
134 | x = self.self_attn_layer_norm(x)
135 |
136 | residual = x
137 | if self.normalize_before:
138 | x = self.final_layer_norm(x)
139 | if not self.is_moe_layer:
140 | x = self.ffn(x)
141 | l_aux = None
142 | else:
143 | x = x.transpose(0, 1)
144 | x, l_aux = self.moe_layer(x)
145 | x = x.transpose(0, 1)
146 |
147 | if self.drop_path is not None:
148 | x = self.drop_path(x)
149 |
150 | x = self.residual_connection(x, residual)
151 | if not self.normalize_before:
152 | x = self.final_layer_norm(x)
153 | return x, l_aux
154 |
155 |
156 | class Encoder(nn.Module):
157 | def __init__(
158 | self,
159 | args,
160 | embed_tokens=None,
161 | embed_positions=None,
162 | output_projection=None,
163 | is_encoder_decoder=False,
164 | **kwargs
165 | ):
166 | self.args = args
167 | super().__init__(**kwargs)
168 |
169 | self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
170 |
171 | embed_dim = args.encoder_embed_dim
172 | self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
173 |
174 | self.embed_tokens = embed_tokens
175 | self.embed_positions = embed_positions
176 |
177 | if (
178 | output_projection is None
179 | and not is_encoder_decoder
180 | and not args.no_output_layer
181 | and args.vocab_size > 0
182 | ):
183 | self.output_projection = self.build_output_projection(args)
184 | else:
185 | self.output_projection = output_projection
186 |
187 | if args.layernorm_embedding:
188 | self.layernorm_embedding = MultiwayWrapper(
189 | args, LayerNorm(embed_dim), dim=1
190 | )
191 | else:
192 | self.layernorm_embedding = None
193 |
194 | self.layers = nn.ModuleList([])
195 |
196 | moe_freq = args.moe_freq
197 | for i in range(args.encoder_layers):
198 | is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
199 | self.layers.append(
200 | self.build_encoder_layer(
201 | args,
202 | depth=i,
203 | is_moe_layer=is_moe_layer,
204 | is_encoder_decoder=is_encoder_decoder,
205 | )
206 | )
207 | self.num_layers = len(self.layers)
208 |
209 | if args.encoder_normalize_before:
210 | self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim))
211 | else:
212 | self.layer_norm = None
213 |
214 | if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
215 | self.relative_position = RelativePositionBias(
216 | num_buckets=args.rel_pos_buckets,
217 | max_distance=args.max_rel_pos,
218 | n_heads=args.encoder_attention_heads,
219 | )
220 | else:
221 | self.relative_position = None
222 |
223 | if args.bert_init:
224 | self.apply(init_bert_params)
225 |
226 | if args.deepnorm:
227 | if is_encoder_decoder:
228 | init_scale = (
229 | math.pow(
230 | math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
231 | )
232 | / 1.15
233 | )
234 | else:
235 | init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
236 | for name, p in self.named_parameters():
237 | if (
238 | "fc1" in name
239 | or "fc2" in name
240 | or "out_proj" in name
241 | or "v_proj" in name
242 | ):
243 | p.data.div_(init_scale)
244 |
245 | if args.subln:
246 | if is_encoder_decoder:
247 | init_scale = math.sqrt(
248 | math.log(3 * args.decoder_layers)
249 | * math.log(2 * args.encoder_layers)
250 | / 3
251 | )
252 | else:
253 | init_scale = math.sqrt(math.log(args.encoder_layers * 2))
254 | for name, p in self.named_parameters():
255 | if (
256 | "fc1" in name
257 | or "fc2" in name
258 | or "out_proj" in name
259 | or "v_proj" in name
260 | ):
261 | p.data.mul_(init_scale)
262 |
263 | def build_output_projection(
264 | self,
265 | args,
266 | ):
267 | if args.share_encoder_input_output_embed:
268 | assert args.encoder_embedding_type == "language"
269 | output_projection = torch.nn.Linear(
270 | self.embed_tokens.weight.shape[1],
271 | self.embed_tokens.weight.shape[0],
272 | bias=False,
273 | )
274 | output_projection.weight = self.embed_tokens.weight
275 | else:
276 | output_projection = torch.nn.Linear(
277 | args.encoder_embed_dim, args.vocab_size, bias=False
278 | )
279 | torch.nn.init.normal_(
280 | output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5
281 | )
282 | return output_projection
283 |
284 | def build_encoder_layer(
285 | self, args, depth, is_moe_layer=False, is_encoder_decoder=False
286 | ):
287 | layer = EncoderLayer(
288 | args,
289 | depth,
290 | is_moe_layer=is_moe_layer,
291 | is_encoder_decoder=is_encoder_decoder,
292 | )
293 | if args.checkpoint_activations:
294 | layer = checkpoint_wrapper(layer)
295 | if args.fsdp:
296 | layer = wrap(layer)
297 | return layer
298 |
299 | def forward_embedding(
300 | self,
301 | src_tokens,
302 | token_embedding=None,
303 | ):
304 | if token_embedding is None:
305 | token_embedding = self.embed_tokens(src_tokens)
306 | x = embed = self.embed_scale * token_embedding
307 | if self.embed_positions is not None:
308 | if src_tokens is not None:
309 | x = embed + self.embed_positions(src_tokens)
310 | else:
311 | x = embed + self.embed_positions(x)
312 | if self.layernorm_embedding is not None:
313 | x = self.layernorm_embedding(x)
314 | x = self.dropout_module(x)
315 | return x, embed
316 |
317 | def forward(
318 | self,
319 | src_tokens,
320 | encoder_padding_mask=None,
321 | return_all_hiddens=False,
322 | token_embeddings=None,
323 | multiway_split_position=None,
324 | features_only=False,
325 | **kwargs
326 | ):
327 | assert src_tokens is not None or token_embeddings is not None
328 |
329 | if encoder_padding_mask is None:
330 | if src_tokens is not None:
331 | encoder_padding_mask = torch.zeros_like(
332 | src_tokens, device=src_tokens.device
333 | ).bool()
334 | else:
335 | encoder_padding_mask = torch.zeros(
336 | [token_embeddings.size(0), token_embeddings.size(1)],
337 | device=token_embeddings.device,
338 | ).bool()
339 |
340 | if multiway_split_position is not None:
341 | assert self.args.multiway
342 | self.apply(set_split_position(multiway_split_position))
343 |
344 | x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
345 | x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
346 |
347 | x = x.transpose(0, 1)
348 |
349 | encoder_states = []
350 |
351 | if return_all_hiddens:
352 | encoder_states.append(x)
353 |
354 | rel_pos_bias = None
355 | if self.relative_position is not None:
356 | rel_pos_bias = self.relative_position(
357 | batch_size=x.size(1), qlen=x.size(0), klen=x.size(0)
358 | )
359 |
360 | l_aux = []
361 | for layer in self.layers:
362 | x, l_aux_i = layer(
363 | x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias
364 | )
365 | if return_all_hiddens:
366 | assert encoder_states is not None
367 | encoder_states.append(x)
368 | l_aux.append(l_aux_i)
369 |
370 | if self.layer_norm is not None:
371 | x = self.layer_norm(x)
372 |
373 | if not features_only and self.output_projection is not None:
374 | x = self.output_projection(x)
375 |
376 | return {
377 | "encoder_out": x,
378 | "encoder_embedding": encoder_embedding,
379 | "encoder_padding_mask": encoder_padding_mask,
380 | "encoder_states": encoder_states,
381 | "l_aux": l_aux,
382 | }
383 |
--------------------------------------------------------------------------------
/torchscale/architecture/encoder_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch.nn as nn
5 |
6 | from torchscale.architecture.decoder import Decoder
7 | from torchscale.architecture.encoder import Encoder
8 |
9 |
10 | class EncoderDecoder(nn.Module):
11 | def __init__(
12 | self,
13 | args,
14 | encoder_embed_tokens=None,
15 | encoder_embed_positions=None,
16 | decoder_embed_tokens=None,
17 | decoder_embed_positions=None,
18 | output_projection=None,
19 | **kwargs
20 | ):
21 | super().__init__()
22 | self.args = args
23 | if args.share_all_embeddings:
24 | args.share_decoder_input_output_embed = True
25 |
26 | self.encoder = Encoder(
27 | args,
28 | encoder_embed_tokens,
29 | encoder_embed_positions,
30 | is_encoder_decoder=True,
31 | **kwargs
32 | )
33 |
34 | if args.share_all_embeddings and decoder_embed_tokens is None:
35 | decoder_embed_tokens = self.encoder.embed_tokens
36 |
37 | self.decoder = Decoder(
38 | args,
39 | decoder_embed_tokens,
40 | decoder_embed_positions,
41 | output_projection,
42 | is_encoder_decoder=True,
43 | **kwargs
44 | )
45 |
46 | def forward(
47 | self,
48 | src_tokens,
49 | prev_output_tokens,
50 | return_all_hiddens=False,
51 | features_only=False,
52 | **kwargs
53 | ):
54 | encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
55 | decoder_out = self.decoder(
56 | prev_output_tokens,
57 | encoder_out=encoder_out,
58 | features_only=features_only,
59 | return_all_hiddens=return_all_hiddens,
60 | )
61 | return decoder_out
62 |
--------------------------------------------------------------------------------
/torchscale/architecture/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch.nn as nn
5 |
6 | from torchscale.component.multihead_attention import MultiheadAttention
7 | from torchscale.component.multiway_network import MultiwayNetwork
8 |
9 |
10 | def init_bert_params(module):
11 | def normal_(data):
12 | data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
13 |
14 | if isinstance(module, nn.Linear):
15 | normal_(module.weight.data)
16 | if module.bias is not None:
17 | module.bias.data.zero_()
18 | if isinstance(module, nn.Embedding):
19 | normal_(module.weight.data)
20 | if module.padding_idx is not None:
21 | module.weight.data[module.padding_idx].zero_()
22 | if isinstance(module, MultiheadAttention):
23 | if isinstance(module.q_proj, MultiwayNetwork):
24 | normal_(module.q_proj.A.weight.data)
25 | normal_(module.q_proj.B.weight.data)
26 | normal_(module.k_proj.A.weight.data)
27 | normal_(module.k_proj.B.weight.data)
28 | normal_(module.v_proj.A.weight.data)
29 | normal_(module.v_proj.B.weight.data)
30 | else:
31 | normal_(module.q_proj.weight.data)
32 | normal_(module.k_proj.weight.data)
33 | normal_(module.v_proj.weight.data)
34 |
--------------------------------------------------------------------------------
/torchscale/component/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/torchscale/component/droppath.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch.nn as nn
5 | from timm.models.layers import drop_path
6 |
7 |
8 | class DropPath(nn.Module):
9 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
10 |
11 | def __init__(self, drop_prob=None):
12 | super(DropPath, self).__init__()
13 | self.drop_prob = drop_prob
14 |
15 | def forward(self, x):
16 | return drop_path(x, self.drop_prob, self.training)
17 |
18 | def extra_repr(self):
19 | return "p={}".format(self.drop_prob)
20 |
--------------------------------------------------------------------------------
/torchscale/component/embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class VisionLanguageEmbedding(nn.Module):
10 | def __init__(self, text_embed, vision_embed):
11 | super().__init__()
12 | self.text_embed = text_embed
13 | self.vision_embed = vision_embed
14 |
15 | def forward(self, textual_tokens, visual_tokens, **kwargs):
16 | if textual_tokens is None:
17 | return self.vision_embed(visual_tokens)
18 |
19 | if visual_tokens is None:
20 | return self.text_embed(textual_tokens)
21 |
22 | x1 = self.vision_embed(visual_tokens)
23 | x2 = self.text_embed(textual_tokens)
24 |
25 | return torch.cat([x1, x2], dim=1)
26 |
27 |
28 | class VisionEmbedding(nn.Module):
29 | """Image to Patch Embedding"""
30 |
31 | def __init__(
32 | self,
33 | img_size=224,
34 | patch_size=16,
35 | in_chans=3,
36 | embed_dim=768,
37 | contain_mask_token=False,
38 | prepend_cls_token=False,
39 | ):
40 | super().__init__()
41 | img_size = (img_size, img_size)
42 | patch_size = (patch_size, patch_size)
43 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
44 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
45 | self.img_size = img_size
46 | self.patch_size = patch_size
47 | self.num_patches = num_patches
48 |
49 | self.proj = nn.Conv2d(
50 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
51 | )
52 |
53 | if contain_mask_token:
54 | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
55 | else:
56 | self.mask_token = None
57 |
58 | if prepend_cls_token:
59 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
60 | else:
61 | self.cls_token = None
62 |
63 | def forward(self, x, masked_position=None, **kwargs):
64 | B, C, H, W = x.shape
65 | assert (
66 | H == self.img_size[0] and W == self.img_size[1]
67 | ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
68 | x = self.proj(x).flatten(2).transpose(1, 2)
69 |
70 | batch_size, seq_len, _ = x.size()
71 |
72 | if masked_position is not None:
73 | assert self.mask_token is not None
74 | mask_token = self.mask_token.expand(batch_size, seq_len, -1)
75 | w = masked_position.unsqueeze(-1).type_as(mask_token)
76 | x = x * (1 - w) + mask_token * w
77 |
78 | if self.cls_token is not None:
79 | cls_tokens = self.cls_token.expand(
80 | batch_size, -1, -1
81 | ) # stole cls_tokens impl from Phil Wang, thanks
82 | x = torch.cat((cls_tokens, x), dim=1)
83 |
84 | return x
85 |
86 |
87 | class TextEmbedding(nn.Embedding):
88 | def reset_parameters(self):
89 | nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5)
90 | self._fill_padding_idx_with_zero()
91 |
92 |
93 | class PositionalEmbedding(nn.Embedding):
94 | def forward(
95 | self,
96 | x,
97 | positions=None,
98 | **kwargs,
99 | ):
100 | if positions is None:
101 | # being consistent with Fairseq, which starts from 2.
102 | positions = (
103 | torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0)
104 | )
105 | return F.embedding(
106 | positions,
107 | self.weight,
108 | self.padding_idx,
109 | self.max_norm,
110 | self.norm_type,
111 | self.scale_grad_by_freq,
112 | self.sparse,
113 | )
114 |
--------------------------------------------------------------------------------
/torchscale/component/feedforward_network.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from apex.normalization import FusedLayerNorm as LayerNorm
8 |
9 |
10 | class set_torch_seed(object):
11 | def __init__(self, seed):
12 | assert isinstance(seed, int)
13 | self.rng_state = self.get_rng_state()
14 |
15 | torch.manual_seed(seed)
16 | if torch.cuda.is_available():
17 | torch.cuda.manual_seed(seed)
18 |
19 | def get_rng_state(self):
20 | state = {"torch_rng_state": torch.get_rng_state()}
21 | if torch.cuda.is_available():
22 | state["cuda_rng_state"] = torch.cuda.get_rng_state()
23 | return state
24 |
25 | def set_rng_state(self, state):
26 | torch.set_rng_state(state["torch_rng_state"])
27 | if torch.cuda.is_available():
28 | torch.cuda.set_rng_state(state["cuda_rng_state"])
29 |
30 | def __enter__(self):
31 | return self
32 |
33 | def __exit__(self, *exc):
34 | self.set_rng_state(self.rng_state)
35 |
36 |
37 | def make_experts(args, embed_dim, expert_ffn_dim):
38 | world_size = (
39 | 1
40 | if not torch.distributed.is_initialized()
41 | else torch.distributed.get_world_size()
42 | )
43 | expert_list = []
44 | ddp_rank = args.ddp_rank
45 | start_seed = torch.randint(1000000, (1,)).item()
46 | # at least as many experts than gpus
47 | if args.moe_expert_count >= world_size:
48 | assert (
49 | args.moe_expert_count % world_size == 0
50 | ), f"{args.moe_expert_count}, {world_size}"
51 | local_moe_expert_count = args.moe_expert_count // world_size
52 | for i in range(local_moe_expert_count):
53 | with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
54 | expert_list.append(
55 | FeedForwardNetwork(
56 | embed_dim,
57 | expert_ffn_dim,
58 | args.activation_fn,
59 | args.dropout,
60 | args.activation_dropout,
61 | args.subln,
62 | )
63 | )
64 | else:
65 | assert (
66 | world_size % args.moe_expert_count == 0
67 | ), f"{world_size}, {args.moe_expert_count}"
68 |
69 | with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
70 | expert_list.append(
71 | FeedForwardNetwork(
72 | embed_dim,
73 | expert_ffn_dim,
74 | args.activation_fn,
75 | args.dropout,
76 | args.activation_dropout,
77 | args.subln,
78 | )
79 | )
80 | experts = nn.ModuleList(expert_list)
81 | return experts
82 |
83 |
84 | def get_activation_fn(activation):
85 | if activation == "relu":
86 | return F.relu
87 | elif activation == "gelu":
88 | return F.gelu
89 | else:
90 | raise NotImplementedError
91 |
92 |
93 | class FeedForwardNetwork(nn.Module):
94 | def __init__(
95 | self,
96 | embed_dim,
97 | ffn_dim,
98 | activation_fn,
99 | dropout,
100 | activation_dropout,
101 | subln=False,
102 | ):
103 | super().__init__()
104 | self.embed_dim = embed_dim
105 | self.activation_fn = get_activation_fn(activation=str(activation_fn))
106 | self.activation_dropout_module = torch.nn.Dropout(
107 | activation_dropout, inplace=True
108 | )
109 | self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
110 | self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
111 | self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
112 | self.ffn_layernorm = LayerNorm(ffn_dim) if subln else None
113 |
114 | def reset_parameters(self):
115 | self.fc1.reset_parameters()
116 | self.fc2.reset_parameters()
117 | if self.ffn_layernorm is not None:
118 | self.ffn_layernorm.reset_parameters()
119 |
120 | def forward(self, x):
121 | x_shape = x.shape
122 | x = x.reshape(-1, x.size(-1))
123 | x = self.fc1(x)
124 | x = self.activation_fn(x.float()).type_as(x)
125 | x = self.activation_dropout_module(x)
126 | if self.ffn_layernorm is not None:
127 | x = self.ffn_layernorm(x)
128 | x = self.fc2(x)
129 | x = x.view(x_shape)
130 | x = self.dropout_module(x)
131 | return x
132 |
--------------------------------------------------------------------------------
/torchscale/component/multihead_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from apex.normalization import FusedLayerNorm as LayerNorm
9 | from torch import nn
10 |
11 | from .multiway_network import MultiwayWrapper
12 |
13 | def rotate_every_two(x):
14 | x1 = x[:, :, ::2]
15 | x2 = x[:, :, 1::2]
16 | x = torch.stack((-x2, x1), dim=-1)
17 | return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
18 |
19 | def duplicate_interleave(m):
20 | """
21 | A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
22 | """
23 | dim0 = m.shape[0]
24 | m = m.view(-1, 1) # flatten the matrix
25 | m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
26 | m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
27 | return m
28 |
29 | def apply_rotary_pos_emb(x, sin, cos, scale=1):
30 | sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos))
31 | # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
32 | return (x * cos) + (rotate_every_two(x) * sin)
33 |
34 | class MultiheadAttention(nn.Module):
35 | def __init__(
36 | self,
37 | args,
38 | embed_dim,
39 | num_heads,
40 | dropout=0.0,
41 | self_attention=False,
42 | encoder_decoder_attention=False,
43 | subln=False,
44 | ):
45 | super().__init__()
46 | self.embed_dim = embed_dim
47 | self.num_heads = num_heads
48 | self.head_dim = embed_dim // num_heads
49 | self.scaling = self.head_dim**-0.5
50 | self.block_size = args.block_size
51 | self.half_block_size = self.block_size // 2
52 |
53 | self.self_attention = self_attention
54 | self.encoder_decoder_attention = encoder_decoder_attention
55 | assert self.self_attention ^ self.encoder_decoder_attention
56 |
57 | self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
58 | self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
59 | self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
60 | self.out_proj = MultiwayWrapper(
61 | args, nn.Linear(embed_dim, embed_dim, bias=True)
62 | )
63 | self.inner_attn_ln = (
64 | MultiwayWrapper(args, LayerNorm(self.embed_dim))
65 | if subln and self.self_attention
66 | else None
67 | )
68 | self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
69 |
70 | def reset_parameters(self):
71 | nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
72 | nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
73 | nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
74 | nn.init.xavier_uniform_(self.out_proj.weight)
75 | nn.init.constant_(self.out_proj.bias, 0.0)
76 |
77 | def forward(
78 | self,
79 | query,
80 | key,
81 | value,
82 | incremental_state=None,
83 | key_padding_mask=None,
84 | attn_mask=None,
85 | rel_pos=None,
86 | ):
87 | tgt_len, bsz, embed_dim = query.size()
88 | src_len = tgt_len
89 | assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
90 | assert list(query.size()) == [tgt_len, bsz, embed_dim]
91 |
92 | src_len, key_bsz, _ = key.size()
93 | assert key_bsz == bsz, f"{query.size(), key.size()}"
94 | assert value is not None
95 | assert src_len, bsz == value.shape[:2]
96 |
97 | q = self.q_proj(query) # tgt_len, bsz, dim
98 | k = self.k_proj(key)
99 | v = self.v_proj(value)
100 | q *= self.scaling
101 | if self.block_size > 0 and tgt_len > self.block_size: # divide block
102 | assert tgt_len % self.half_block_size == 0
103 | if incremental_state is not None:
104 | incremental_state["prev_key"] = k.view(
105 | bsz, self.num_heads, -1, self.head_dim
106 | )
107 | incremental_state["prev_value"] = v.view(
108 | bsz, self.num_heads, -1, self.head_dim
109 | )
110 |
111 | q = q.view(-1, self.half_block_size, bsz * self.num_heads, self.head_dim).transpose(1, 2).reshape(-1, self.half_block_size, self.head_dim)
112 | k = F.pad(k, (0, 0, 0, 0, self.half_block_size, 0)).unfold(0, self.block_size, self.half_block_size).reshape(-1, self.head_dim, self.block_size).transpose(1, 2)
113 | v = F.pad(v, (0, 0, 0, 0, self.half_block_size, 0)).unfold(0, self.block_size, self.half_block_size).reshape(-1, self.head_dim, self.block_size).transpose(1, 2)
114 | bsz *= tgt_len // self.half_block_size
115 | tgt_len = self.half_block_size
116 | src_len = self.block_size
117 |
118 | else:
119 | q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
120 | k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
121 | v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
122 | if incremental_state is not None:
123 | if "prev_key" in incremental_state:
124 | prev_key = incremental_state["prev_key"].view(
125 | bsz * self.num_heads, -1, self.head_dim
126 | )
127 | prev_value = incremental_state["prev_value"].view(
128 | bsz * self.num_heads, -1, self.head_dim
129 | )
130 | k = torch.cat([prev_key, k], dim=1)
131 | v = torch.cat([prev_value, v], dim=1)
132 | incremental_state["prev_key"] = k.view(
133 | bsz, self.num_heads, -1, self.head_dim
134 | )
135 | incremental_state["prev_value"] = v.view(
136 | bsz, self.num_heads, -1, self.head_dim
137 | )
138 | src_len = k.size(1)
139 |
140 | if isinstance(rel_pos, tuple): # XPos implementation
141 | sin, cos, scale = rel_pos
142 | if self.self_attention:
143 | k = apply_rotary_pos_emb(k, sin, cos, scale = 1 / scale)
144 | q = apply_rotary_pos_emb(q, sin[-q.shape[1]:], cos[-q.shape[1]:], scale = scale[-q.shape[1]:])
145 | else:
146 | k = apply_rotary_pos_emb(k, sin[:k.shape[1]], cos[:k.shape[1]], scale = 1 / scale[:k.shape[1]])
147 | q = apply_rotary_pos_emb(q, sin[k.shape[1]:], cos[k.shape[1]:], scale = scale[k.shape[1]:])
148 |
149 | attn_weights = torch.bmm(q, k.transpose(1, 2))
150 | if attn_mask is not None:
151 | attn_weights = torch.nan_to_num(attn_weights)
152 | attn_mask = attn_mask.unsqueeze(0)
153 | attn_weights += attn_mask
154 |
155 | if key_padding_mask is not None:
156 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
157 | attn_weights = attn_weights.masked_fill(
158 | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
159 | float("-inf"),
160 | )
161 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
162 |
163 | if isinstance(rel_pos, torch.Tensor):
164 | rel_pos = rel_pos.view(attn_weights.size())
165 | attn_weights = attn_weights + rel_pos
166 |
167 | attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
168 | attn_weights
169 | )
170 | attn_probs = self.dropout_module(attn_weights)
171 | attn = torch.bmm(attn_probs, v)
172 | if bsz > key_bsz: # merge block
173 | attn = attn.view(-1, key_bsz * self.num_heads, self.half_block_size, self.head_dim).transpose(1, 2).reshape(-1, key_bsz, embed_dim)
174 | else:
175 | attn = attn.transpose(0, 1).reshape(-1, bsz, embed_dim)
176 |
177 | if self.inner_attn_ln is not None:
178 | attn = self.inner_attn_ln(attn)
179 |
180 | attn = self.out_proj(attn)
181 | attn_weights = attn_weights.view(
182 | bsz, self.num_heads, tgt_len, src_len
183 | ).transpose(1, 0)
184 | return attn, attn_weights
185 |
--------------------------------------------------------------------------------
/torchscale/component/multiway_network.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import copy
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | def MultiwayWrapper(args, module, dim=0):
11 | if args.multiway:
12 | return MultiwayNetwork(module, dim=dim)
13 | return module
14 |
15 |
16 | def set_split_position(position):
17 | def apply_fn(module):
18 | if hasattr(module, "split_position"):
19 | module.split_position = position
20 |
21 | return apply_fn
22 |
23 |
24 | class MultiwayNetwork(nn.Module):
25 | def __init__(self, module, dim=0):
26 | super().__init__()
27 | self.dim = dim
28 | self.A = module
29 | self.B = copy.deepcopy(module)
30 | self.B.reset_parameters()
31 | self.split_position = -1
32 |
33 | def forward(self, x, **kwargs):
34 | if self.split_position == -1:
35 | return self.A(x, **kwargs)
36 | if self.split_position == 0:
37 | return self.B(x, **kwargs)
38 | x1, x2 = torch.split(
39 | x,
40 | [self.split_position, x.size(self.dim) - self.split_position],
41 | dim=self.dim,
42 | )
43 | # x1, x2 = x[:self.split_position], x[self.split_position:]
44 | y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
45 | return torch.cat([y1, y2], dim=self.dim)
46 |
--------------------------------------------------------------------------------
/torchscale/component/relative_position_bias.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class RelativePositionBias(nn.Module):
11 | def __init__(
12 | self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12
13 | ):
14 | super().__init__()
15 | self.bidirectional = bidirectional
16 | self.num_buckets = num_buckets
17 | self.max_distance = max_distance
18 | self.n_heads = n_heads
19 | self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
20 |
21 | @staticmethod
22 | def _relative_position_bucket(
23 | relative_position, bidirectional=True, num_buckets=32, max_distance=128
24 | ):
25 | ret = 0
26 | n = -relative_position
27 | if bidirectional:
28 | num_buckets //= 2
29 | ret += (n < 0).to(torch.long) * num_buckets
30 | n = torch.abs(n)
31 | else:
32 | n = torch.max(n, torch.zeros_like(n))
33 |
34 | max_exact = num_buckets // 2
35 | is_small = n < max_exact
36 |
37 | val_if_large = max_exact + (
38 | torch.log(n.float() / max_exact)
39 | / math.log(max_distance / max_exact)
40 | * (num_buckets - max_exact)
41 | ).to(torch.long)
42 | val_if_large = torch.min(
43 | val_if_large, torch.full_like(val_if_large, num_buckets - 1)
44 | )
45 |
46 | ret += torch.where(is_small, n, val_if_large)
47 | return ret
48 |
49 | def compute_bias(self, qlen, klen, step=None):
50 | step = 0 if step is None else step
51 | context_position = torch.arange(
52 | step,
53 | step + qlen,
54 | dtype=torch.long,
55 | device=self.relative_attention_bias.weight.device,
56 | )[:, None]
57 | memory_position = torch.arange(
58 | klen, dtype=torch.long, device=self.relative_attention_bias.weight.device
59 | )[None, :]
60 | relative_position = memory_position - context_position # shape (qlen, klen)
61 |
62 | rp_bucket = self._relative_position_bucket(
63 | relative_position, # shape (qlen, klen)
64 | bidirectional=self.bidirectional,
65 | num_buckets=self.num_buckets,
66 | )
67 | rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
68 | values = self.relative_attention_bias(
69 | rp_bucket
70 | ) # shape (qlen, klen, num_heads)
71 | values = values.permute([2, 0, 1]).unsqueeze(
72 | 0
73 | ) # shape (1, num_heads, qlen, klen)
74 | return values
75 |
76 | def forward(self, batch_size, qlen, klen, step=None):
77 | # shape (batch * num_heads, qlen, klen)
78 | return (
79 | self.compute_bias(qlen, klen, step)
80 | .repeat(batch_size, 1, 1, 1)
81 | .view(-1, qlen, klen)
82 | )
83 |
--------------------------------------------------------------------------------
/torchscale/component/xmoe/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/torchscale/component/xmoe/moe_layer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
5 | #
6 | # This source code is licensed under the BSD license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | # NOTE: This is a mirror of the code in
10 | # https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
11 |
12 | import logging
13 | import time
14 | from typing import Any, Tuple, cast
15 |
16 | import torch
17 | import torch.distributed as dist
18 | from torch import Tensor
19 | from torch.nn import Module, ModuleList
20 |
21 | try:
22 | from fairseq.modules.moe import MOELayer
23 |
24 | has_fairseq = True
25 | Base = MOELayer
26 | except ModuleNotFoundError:
27 | Base = Module
28 | has_fairseq = False
29 |
30 | try:
31 | # To enable Tutel MoE optimizations:
32 | # python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
33 | from tutel import moe as tutel_moe
34 |
35 | has_tutel, fused_cumsum_sub_one = True, tutel_moe.fast_cumsum_sub_one
36 | except ModuleNotFoundError:
37 | has_tutel, fused_cumsum_sub_one = False, lambda mask: torch.cumsum(mask, dim=0) - 1
38 |
39 | logger = logging.getLogger(__name__)
40 |
41 |
42 | # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
43 | # See https://arxiv.org/pdf/2006.16668.pdf for details.
44 |
45 | # Based on https://github.com/pytorch/pytorch/pull/40762
46 | class _AllToAll(torch.autograd.Function):
47 | @staticmethod
48 | def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
49 | ctx.group = group
50 | input = input.contiguous()
51 | output = torch.empty_like(input)
52 | if torch.distributed.is_initialized():
53 | dist.all_to_all_single(output, input, group=group)
54 | else:
55 | assert group is None
56 | output = input
57 | return output
58 |
59 | @staticmethod
60 | def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
61 | return (None, _AllToAll.apply(ctx.group, *grad_output))
62 |
63 |
64 | def _find_my_group_index(grouped_ranks):
65 | my_rank = dist.get_rank()
66 | for i, group in enumerate(grouped_ranks):
67 | if my_rank in group:
68 | return i
69 | raise RuntimeError
70 |
71 |
72 | def get_moe_group(moe_expert_count):
73 | if dist.is_initialized():
74 | if not hasattr(get_moe_group, "_moe_groups"):
75 | world_size = dist.get_world_size()
76 |
77 | if world_size <= moe_expert_count:
78 | assert moe_expert_count % world_size == 0
79 | moe_groups = [[i] for i in range(world_size)]
80 |
81 | else:
82 | assert world_size % moe_expert_count == 0
83 | ranks_per_group = world_size // moe_expert_count
84 | moe_groups = [
85 | [i + j * moe_expert_count for j in range(ranks_per_group)]
86 | for i in range(moe_expert_count)
87 | ]
88 |
89 | get_moe_group._moe_group_idx = moe_groups
90 | get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
91 |
92 | my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
93 | return get_moe_group._moe_groups[my_group_idx]
94 |
95 |
96 | def get_all2all_group(moe_expert_count):
97 | if dist.is_initialized():
98 | if not hasattr(get_all2all_group, "_all2all_groups"):
99 | world_size = dist.get_world_size()
100 |
101 | # more experts than world size
102 | if world_size <= moe_expert_count:
103 | assert moe_expert_count % world_size == 0
104 | all2all_groups = [[i for i in range(world_size)]]
105 |
106 | # larger world than num experts
107 | else:
108 | assert world_size % moe_expert_count == 0
109 | ranks_per_group = world_size // moe_expert_count
110 | all2all_groups = [
111 | [i * moe_expert_count + j for j in range(moe_expert_count)]
112 | for i in range(ranks_per_group)
113 | ]
114 |
115 | get_all2all_group._all2all_group_idx = all2all_groups
116 | get_all2all_group._all2all_groups = [
117 | dist.new_group(g) for g in all2all_groups
118 | ]
119 |
120 | my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
121 | return get_all2all_group._all2all_groups[my_group_idx]
122 |
123 |
124 | class MOELayer(Base):
125 | """MOELayer module which implements MixtureOfExperts as described in Gshard_.
126 | ::
127 |
128 | gate = Top2Gate(model_dim, num_experts)
129 | moe = MOELayer(gate, expert)
130 | output = moe(input)
131 | l_aux = moe.l_aux
132 |
133 | .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
134 |
135 | Args:
136 | gate (torch.nn.Module):
137 | gate network
138 | expert (torch.nn.Module):
139 | expert network
140 | """
141 |
142 | def __init__(self, gate, experts, args):
143 | if has_fairseq:
144 | super(Base, self).__init__()
145 | else:
146 | super().__init__()
147 | self.gate = gate
148 | if type(experts) == ModuleList:
149 | self.experts = cast(ModuleList, experts)
150 | else:
151 | self.experts = ModuleList([experts])
152 | self.expert_group = get_moe_group(args.moe_expert_count)
153 | self.all2all_group = get_all2all_group(args.moe_expert_count)
154 | self.world_size = dist.get_world_size(group=self.expert_group)
155 | self.all2all_size = dist.get_world_size(group=self.all2all_group)
156 | for p in experts.parameters():
157 | p.expert = True # type: ignore
158 | self.num_local_experts = len(self.experts)
159 | self.args = args
160 | self.in_generation = False
161 | self.a2a_cuda_event_intervals = []
162 | self.a2a_cpu_time_ms = 0.0
163 |
164 | def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor:
165 | assert len(input) == 1, "only single input Tensor supported"
166 | input = input[0]
167 | assert (
168 | len(input.shape) == 3
169 | ), "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
170 | if input_padding_mask is not None:
171 | assert (
172 | len(input_padding_mask.shape) == 2
173 | ), "input Tensor must have dimensions: (s)equence, (t)oken"
174 | assert input_padding_mask.shape[0] == input.shape[0]
175 | assert input_padding_mask.shape[1] == input.shape[1]
176 | # assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"
177 |
178 | # Implement Algorithm 2 from GShard paper.
179 | d_model = input.shape[2]
180 | # Pad to expected batch size
181 | input_shape = list(input.shape)
182 | expected_bsz = (
183 | getattr(self.args, "batch_size", 0)
184 | if self.training
185 | else getattr(self.args, "batch_size_valid", 0)
186 | )
187 | # This indicates that --batch-size or --max-sentences is not specified
188 | if expected_bsz is None:
189 | expected_bsz = 0
190 | # Note: Padding is not necessary at generation time at present
191 | # because all DDP workers process the same batch. Also, batch size at generation time
192 | # can be different from that present in the checkpoint state
193 | if (
194 | not self.in_generation
195 | and expected_bsz != 0
196 | and input_shape[0] != expected_bsz
197 | ):
198 | logger.warning(
199 | f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})"
200 | )
201 | assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}"
202 | padded_input = torch.zeros(
203 | (expected_bsz, input_shape[1], input_shape[2]),
204 | dtype=input.dtype,
205 | layout=input.layout,
206 | device=input.device,
207 | )
208 | padded_input[: input_shape[0], :, :] = input
209 | input = padded_input
210 |
211 | padded_input_padding_mask = torch.ones(
212 | (
213 | expected_bsz,
214 | input_shape[1],
215 | ),
216 | dtype=torch.bool,
217 | device=input.device,
218 | )
219 | if input_padding_mask is not None:
220 | padded_input_padding_mask[: input_shape[0], :] = input_padding_mask
221 | else:
222 | padded_input_padding_mask[: input_shape[0], :] = False
223 | input_padding_mask = padded_input_padding_mask
224 |
225 | # Reshape into S tokens by dropping sequence dimension.
226 | reshaped_input = input.reshape(-1, d_model)
227 | reshaped_input_shape = reshaped_input.shape
228 | reshaped_input_padding_mask = (
229 | input_padding_mask.reshape(-1) if input_padding_mask is not None else None
230 | )
231 |
232 | # Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences
233 | # Pro of --max-tokens: more flexible for MT variable sequence lengths
234 | # Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM
235 | if expected_bsz == 0:
236 | expected_dim = reshaped_input_shape[0] * torch.ones(
237 | (1,), dtype=torch.long, device=input.device
238 | )
239 | dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX)
240 | expected_dim = int(expected_dim.item())
241 | padded_input = torch.zeros(
242 | (expected_dim, reshaped_input_shape[1]),
243 | dtype=input.dtype,
244 | layout=input.layout,
245 | device=input.device,
246 | )
247 | padded_input[: reshaped_input_shape[0], :] = reshaped_input
248 | reshaped_input = padded_input
249 |
250 | padded_input_padding_mask = torch.ones(
251 | (expected_dim,), dtype=torch.bool, device=padded_input.device
252 | )
253 | if reshaped_input_padding_mask is not None:
254 | padded_input_padding_mask[
255 | : reshaped_input_shape[0]
256 | ] = reshaped_input_padding_mask
257 | else:
258 | padded_input_padding_mask[: reshaped_input_shape[0]] = False
259 | reshaped_input_padding_mask = padded_input_padding_mask
260 |
261 | if has_tutel:
262 | l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(
263 | reshaped_input, reshaped_input_padding_mask
264 | )
265 | S, M = reshaped_input.size(0), reshaped_input.size(1)
266 |
267 | if not hasattr(self, "_tutel_dispatcher"):
268 | self._tutel_dispatcher = tutel_moe.fast_dispatcher(
269 | E, C, M, dispatch_dtype=reshaped_input.dtype
270 | )
271 | self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
272 | dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
273 | else:
274 | l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(
275 | reshaped_input, reshaped_input_padding_mask
276 | )
277 |
278 | dispatch_mask = dispatch_mask.to(input.dtype).permute(
279 | 1, 2, 0
280 | ) # S,E,C -> E,C,S
281 | E, C, S = dispatch_mask.size()
282 | M = reshaped_input.size(1)
283 | assert reshaped_input.size() == (S, M)
284 | # einsum("sec,sm->ecm")
285 | dispatched_input = torch.mm(
286 | dispatch_mask.view(E * C, S), reshaped_input
287 | ) # -> (E*C),M
288 |
289 | if self.all2all_size > 1:
290 | dispatched_input = self.all_to_all_wrapper(dispatched_input)
291 |
292 | # Re-shape after all-to-all: ecm -> gecm
293 | dispatched_input = dispatched_input.reshape(
294 | self.all2all_size, self.num_local_experts, -1, d_model
295 | )
296 | chunks = dispatched_input.chunk(self.num_local_experts, dim=1)
297 | expert_outputs = []
298 | for chunk, expert in zip(chunks, self.experts):
299 | expert_outputs += [expert(chunk)]
300 | expert_output = torch.cat(expert_outputs, dim=1)
301 |
302 | if self.all2all_size > 1:
303 | expert_output = self.all_to_all_wrapper(expert_output)
304 |
305 | # Re-shape back: gecm -> ecm
306 | expert_output = expert_output.reshape(
307 | self.all2all_size * self.num_local_experts, -1, d_model
308 | )
309 |
310 | if has_tutel:
311 | combined_output = self._tutel_dispatcher.decode(
312 | expert_output.view(E * C, M)
313 | )
314 | else:
315 | # einsum("sec,ecm->sm")
316 | combined_output = combine_weights.view(S, E * C).mm(
317 | expert_output.view(E * C, M)
318 | )
319 |
320 | # Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences
321 | combined_output = combined_output[: reshaped_input_shape[0], :]
322 | combined_output = combined_output.reshape(input.shape)
323 | combined_output = combined_output[: input_shape[0], :, :]
324 |
325 | self.record_all_to_all_stats()
326 |
327 | return combined_output, l_aux
328 |
329 | def prepare_for_inference_(self):
330 | self.in_generation = True
331 |
332 | def all_to_all_wrapper(self, input: Tensor):
333 | dummy_a2a = getattr(self.args, "dummy_a2a", False)
334 | if dummy_a2a:
335 | input = input.contiguous()
336 | output = input.detach().clone()
337 | return input
338 | # always record times, since it is not a lot of overhead
339 | # if we do not log it we simply clear it off in record_all_to_all_stats
340 | cuda_start = torch.cuda.Event(enable_timing=True)
341 | cuda_end = torch.cuda.Event(enable_timing=True)
342 | cpu_start = time.time() * 1000
343 | cuda_start.record()
344 | output = _AllToAll.apply(self.all2all_group, input)
345 | cuda_end.record()
346 | cpu_end = time.time() * 1000
347 | self.a2a_cpu_time_ms += cpu_end - cpu_start
348 | self.a2a_cuda_event_intervals.append((cuda_start, cuda_end))
349 | return output
350 |
351 | def record_all_to_all_stats(self):
352 | # controlled via an argument as we want to minimize any impact from torch.cuda.synchronize()
353 | record_a2a_perf_stats = getattr(self.args, "record_a2a_perf_stats", False)
354 | if record_a2a_perf_stats:
355 | torch.cuda.synchronize()
356 | self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms
357 | a2a_cuda_time_ms = 0.0
358 | for ev_start, ev_end in self.a2a_cuda_event_intervals:
359 | a2a_cuda_time_ms += ev_start.elapsed_time(ev_end)
360 | self.metadata["all_to_all_cuda_time_ms"] = a2a_cuda_time_ms
361 | # reset stats
362 | self.a2a_cpu_time_ms = 0.0
363 | self.a2a_cuda_event_intervals = []
364 |
--------------------------------------------------------------------------------
/torchscale/component/xmoe/routing.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
5 | #
6 | # This source code is licensed under the BSD license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | # Implementation of Top2Gating described in https://arxiv.org/pdf/2006.16668.pdf
10 | # Code is inspired by Top2GatingOnLogits from lingvo:
11 | # https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477
12 |
13 | # NOTE: This is a mirror of the code in
14 | # https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
15 |
16 | import math
17 | from typing import Callable, Dict, Optional, Tuple
18 |
19 | import torch
20 | import torch.nn.functional as F
21 | from torch import Tensor
22 |
23 | from .moe_layer import fused_cumsum_sub_one, has_tutel
24 |
25 | # use a fixed temperature to compute balance loss
26 | TEMPERATURE_FOR_L_UAX = 0.07
27 |
28 | # maximum capacity of 1 expert as a fraction of number of tokens in the batch
29 | # Note: setting this to 1.0 causes inference to significantly slow down
30 | EVAL_CAPACITY_TOKEN_FRACTION = 0.25
31 |
32 | # logging
33 | SAMPLE_FRACTION = 0.2
34 |
35 |
36 | def top1gating(
37 | logits: torch.Tensor,
38 | input_mask: Optional[torch.Tensor] = None,
39 | use_fp32=False,
40 | capacity_factor=1.0,
41 | eval_mode=False,
42 | moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION,
43 | use_xmoe=False,
44 | gate_obj=None,
45 | ) -> Tuple[Tensor, Tensor, Tensor, Dict]:
46 | """Implements Top2Gating on logits."""
47 | metadata = {}
48 | if use_fp32:
49 | orig_dtype = logits.dtype
50 | logits = logits.float()
51 |
52 | gates = F.softmax(logits, dim=1)
53 | metadata["entropy_gating"] = entropy(probs=gates).mean().detach()
54 |
55 | # gates has shape of SE
56 | num_tokens = gates.shape[0]
57 | num_experts = gates.shape[1]
58 | if moe_eval_capacity_token_fraction > 0.0 and eval_mode:
59 | capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens)
60 | else:
61 | # capacity = capacity_factor * S/E
62 | capacity = int(capacity_factor * math.ceil(num_tokens / num_experts))
63 |
64 | # Create a mask for 1st's expert per token
65 | indices1_s = torch.argmax(gates, dim=1)
66 | mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True)
67 | if input_mask is not None and input_mask.any():
68 | nonpadding = ~input_mask
69 | mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
70 |
71 | # for logging (percent of tokens routed to each expert)
72 | expert1_hist = (
73 | 100
74 | * torch.histc(
75 | (indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
76 | )
77 | / num_tokens
78 | )
79 | metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
80 | expert1_hist = (
81 | torch.sort(expert1_hist, dim=0, descending=True).values
82 | + torch.finfo(torch.float32).tiny
83 | )
84 |
85 | sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
86 | metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
87 | metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum()
88 |
89 | gates1_s = (gates * mask1).sum(dim=1)
90 |
91 | # Compute locations in capacity buffer
92 | locations1 = fused_cumsum_sub_one(mask1)
93 |
94 | # Compute l_aux
95 | me = torch.mean(gates, dim=0)
96 | ce = torch.mean(mask1.to(gates.dtype), dim=0)
97 |
98 | l_aux = torch.mean(me * ce)
99 | l_aux = l_aux * num_experts * num_experts
100 |
101 | if has_tutel:
102 | locations1_s = torch.sum(locations1 * mask1, dim=1)
103 | return (
104 | l_aux,
105 | metadata,
106 | capacity,
107 | num_experts,
108 | [
109 | indices1_s,
110 | ],
111 | [
112 | locations1_s,
113 | ],
114 | [
115 | gates1_s,
116 | ],
117 | )
118 |
119 | # Remove locations outside capacity from mask
120 | mask1 = mask1 * torch.lt(locations1, capacity)
121 | # Store the capacity location for each token
122 | locations1_s = torch.sum(locations1 * mask1, dim=1)
123 |
124 | # Calculate combine_weights and dispatch_mask
125 | gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se")
126 | # locations1_sc = num_tokens * capacity
127 | locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
128 | combine1_sec = torch.bmm(
129 | # einsum("se,sc->sec")
130 | gates1.unsqueeze(-1),
131 | locations1_sc.to(gates1.dtype).unsqueeze(1),
132 | )
133 | dispatch_mask = combine1_sec.bool()
134 | if use_fp32:
135 | return l_aux, combine1_sec.to(orig_dtype), dispatch_mask, metadata
136 | else:
137 | return l_aux, combine1_sec, dispatch_mask, metadata
138 |
139 |
140 | class Top1Gate(torch.nn.Module):
141 | """Gate module which implements Top2Gating as described in Gshard_.
142 | ::
143 |
144 | gate = Top2Gate(model_dim, num_experts)
145 | l_aux, combine_weights, dispatch_mask = gate(input)
146 |
147 | .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
148 |
149 | Args:
150 | model_dim (int):
151 | size of model embedding dimension
152 | num_experts (ints):
153 | number of experts in model
154 | """
155 |
156 | wg: torch.nn.Linear
157 |
158 | def __init__(
159 | self,
160 | model_dim: int,
161 | num_experts: int,
162 | use_fp32=False,
163 | input_noise_type=None,
164 | capacity_factor=1.0,
165 | moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION,
166 | use_xmoe=False,
167 | ) -> None:
168 | # TODO: merge this to top2gate.py
169 | #
170 | super().__init__()
171 |
172 | if not use_xmoe:
173 | self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
174 | else:
175 | self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False)
176 | wg = torch.empty(num_experts, 16)
177 | torch.nn.init.orthogonal_(wg, gain=0.32)
178 | self.register_parameter("wg", torch.nn.Parameter(wg))
179 |
180 | self.use_xmoe = use_xmoe
181 | self.use_fp32 = use_fp32
182 | self.input_noise_type = input_noise_type
183 | self.capacity_factor = capacity_factor
184 | self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction
185 |
186 | def forward(self, input, mask=None): # type: ignore
187 | if self.use_xmoe:
188 | input = self.wg_reduction(input)
189 | with torch.no_grad():
190 | wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True)
191 | self.wg.mul_(1.5 / wg_norm)
192 | logits = self._cosine(input, self.wg)
193 | logits = self._make_finite(logits)
194 | else:
195 | logits = self.wg(input)
196 |
197 | return top1gating(
198 | logits,
199 | mask,
200 | use_fp32=self.use_fp32,
201 | capacity_factor=self.capacity_factor,
202 | eval_mode=not self.training,
203 | moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction,
204 | use_xmoe=self.use_xmoe,
205 | gate_obj=self,
206 | )
207 |
208 | def _make_finite(self, scores):
209 | ok = scores.isfinite()
210 | if not ok.all():
211 | # NaNs here can break the assignment algorithm
212 | scores[~ok] = scores[ok].min()
213 | return scores
214 |
215 | def _get_gating_temperature(self, eps=1e-4):
216 | if self.gating_t.data.item() < eps:
217 | return eps
218 | return self.gating_t
219 |
220 | def _cosine(self, mat1, mat2, eps=1e-4):
221 | assert mat1.dim() == 2
222 | assert mat2.dim() == 2
223 | # mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps)
224 | mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps)
225 | return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1)
226 |
227 |
228 | gumbel_map: Dict[torch.device, Callable] = {}
229 |
230 |
231 | def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
232 | gumbel = gumbel_map.get(device)
233 | if gumbel is None:
234 | one = torch.tensor(1.0, device=device)
235 | zero = torch.tensor(0.0, device=device)
236 | gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
237 | gumbel_map[device] = gumbel
238 | return gumbel(shape)
239 |
240 |
241 | def one_hot(indices: torch.Tensor, num_classes: int, unsqueeze_indices=False) -> Tensor:
242 | if unsqueeze_indices:
243 | indices = indices.unsqueeze(-1)
244 | assert indices.shape[-1] == 1, "last dimension of indices must be have size 1"
245 | output = torch.zeros(
246 | indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype
247 | )
248 | output.scatter_(len(output.shape) - 1, indices, 1)
249 | return output
250 |
251 |
252 | def entropy(probs):
253 | logits = torch.distributions.utils.probs_to_logits(probs)
254 | p_log_p = probs * logits
255 | return -p_log_p.sum(-1)
256 |
257 |
258 | def top2gating(
259 | logits: torch.Tensor,
260 | input_mask: Optional[torch.Tensor] = None,
261 | use_fp32=False,
262 | second_expert_policy="sampling",
263 | normalize_gate_prob_before_dropping=False,
264 | eval_mode=False,
265 | moe_eval_capacity_token_fraction=0.25,
266 | batch_prioritized_routing=False,
267 | ) -> Tuple[Tensor, Tensor, Tensor]:
268 | """Implements Top2Gating on logits."""
269 | metadata = {}
270 | if use_fp32:
271 | orig_dtype = logits.dtype
272 | logits = logits.float()
273 | gates = F.softmax(logits, dim=1)
274 | metadata["entropy_gating"] = entropy(probs=gates).mean().detach()
275 | # gates has shape of SE
276 | num_tokens = gates.shape[0]
277 | num_experts = gates.shape[1]
278 | if moe_eval_capacity_token_fraction > 0.0 and eval_mode:
279 | capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens)
280 | else:
281 | # capacity = 2S/E
282 | capacity = 2 * math.ceil(num_tokens / num_experts)
283 |
284 | # Create a mask for 1st's expert per token
285 | indices1_s = torch.argmax(gates, dim=1, keepdim=True)
286 | mask1 = one_hot(indices1_s, num_experts)
287 | if second_expert_policy == "sampling":
288 | # Create a mask for 2nd's expert per token using Gumbel-max trick
289 | # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
290 | logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
291 | else:
292 | logits_w_noise = logits
293 | # Replace top-expert with min value
294 | logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
295 | indices2_s = torch.argmax(logits_except1, dim=1, keepdim=True)
296 | mask2 = one_hot(indices2_s, num_experts)
297 | gates1_s = (gates * mask1).sum(dim=1)
298 | gates2_s = (gates * mask2).sum(dim=1)
299 |
300 | if normalize_gate_prob_before_dropping:
301 | # Normalize gate probabilities
302 | denom_s = gates1_s + gates2_s
303 | # Avoid divide-by-zero
304 | denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
305 | gates1_s = gates1_s / denom_s
306 | gates2_s = gates2_s / denom_s
307 |
308 | if second_expert_policy == "random":
309 | sampled = (2 * gates2_s) > torch.rand_like(gates2_s)
310 | mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0)
311 |
312 | # Compute locations in capacity buffer
313 | if input_mask is not None and input_mask.any():
314 | nonpadding = ~input_mask
315 | mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
316 | mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype)
317 |
318 | if batch_prioritized_routing:
319 | # if batch_prioritized_routing:
320 | importance_scores = -1 * gates.max(dim=1)[0]
321 | sorted_mask1 = mask1[importance_scores.argsort(dim=0)]
322 | sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1
323 | importance_sorted_locations1 = sorted_cumsum1[
324 | importance_scores.argsort(dim=0).argsort(dim=0)
325 | ]
326 |
327 | sorted_mask2 = mask2[importance_scores.argsort(dim=0)]
328 | sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2
329 | importance_sorted_locations2 = sorted_cumsum2[
330 | importance_scores.argsort(dim=0).argsort(dim=0)
331 | ]
332 |
333 | importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True)
334 |
335 | locations1, locations2 = (
336 | importance_sorted_locations1,
337 | importance_sorted_locations2,
338 | )
339 | else:
340 | locations1 = fused_cumsum_sub_one(mask1)
341 | locations2 = fused_cumsum_sub_one(mask2)
342 | # Update 2nd's location by accounting for locations of 1st
343 | locations2 += torch.sum(mask1, dim=0, keepdim=True)
344 |
345 | # Compute l_aux
346 | me = torch.mean(gates, dim=0)
347 | ce = torch.mean(mask1.to(gates.dtype), dim=0)
348 | l_aux = torch.mean(me * ce)
349 | l_aux = l_aux * num_experts * num_experts
350 |
351 | # for logging purposes
352 | metadata["overflow_expert1"] = (
353 | 100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1)
354 | )
355 | metadata["overflow_expert2"] = (
356 | 100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2)
357 | )
358 |
359 | # Remove locations outside capacity from mask
360 | mask1_, mask2_ = mask1, mask2
361 | mask1 = mask1 * torch.lt(locations1, capacity)
362 | mask2 = mask2 * torch.lt(locations2, capacity)
363 |
364 | # for logging (percent of tokens routed to each expert)
365 | expert1_hist = (
366 | 100
367 | * torch.histc(
368 | (indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
369 | )
370 | / num_tokens
371 | )
372 | metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
373 | expert1_hist = (
374 | torch.sort(expert1_hist, dim=0, descending=True).values
375 | + torch.finfo(torch.float32).tiny
376 | )
377 |
378 | expert2_hist = (
379 | 100
380 | * torch.histc(
381 | (indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
382 | )
383 | / num_tokens
384 | )
385 | metadata["unused_expert2_count"] = (expert2_hist == 0).sum()
386 | expert2_hist = (
387 | torch.sort(expert2_hist, dim=0, descending=True).values
388 | + torch.finfo(torch.float32).tiny
389 | )
390 |
391 | sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
392 | metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
393 | metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum()
394 |
395 | metadata["expert2_balance_top"] = expert2_hist[:sample_count].sum()
396 | metadata["expert2_balance_bottom"] = expert2_hist[-sample_count:].sum()
397 |
398 | if not normalize_gate_prob_before_dropping:
399 | # Normalize gate probabilities
400 | gates1_s = (gates * mask1).sum(dim=1)
401 | gates2_s = (gates * mask2).sum(dim=1)
402 | denom_s = gates1_s + gates2_s
403 | # Avoid divide-by-zero
404 | denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
405 | gates1_s /= denom_s
406 | gates2_s /= denom_s
407 |
408 | if has_tutel:
409 | locations1_s = torch.sum(locations1 * mask1_, dim=1)
410 | locations2_s = torch.sum(locations2 * mask2_, dim=1)
411 | return (
412 | l_aux,
413 | metadata,
414 | capacity,
415 | num_experts,
416 | [indices1_s, indices2_s],
417 | [locations1_s, locations2_s],
418 | [gates1_s, gates2_s],
419 | )
420 |
421 | # Store the capacity location for each token
422 | locations1_s = torch.sum(locations1 * mask1, dim=1)
423 | locations2_s = torch.sum(locations2 * mask2, dim=1)
424 |
425 | # Calculate combine_weights and dispatch_mask
426 | gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se")
427 | gates2 = gates2_s.unsqueeze(-1) * mask2.to(gates2_s.dtype) # einsum("s,se->se")
428 | locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
429 | locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True)
430 | combine1_sec = torch.bmm(
431 | # einsum("se,sc->sec")
432 | gates1.unsqueeze(-1),
433 | locations1_sc.to(gates1.dtype).unsqueeze(1),
434 | )
435 | combine2_sec = torch.bmm(
436 | # einsum("se,sc->sec")
437 | gates2.unsqueeze(-1),
438 | locations2_sc.to(gates2.dtype).unsqueeze(1),
439 | )
440 | combine_weights = combine1_sec + combine2_sec
441 | dispatch_mask = combine_weights.bool()
442 | if use_fp32:
443 | return l_aux, combine_weights.to(orig_dtype), dispatch_mask, metadata
444 | else:
445 | return l_aux, combine_weights, dispatch_mask, metadata
446 |
447 |
448 | class Top2Gate(torch.nn.Module):
449 | """Gate module which implements Top2Gating as described in Gshard_.
450 | ::
451 |
452 | gate = Top2Gate(model_dim, num_experts)
453 | l_aux, combine_weights, dispatch_mask = gate(input)
454 |
455 | .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
456 |
457 | Args:
458 | model_dim (int):
459 | size of model embedding dimension
460 | num_experts (ints):
461 | number of experts in model
462 | """
463 |
464 | wg: torch.nn.Linear
465 |
466 | def __init__(
467 | self,
468 | model_dim: int,
469 | num_experts: int,
470 | use_fp32=False,
471 | second_expert_policy="sampling",
472 | normalize_gate_prob_before_dropping=False,
473 | moe_eval_capacity_token_fraction=0.25,
474 | batch_prioritized_routing=False,
475 | use_xmoe=False,
476 | ) -> None:
477 | super().__init__()
478 | if not use_xmoe:
479 | self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
480 | else:
481 | self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False)
482 | wg = torch.empty(num_experts, 16)
483 | torch.nn.init.orthogonal_(wg, gain=0.32)
484 | self.register_parameter("wg", torch.nn.Parameter(wg))
485 | self.use_fp32 = use_fp32
486 | self.second_expert_policy = second_expert_policy
487 | self.normalize_gate_prob_before_dropping = normalize_gate_prob_before_dropping
488 | self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction
489 | self.batch_prioritized_routing = batch_prioritized_routing
490 | self.use_xmoe = use_xmoe
491 |
492 | def forward(self, input, mask=None): # type: ignore
493 | if self.use_xmoe:
494 | input = self.wg_reduction(input)
495 | with torch.no_grad():
496 | wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True)
497 | self.wg.mul_(1.5 / wg_norm)
498 | logits = self._cosine(input, self.wg)
499 | logits = self._make_finite(logits)
500 | else:
501 | logits = self.wg(input)
502 | return top2gating(
503 | logits,
504 | mask,
505 | use_fp32=self.use_fp32,
506 | second_expert_policy=self.second_expert_policy,
507 | normalize_gate_prob_before_dropping=self.normalize_gate_prob_before_dropping,
508 | eval_mode=not self.training,
509 | moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction,
510 | batch_prioritized_routing=self.batch_prioritized_routing,
511 | )
512 |
513 | def _cosine(self, mat1, mat2, eps=1e-4):
514 | assert mat1.dim() == 2
515 | assert mat2.dim() == 2
516 | # mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps)
517 | mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps)
518 | return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1)
519 |
520 | def _make_finite(self, scores):
521 | ok = scores.isfinite()
522 | if not ok.all():
523 | # NaNs here can break the assignment algorithm
524 | scores[~ok] = scores[ok].min()
525 | return scores
526 |
--------------------------------------------------------------------------------
/torchscale/component/xpos_relative_position.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | def fixed_pos_embedding(x):
8 | seq_len, dim = x.shape
9 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim))
10 | sinusoid_inp = (
11 | torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x)
12 | )
13 | return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
14 |
15 |
16 | class XPos(nn.Module):
17 | def __init__(
18 | self, head_dim, scale_base = 512
19 | ):
20 | super().__init__()
21 | self.head_dim = head_dim
22 | self.scale_base = scale_base
23 | self.register_buffer(
24 | "scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim)
25 | )
26 |
27 | def forward(self, len):
28 | scale = self.scale ** (torch.arange(0, len, 1) - len // 2).to(self.scale).div(self.scale_base)[:, None]
29 | sin, cos = fixed_pos_embedding(scale)
30 | return (sin, cos, scale)
31 |
--------------------------------------------------------------------------------
/torchscale/model/BEiT3.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torchscale.architecture.encoder import Encoder
8 | from torchscale.component.embedding import (
9 | PositionalEmbedding,
10 | TextEmbedding,
11 | VisionEmbedding,
12 | )
13 | from torchscale.component.multiway_network import MultiwayWrapper
14 |
15 |
16 | class BEiT3(nn.Module):
17 | def __init__(self, args, **kwargs):
18 | super().__init__()
19 | self.args = args
20 | assert args.multiway
21 | assert args.vocab_size > 0
22 | assert not args.share_encoder_input_output_embed
23 | self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim)
24 | self.vision_embed = VisionEmbedding(
25 | args.img_size,
26 | args.patch_size,
27 | args.in_chans,
28 | args.encoder_embed_dim,
29 | contain_mask_token=True,
30 | prepend_cls_token=True,
31 | )
32 | embed_positions = MultiwayWrapper(
33 | args,
34 | PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
35 | dim=1,
36 | )
37 | self.encoder = Encoder(
38 | args,
39 | embed_tokens=None,
40 | embed_positions=embed_positions,
41 | output_projection=None,
42 | is_encoder_decoder=False,
43 | )
44 |
45 | def forward(
46 | self,
47 | textual_tokens=None,
48 | visual_tokens=None,
49 | text_padding_position=None,
50 | vision_masked_position=None,
51 | ):
52 | assert textual_tokens is not None or visual_tokens is not None
53 |
54 | if textual_tokens is None:
55 | x = self.vision_embed(visual_tokens, vision_masked_position)
56 | encoder_padding_mask = None
57 | multiway_split_position = -1
58 | elif visual_tokens is None:
59 | x = self.text_embed(textual_tokens)
60 | encoder_padding_mask = text_padding_position
61 | multiway_split_position = 0
62 | else:
63 | x1 = self.vision_embed(visual_tokens, vision_masked_position)
64 | multiway_split_position = x1.size(1)
65 | x2 = self.text_embed(textual_tokens)
66 | x = torch.cat([x1, x2], dim=1)
67 |
68 | if text_padding_position is not None:
69 | encoder_padding_mask = torch.cat(
70 | [
71 | torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
72 | text_padding_position,
73 | ],
74 | dim=1,
75 | )
76 | else:
77 | encoder_padding_mask = None
78 |
79 | encoder_out = self.encoder(
80 | src_tokens=None,
81 | encoder_padding_mask=encoder_padding_mask,
82 | token_embeddings=x,
83 | multiway_split_position=multiway_split_position,
84 | )
85 |
86 | return encoder_out
87 |
--------------------------------------------------------------------------------
/torchscale/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------