├── .github
└── workflows
│ └── test.yml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── examples
├── __init__.py
├── fairseq
│ ├── README.md
│ ├── __init__.py
│ ├── criterions
│ │ ├── __init__.py
│ │ └── masked_lm_moe.py
│ ├── generate.py
│ ├── interactive.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── bert.py
│ │ ├── language_modeling.py
│ │ ├── machine_translation.py
│ │ └── retnet.py
│ ├── tasks
│ │ ├── __init__.py
│ │ ├── data
│ │ │ ├── __init__.py
│ │ │ ├── basic_loader.py
│ │ │ ├── mlm_loader.py
│ │ │ └── utils.py
│ │ └── pretraining.py
│ ├── train.py
│ └── utils
│ │ ├── __init__.py
│ │ └── sparse_clip.py
└── longvit
│ ├── README.md
│ ├── data_preprocessing
│ ├── cache_transformed_images.py
│ ├── convert_wsi_to_images.py
│ ├── create_tcga_subtyping_index.py
│ ├── create_tcga_survival_index.py
│ ├── generate_1024_crops.py
│ └── split_to_small_images.py
│ ├── datasets.py
│ ├── engine_for_finetuning.py
│ ├── get_started
│ ├── get_started_for_tcga_pretraining.md
│ ├── get_started_for_tcga_subtyping.md
│ └── get_started_for_tcga_survival_prediction.md
│ ├── longvit.py
│ ├── modeling_finetune.py
│ ├── optim_factory.py
│ ├── pretraining
│ └── vision_transformer.py
│ ├── requirements.txt
│ ├── run_longvit_finetuning.py
│ └── utils.py
├── setup.py
├── tests
├── __init__.py
├── test_decoder.py
├── test_encoder.py
└── test_encoder_decoder.py
└── torchscale
├── __init__.py
├── architecture
├── __init__.py
├── config.py
├── decoder.py
├── encoder.py
├── encoder_decoder.py
├── retnet.py
└── utils.py
├── component
├── __init__.py
├── dilated_attention.py
├── droppath.py
├── embedding.py
├── feedforward_network.py
├── flash_attention.py
├── gate_linear_unit.py
├── multihead_attention.py
├── multiscale_retention.py
├── multiway_network.py
├── relative_position_bias.py
├── rms_norm.py
├── utils.py
├── xmoe
│ ├── __init__.py
│ ├── global_groups.py
│ ├── moe_layer.py
│ └── routing.py
└── xpos_relative_position.py
└── model
├── BEiT3.py
├── LongNet.py
└── __init__.py
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Python package
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 |
10 | steps:
11 | - uses: actions/checkout@v2
12 | - name: Set up Python 3.10
13 | uses: actions/setup-python@v2
14 | with:
15 | python-version: "3.10"
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install --upgrade pip
19 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
20 | if [ -f setup.py ]; then pip install .; fi
21 | - name: Install pytest
22 | run: |
23 | pip install pytest
24 | - name: Run tests
25 | run: |
26 | pytest tests/
27 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 | ##
4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
5 |
6 | # User-specific files
7 | *.rsuser
8 | *.suo
9 | *.user
10 | *.userosscache
11 | *.sln.docstates
12 |
13 | # User-specific files (MonoDevelop/Xamarin Studio)
14 | *.userprefs
15 |
16 | # Mono auto generated files
17 | mono_crash.*
18 |
19 | # Build results
20 | [Dd]ebug/
21 | [Dd]ebugPublic/
22 | [Rr]elease/
23 | [Rr]eleases/
24 | x64/
25 | x86/
26 | [Aa][Rr][Mm]/
27 | [Aa][Rr][Mm]64/
28 | bld/
29 | [Bb]in/
30 | [Oo]bj/
31 | [Ll]og/
32 | [Ll]ogs/
33 |
34 | # Visual Studio 2015/2017 cache/options directory
35 | .vs/
36 | # Uncomment if you have tasks that create the project's static files in wwwroot
37 | #wwwroot/
38 |
39 | # Visual Studio 2017 auto generated files
40 | Generated\ Files/
41 |
42 | # MSTest test Results
43 | [Tt]est[Rr]esult*/
44 | [Bb]uild[Ll]og.*
45 |
46 | # NUnit
47 | *.VisualState.xml
48 | TestResult.xml
49 | nunit-*.xml
50 |
51 | # Build Results of an ATL Project
52 | [Dd]ebugPS/
53 | [Rr]eleasePS/
54 | dlldata.c
55 |
56 | # Benchmark Results
57 | BenchmarkDotNet.Artifacts/
58 |
59 | # .NET Core
60 | project.lock.json
61 | project.fragment.lock.json
62 | artifacts/
63 |
64 | # StyleCop
65 | StyleCopReport.xml
66 |
67 | # Files built by Visual Studio
68 | *_i.c
69 | *_p.c
70 | *_h.h
71 | *.ilk
72 | *.meta
73 | *.obj
74 | *.iobj
75 | *.pch
76 | *.pdb
77 | *.ipdb
78 | *.pgc
79 | *.pgd
80 | *.rsp
81 | *.sbr
82 | *.tlb
83 | *.tli
84 | *.tlh
85 | *.tmp
86 | *.tmp_proj
87 | *_wpftmp.csproj
88 | *.log
89 | *.vspscc
90 | *.vssscc
91 | .builds
92 | *.pidb
93 | *.svclog
94 | *.scc
95 |
96 | # Chutzpah Test files
97 | _Chutzpah*
98 |
99 | # Visual C++ cache files
100 | ipch/
101 | *.aps
102 | *.ncb
103 | *.opendb
104 | *.opensdf
105 | *.sdf
106 | *.cachefile
107 | *.VC.db
108 | *.VC.VC.opendb
109 |
110 | # Visual Studio profiler
111 | *.psess
112 | *.vsp
113 | *.vspx
114 | *.sap
115 |
116 | # Visual Studio Trace Files
117 | *.e2e
118 |
119 | # TFS 2012 Local Workspace
120 | $tf/
121 |
122 | # Guidance Automation Toolkit
123 | *.gpState
124 |
125 | # ReSharper is a .NET coding add-in
126 | _ReSharper*/
127 | *.[Rr]e[Ss]harper
128 | *.DotSettings.user
129 |
130 | # TeamCity is a build add-in
131 | _TeamCity*
132 |
133 | # DotCover is a Code Coverage Tool
134 | *.dotCover
135 |
136 | # AxoCover is a Code Coverage Tool
137 | .axoCover/*
138 | !.axoCover/settings.json
139 |
140 | # Visual Studio code coverage results
141 | *.coverage
142 | *.coveragexml
143 |
144 | # NCrunch
145 | _NCrunch_*
146 | .*crunch*.local.xml
147 | nCrunchTemp_*
148 |
149 | # MightyMoose
150 | *.mm.*
151 | AutoTest.Net/
152 |
153 | # Web workbench (sass)
154 | .sass-cache/
155 |
156 | # Installshield output folder
157 | [Ee]xpress/
158 |
159 | # DocProject is a documentation generator add-in
160 | DocProject/buildhelp/
161 | DocProject/Help/*.HxT
162 | DocProject/Help/*.HxC
163 | DocProject/Help/*.hhc
164 | DocProject/Help/*.hhk
165 | DocProject/Help/*.hhp
166 | DocProject/Help/Html2
167 | DocProject/Help/html
168 |
169 | # Click-Once directory
170 | publish/
171 |
172 | # Publish Web Output
173 | *.[Pp]ublish.xml
174 | *.azurePubxml
175 | # Note: Comment the next line if you want to checkin your web deploy settings,
176 | # but database connection strings (with potential passwords) will be unencrypted
177 | *.pubxml
178 | *.publishproj
179 |
180 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
181 | # checkin your Azure Web App publish settings, but sensitive information contained
182 | # in these scripts will be unencrypted
183 | PublishScripts/
184 |
185 | # NuGet Packages
186 | *.nupkg
187 | # NuGet Symbol Packages
188 | *.snupkg
189 | # The packages folder can be ignored because of Package Restore
190 | **/[Pp]ackages/*
191 | # except build/, which is used as an MSBuild target.
192 | !**/[Pp]ackages/build/
193 | # Uncomment if necessary however generally it will be regenerated when needed
194 | #!**/[Pp]ackages/repositories.config
195 | # NuGet v3's project.json files produces more ignorable files
196 | *.nuget.props
197 | *.nuget.targets
198 |
199 | # Microsoft Azure Build Output
200 | csx/
201 | *.build.csdef
202 |
203 | # Microsoft Azure Emulator
204 | ecf/
205 | rcf/
206 |
207 | # Windows Store app package directories and files
208 | AppPackages/
209 | BundleArtifacts/
210 | Package.StoreAssociation.xml
211 | _pkginfo.txt
212 | *.appx
213 | *.appxbundle
214 | *.appxupload
215 |
216 | # Visual Studio cache files
217 | # files ending in .cache can be ignored
218 | *.[Cc]ache
219 | # but keep track of directories ending in .cache
220 | !?*.[Cc]ache/
221 |
222 | # Others
223 | ClientBin/
224 | ~$*
225 | *~
226 | *.dbmdl
227 | *.dbproj.schemaview
228 | *.jfm
229 | *.pfx
230 | *.publishsettings
231 | orleans.codegen.cs
232 |
233 | # Including strong name files can present a security risk
234 | # (https://github.com/github/gitignore/pull/2483#issue-259490424)
235 | #*.snk
236 |
237 | # Since there are multiple workflows, uncomment next line to ignore bower_components
238 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
239 | #bower_components/
240 |
241 | # RIA/Silverlight projects
242 | Generated_Code/
243 |
244 | # Backup & report files from converting an old project file
245 | # to a newer Visual Studio version. Backup files are not needed,
246 | # because we have git ;-)
247 | _UpgradeReport_Files/
248 | Backup*/
249 | UpgradeLog*.XML
250 | UpgradeLog*.htm
251 | ServiceFabricBackup/
252 | *.rptproj.bak
253 |
254 | # SQL Server files
255 | *.mdf
256 | *.ldf
257 | *.ndf
258 |
259 | # Business Intelligence projects
260 | *.rdl.data
261 | *.bim.layout
262 | *.bim_*.settings
263 | *.rptproj.rsuser
264 | *- [Bb]ackup.rdl
265 | *- [Bb]ackup ([0-9]).rdl
266 | *- [Bb]ackup ([0-9][0-9]).rdl
267 |
268 | # Microsoft Fakes
269 | FakesAssemblies/
270 |
271 | # GhostDoc plugin setting file
272 | *.GhostDoc.xml
273 |
274 | # Node.js Tools for Visual Studio
275 | .ntvs_analysis.dat
276 | node_modules/
277 |
278 | # Visual Studio 6 build log
279 | *.plg
280 |
281 | # Visual Studio 6 workspace options file
282 | *.opt
283 |
284 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
285 | *.vbw
286 |
287 | # Visual Studio LightSwitch build output
288 | **/*.HTMLClient/GeneratedArtifacts
289 | **/*.DesktopClient/GeneratedArtifacts
290 | **/*.DesktopClient/ModelManifest.xml
291 | **/*.Server/GeneratedArtifacts
292 | **/*.Server/ModelManifest.xml
293 | _Pvt_Extensions
294 |
295 | # Paket dependency manager
296 | .paket/paket.exe
297 | paket-files/
298 |
299 | # FAKE - F# Make
300 | .fake/
301 |
302 | # CodeRush personal settings
303 | .cr/personal
304 |
305 | # Python Tools for Visual Studio (PTVS)
306 | __pycache__/
307 | *.pyc
308 |
309 | # Cake - Uncomment if you are using it
310 | # tools/**
311 | # !tools/packages.config
312 |
313 | # Tabs Studio
314 | *.tss
315 |
316 | # Telerik's JustMock configuration file
317 | *.jmconfig
318 |
319 | # BizTalk build output
320 | *.btp.cs
321 | *.btm.cs
322 | *.odx.cs
323 | *.xsd.cs
324 |
325 | # OpenCover UI analysis results
326 | OpenCover/
327 |
328 | # Azure Stream Analytics local run output
329 | ASALocalRun/
330 |
331 | # MSBuild Binary and Structured Log
332 | *.binlog
333 |
334 | # NVidia Nsight GPU debugger configuration file
335 | *.nvuser
336 |
337 | # MFractors (Xamarin productivity tool) working folder
338 | .mfractor/
339 |
340 | # Local History for Visual Studio
341 | .localhistory/
342 |
343 | # BeatPulse healthcheck temp database
344 | healthchecksdb
345 |
346 | # Backup folder for Package Reference Convert tool in Visual Studio 2017
347 | MigrationBackup/
348 |
349 | # Ionide (cross platform F# VS Code tools) working folder
350 | .ionide/
351 |
352 |
353 | # Byte-compiled / optimized / DLL files
354 | __pycache__/
355 | *.py[cod]
356 | *$py.class
357 |
358 | # C extensions
359 | *.so
360 |
361 | # Distribution / packaging
362 | .Python
363 | build/
364 | develop-eggs/
365 | dist/
366 | downloads/
367 | eggs/
368 | .eggs/
369 | lib/
370 | lib64/
371 | parts/
372 | sdist/
373 | var/
374 | wheels/
375 | share/python-wheels/
376 | *.egg-info/
377 | .installed.cfg
378 | *.egg
379 | MANIFEST
380 |
381 | # PyInstaller
382 | # Usually these files are written by a python script from a template
383 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
384 | *.manifest
385 | *.spec
386 |
387 | # Installer logs
388 | pip-log.txt
389 | pip-delete-this-directory.txt
390 |
391 | # Unit test / coverage reports
392 | htmlcov/
393 | .tox/
394 | .nox/
395 | .coverage
396 | .coverage.*
397 | .cache
398 | nosetests.xml
399 | coverage.xml
400 | *.cover
401 | *.py,cover
402 | .hypothesis/
403 | .pytest_cache/
404 | cover/
405 |
406 | # Translations
407 | *.mo
408 | *.pot
409 |
410 | # Django stuff:
411 | *.log
412 | local_settings.py
413 | db.sqlite3
414 | db.sqlite3-journal
415 |
416 | # Flask stuff:
417 | instance/
418 | .webassets-cache
419 |
420 | # Scrapy stuff:
421 | .scrapy
422 |
423 | # Sphinx documentation
424 | docs/_build/
425 |
426 | # PyBuilder
427 | .pybuilder/
428 | target/
429 |
430 | # Jupyter Notebook
431 | .ipynb_checkpoints
432 |
433 | # IPython
434 | profile_default/
435 | ipython_config.py
436 |
437 | # pyenv
438 | # For a library or package, you might want to ignore these files since the code is
439 | # intended to run in multiple environments; otherwise, check them in:
440 | # .python-version
441 |
442 | # pipenv
443 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
444 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
445 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
446 | # install all needed dependencies.
447 | #Pipfile.lock
448 |
449 | # poetry
450 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
451 | # This is especially recommended for binary packages to ensure reproducibility, and is more
452 | # commonly ignored for libraries.
453 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
454 | #poetry.lock
455 |
456 | # pdm
457 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
458 | #pdm.lock
459 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
460 | # in version control.
461 | # https://pdm.fming.dev/#use-with-ide
462 | .pdm.toml
463 |
464 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
465 | __pypackages__/
466 |
467 | # Celery stuff
468 | celerybeat-schedule
469 | celerybeat.pid
470 |
471 | # SageMath parsed files
472 | *.sage.py
473 |
474 | # Environments
475 | .env
476 | .venv
477 | env/
478 | venv/
479 | ENV/
480 | env.bak/
481 | venv.bak/
482 |
483 | # Spyder project settings
484 | .spyderproject
485 | .spyproject
486 |
487 | # Rope project settings
488 | .ropeproject
489 |
490 | # mkdocs documentation
491 | /site
492 |
493 | # mypy
494 | .mypy_cache/
495 | .dmypy.json
496 | dmypy.json
497 |
498 | # Pyre type checker
499 | .pyre/
500 |
501 | # pytype static type analyzer
502 | .pytype/
503 |
504 | # Cython debug symbols
505 | cython_debug/
506 |
507 | # PyCharm
508 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
509 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
510 | # and can be added to the global gitignore or merged into this file. For a more nuclear
511 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
512 | #.idea/
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TorchScale - A Library of Foundation Architectures
2 |
3 |
4 |
5 |
6 |
7 |
8 | TorchScale is a PyTorch library that allows researchers and developers to scale up Transformers efficiently and effectively.
9 |
10 | Fundamental research to develop new architectures for foundation models and A(G)I, focusing on modeling generality and capability, as well as training stability and efficiency.
11 | - Stability - [**DeepNet**](https://arxiv.org/abs/2203.00555): scaling Transformers to 1,000 Layers and beyond
12 | - Generality - [**Foundation Transformers (Magneto)**](https://arxiv.org/abs/2210.06423): towards true general-purpose modeling across tasks and modalities (including language, vision, speech, and multimodal)
13 | - Capability - A [**Length-Extrapolatable**](https://arxiv.org/abs/2212.10554) Transformer
14 | - Efficiency - [**X-MoE**](https://arxiv.org/abs/2204.09179): scalable & finetunable sparse Mixture-of-Experts (MoE)
15 |
16 | ### The Revolution of Model Architecture
17 | - [**BitNet**](https://arxiv.org/abs/2310.11453): 1-bit Transformers for Large Language Models
18 | - [**RetNet**](https://arxiv.org/abs/2307.08621): Retentive Network: A Successor to Transformer for Large Language Models
19 | - [**LongNet**](https://arxiv.org/abs/2307.02486): Scaling Transformers to 1,000,000,000 Tokens
20 |
21 | ## News
22 |
23 | - December, 2023: [LongNet](torchscale/model/LongNet.py) and [LongViT](examples/longvit/README.md) released
24 | - October, 2023: Update RMSNorm and SwiGLU as the default module in RetNet
25 | - November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)]
26 |
27 | ## Installation
28 |
29 | To install:
30 | ```
31 | pip install torchscale
32 | ```
33 |
34 | Alternatively, you can develop it locally:
35 | ```
36 | git clone https://github.com/microsoft/torchscale.git
37 | cd torchscale
38 | pip install -e .
39 | ```
40 |
41 | For faster training install [Flash Attention](https://github.com/Dao-AILab/flash-attention) for Turing, Ampere, Ada, or Hopper GPUs:
42 | ```
43 | pip install flash-attn
44 | ```
45 | or [xFormers](https://github.com/facebookresearch/xformers) for Volta, Turing, Ampere, Ada, or Hopper GPUs:
46 | ```
47 | # cuda 11.8 version
48 | pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118
49 | # cuda 12.1 version
50 | pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121
51 | ```
52 |
53 | ## Getting Started
54 |
55 | It takes only several lines of code to create a model with the above fundamental research features enabled. Here is how to quickly obtain a BERT-like encoder:
56 |
57 | ```python
58 | >>> from torchscale.architecture.config import EncoderConfig
59 | >>> from torchscale.architecture.encoder import Encoder
60 |
61 | >>> config = EncoderConfig(vocab_size=64000)
62 | >>> model = Encoder(config)
63 |
64 | >>> print(model)
65 | ```
66 |
67 | We also support the `Decoder` architecture and the `EncoderDecoder` architecture:
68 |
69 | ```python
70 | # Creating a decoder model
71 | >>> from torchscale.architecture.config import DecoderConfig
72 | >>> from torchscale.architecture.decoder import Decoder
73 |
74 | >>> config = DecoderConfig(vocab_size=64000)
75 | >>> decoder = Decoder(config)
76 | >>> print(decoder)
77 |
78 | # Creating a encoder-decoder model
79 | >>> from torchscale.architecture.config import EncoderDecoderConfig
80 | >>> from torchscale.architecture.encoder_decoder import EncoderDecoder
81 |
82 | >>> config = EncoderDecoderConfig(vocab_size=64000)
83 | >>> encdec = EncoderDecoder(config)
84 | >>> print(encdec)
85 | ```
86 |
87 | It takes only several lines of code to create a RetNet model:
88 |
89 | ```python
90 | # Creating a RetNet model
91 | >>> import torch
92 | >>> from torchscale.architecture.config import RetNetConfig
93 | >>> from torchscale.architecture.retnet import RetNetDecoder
94 |
95 | >>> config = RetNetConfig(vocab_size=64000)
96 | >>> retnet = RetNetDecoder(config)
97 |
98 | >>> print(retnet)
99 | ```
100 |
101 | For LongNet models ([Flash Attention](https://github.com/Dao-AILab/flash-attention) required):
102 | ```python
103 | >>> import torch
104 | >>> from torchscale.architecture.config import EncoderConfig, DecoderConfig
105 | >>> from torchscale.model.longnet import LongNetEncoder, LongNetDecoder
106 |
107 | # Creating a LongNet encoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2]
108 | >>> config = EncoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True)
109 | >>> longnet = LongNetEncoder(config)
110 |
111 | # Creating a LongNet decoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2]
112 | >>> config = DecoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True)
113 | >>> longnet = LongNetDecoder(config)
114 | ```
115 |
116 | ## Key Features
117 |
118 | - [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555)
119 | * enabled by setting *deepnorm=True* in the `Config` class.
120 | * It adjusts both the residual connection and the initialization method according to the model architecture (i.e., encoder, decoder, or encoder-decoder).
121 |
122 | - [SubLN for the model generality and the training stability](https://arxiv.org/abs/2210.06423)
123 | * enabled by *subln=True*. This is enabled by default.
124 | * It introduces another LayerNorm to each sublayer and adjusts the initialization according to the model architecture.
125 | * Note that SubLN and DeepNorm cannot be used in one single model.
126 |
127 | - [X-MoE: efficient and finetunable sparse MoE modeling](https://arxiv.org/abs/2204.09179)
128 | * enabled by *use_xmoe=True*.
129 | * It replaces every *'moe_freq'* `FeedForwardNetwork` layers with the X-MoE layers.
130 |
131 | - [Multiway architecture for multimodality](https://arxiv.org/abs/2208.10442)
132 | * enabled by *multiway=True*.
133 | * It provides a pool of Transformer's parameters used for different modalities.
134 |
135 | - [Extrapolatable position embedding (Xpos)](https://arxiv.org/abs/2212.10554)
136 | * enabled by *xpos_rel_pos=True*.
137 |
138 | - [Relative position bias](https://arxiv.org/abs/1910.10683)
139 | * enabled by adjusting *rel_pos_buckets* and *max_rel_pos*.
140 |
141 | - [SparseClip: improving the gradient clipping for sparse MoE models](https://arxiv.org/abs/2211.13184)
142 | * we provide a [sample code](examples/fairseq/utils/sparse_clip.py) that can be easily adapted to the FairSeq (or other) repo.
143 |
144 | - [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/abs/2307.08621)
145 | * created by `config = RetNetConfig(vocab_size=64000)` and `retnet = RetNetDecoder(config)`.
146 |
147 | - [LongNet: Scaling Transformers to 1,000,000,000 Tokens](https://arxiv.org/abs/2307.02486)
148 |
149 | Most of the features above can be used by simply passing the corresponding parameters to the config. For example:
150 |
151 | ```python
152 | >>> from torchscale.architecture.config import EncoderConfig
153 | >>> from torchscale.architecture.encoder import Encoder
154 |
155 | >>> config = EncoderConfig(vocab_size=64000, deepnorm=True, multiway=True)
156 | >>> model = Encoder(config)
157 |
158 | >>> print(model)
159 | ```
160 |
161 | ## Examples
162 |
163 | We have examples of how to use TorchScale in the following scenarios/tasks:
164 |
165 | - Language
166 |
167 | * [Decoder/GPT](examples/fairseq/README.md#example-gpt-pretraining)
168 |
169 | * [Encoder-Decoder/Neural Machine Translation](examples/fairseq/README.md#example-machine-translation)
170 |
171 | * [Encoder/BERT](examples/fairseq/README.md#example-bert-pretraining)
172 |
173 | - Vision
174 |
175 | * [LongViT](examples/longvit/README.md)
176 |
177 | * ViT/BEiT [In progress]
178 |
179 | - Speech
180 |
181 | - Multimodal
182 |
183 | * [Multiway Transformers/BEiT-3](https://github.com/microsoft/unilm/tree/master/beit3)
184 |
185 | We plan to provide more examples regarding different tasks (e.g. vision pretraining and speech recognition) and various deep learning toolkits (e.g. [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). Any comments or PRs are welcome!
186 |
187 |
188 | ## Acknowledgments
189 |
190 | Some implementations in TorchScale are either adapted from or inspired by the [FairSeq](https://github.com/facebookresearch/fairseq) repository and the [UniLM](https://github.com/microsoft/unilm) repository.
191 |
192 | ## Citations
193 |
194 | If you find this repository useful, please consider citing our work:
195 |
196 | ```
197 | @article{torchscale,
198 | author = {Shuming Ma and Hongyu Wang and Shaohan Huang and Wenhui Wang and Zewen Chi and Li Dong and Alon Benhaim and Barun Patra and Vishrav Chaudhary and Xia Song and Furu Wei},
199 | title = {{TorchScale}: {Transformers} at Scale},
200 | journal = {CoRR},
201 | volume = {abs/2211.13184},
202 | year = {2022}
203 | }
204 | ```
205 |
206 | ```
207 | @article{deepnet,
208 | author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei},
209 | title = {{DeepNet}: Scaling {Transformers} to 1,000 Layers},
210 | journal = {CoRR},
211 | volume = {abs/2203.00555},
212 | year = {2022},
213 | }
214 | ```
215 |
216 | ```
217 | @article{magneto,
218 | author = {Hongyu Wang and Shuming Ma and Shaohan Huang and Li Dong and Wenhui Wang and Zhiliang Peng and Yu Wu and Payal Bajaj and Saksham Singhal and Alon Benhaim and Barun Patra and Zhun Liu and Vishrav Chaudhary and Xia Song and Furu Wei},
219 | title = {Foundation {Transformers}},
220 | journal = {CoRR},
221 | volume = {abs/2210.06423},
222 | year = {2022}
223 | }
224 | ```
225 |
226 | ```
227 | @inproceedings{xmoe,
228 | title={On the Representation Collapse of Sparse Mixture of Experts},
229 | author={Zewen Chi and Li Dong and Shaohan Huang and Damai Dai and Shuming Ma and Barun Patra and Saksham Singhal and Payal Bajaj and Xia Song and Xian-Ling Mao and Heyan Huang and Furu Wei},
230 | booktitle={Advances in Neural Information Processing Systems},
231 | year={2022},
232 | url={https://openreview.net/forum?id=mWaYC6CZf5}
233 | }
234 | ```
235 |
236 | ```
237 | @article{retnet,
238 | author={Yutao Sun and Li Dong and Shaohan Huang and Shuming Ma and Yuqing Xia and Jilong Xue and Jianyong Wang and Furu Wei},
239 | title = {Retentive Network: A Successor to {Transformer} for Large Language Models},
240 | journal = {ArXiv},
241 | volume = {abs/2307.08621},
242 | year = {2023}
243 | }
244 | ```
245 |
246 | ```
247 | @article{longnet,
248 | author={Jiayu Ding and Shuming Ma and Li Dong and Xingxing Zhang and Shaohan Huang and Wenhui Wang and Nanning Zheng and Furu Wei},
249 | title = {{LongNet}: Scaling Transformers to 1,000,000,000 Tokens},
250 | journal = {ArXiv},
251 | volume = {abs/2307.02486},
252 | year = {2023}
253 | }
254 | ```
255 |
256 | ```
257 | @article{longvit,
258 | title = {When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology},
259 | author = {Wenhui Wang and Shuming Ma and Hanwen Xu and Naoto Usuyama and Jiayu Ding and Hoifung Poon and Furu Wei},
260 | journal = {ArXiv},
261 | volume = {abs/2312.03558},
262 | year = {2023}
263 | }
264 | ```
265 |
266 | ## Contributing
267 |
268 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
269 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
270 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
271 |
272 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide
273 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
274 | provided by the bot. You will only need to do this once across all repos using our CLA.
275 |
276 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
277 | For more information, see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
278 | contact [Furu Wei](mailto:fuwei@microsoft.com) and [Shuming Ma](mailto:shumma@microsoft.com) with any additional questions or comments.
279 |
280 | ## Trademarks
281 |
282 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
283 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
284 | Any use of third-party trademarks or logos is subject to those third-party's policies.
285 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/examples/fairseq/README.md:
--------------------------------------------------------------------------------
1 | # Example: Integration with FairSeq
2 |
3 | ## Setup
4 |
5 | ```bash
6 | # Install the repo as a package:
7 | git clone https://github.com/microsoft/torchscale.git
8 | cd torchscale
9 | pip install -e .
10 | pip install git+https://github.com/shumingma/fairseq.git@moe
11 | pip install git+https://github.com/shumingma/infinibatch.git
12 | pip install iopath
13 | pip install numpy==1.23.0
14 | ```
15 |
16 | ## Example: BERT Pretraining
17 |
18 | ### Data Format
19 |
20 | We use a [streaming dataloader](https://github.com/microsoft/infinibatch) to read the data on-the-fly from the disk. It requires the data sharded into multiple small files (e.g. 10K lines per file), as well as a JSON file to contain some meta data and the paths to these files.
21 |
22 | The overall data directory should be organized as follows:
23 | ```
24 | Data/
25 | ├── json/
26 | │ ├── train.json
27 | │ └── valid.json
28 | ├── shard/
29 | │ ├── train/
30 | │ │ ├── 00000.txt
31 | │ │ ├── 00001.txt
32 | │ │ └── ...
33 | │ └── valid/
34 | │ ├── 00000.txt
35 | │ ├── 00001.txt
36 | │ └── ...
37 | ├── dict.txt
38 | └── sentencepiece.bpe.model
39 | ```
40 |
41 | We recommend that each sharded data files contains no more than 10K lines with one sentence per line, and two documents should be separated with an empty line.
42 | ```
43 | Document 1 Line 1
44 | Document 1 Line 2
45 | Document 1 Line 3
46 |
47 | Document 2 Line 1
48 | Document 2 Line 2
49 |
50 | ...
51 | ```
52 |
53 | Also, the JSON file should be in the format like this:
54 | ```
55 | [
56 | {
57 | "source": [
58 | "shard/train/00000.txt",
59 | "shard/train/00001.txt",
60 | ...
61 | ],
62 | "source_lang": "en",
63 | "weight": 1.0
64 | }
65 | ]
66 | ```
67 |
68 | You can quickly get started with our processed vocabulary files: [sentencepiece.bpe.model] and [dict.txt]. Note that this vocabulary is English-only with 64K tokens. To train a new `sentencepiece.bpe.model` on your own data, please refer to the [SentencePiece](https://github.com/google/sentencepiece) repo. With the sentecepiece model and the installed `sentencepiece` library, you can extract the `dict.txt` file from it by
69 | ```
70 | spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt
71 | ```
72 |
73 | ### Dense Model
74 | ```bash
75 | cd examples/fairseq/
76 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \
77 | --task pretraining \
78 | --tokens-per-sample 512 \
79 | --mask-prob 0.15 \
80 | --span-length 3.0 \
81 | --leave-unmasked-prob 0.0 \
82 | --random-token-prob 0.0 \
83 | --criterion masked_lm \
84 | --arch mlm_base \
85 | --share-encoder-input-output-embed \
86 | --required-batch-size-multiple 8 \
87 | --spm-model ${PATH_TO_DATA}/sentencepiece.bpe.model \
88 | --dict-file ${PATH_TO_DATA}/dict.txt \
89 | --optimizer adam \
90 | --adam-betas '(0.9,0.98)' \
91 | --adam-eps 1e-6 \
92 | --clip-norm 2.0 \
93 | --lr-scheduler polynomial_decay \
94 | --lr 0.0005 \
95 | --warmup-updates 10000 \
96 | --total-num-update 125000 \
97 | --max-update 125000 \
98 | --max-sentences 32 \
99 | --update-freq 1 \
100 | --log-format simple \
101 | --log-interval 100 \
102 | --disable-validation \
103 | --save-interval-updates 5000 \
104 | --no-epoch-checkpoints \
105 | --fp16 \
106 | --fp16-init-scale 4 \
107 | --fp16-scale-window 256 \
108 | --min-loss-scale 0.0001 \
109 | --seed 1 \
110 | --save-dir ${PATH_TO_CKPT} \
111 | --ddp-backend=no_c10d \
112 | --distributed-no-spawn \
113 | --reset-dataloader \
114 | --batch-read-ahead 10000 \
115 | --rel-pos-buckets 32 \
116 | --max-rel-pos 128 \
117 | --deepnorm
118 | ```
119 |
120 | ### Sparse (MoE) Model
121 | ```bash
122 | cd examples/fairseq/
123 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \
124 | --task pretraining \
125 | --tokens-per-sample 512 \
126 | --mask-prob 0.15 \
127 | --span-length 3.0 \
128 | --leave-unmasked-prob 0.0 \
129 | --random-token-prob 0.0 \
130 | --arch mlm_base \
131 | --share-encoder-input-output-embed \
132 | --required-batch-size-multiple 8 \
133 | --spm-model ${PATH_TO_DATA}/sentencepiece.bpe.model \
134 | --dict-file ${PATH_TO_DATA}/dict.txt \
135 | --optimizer adam \
136 | --adam-betas '(0.9,0.98)' \
137 | --adam-eps 1e-6 \
138 | --clip-norm 2.0 \
139 | --lr-scheduler polynomial_decay \
140 | --lr 0.0005 \
141 | --warmup-updates 10000 \
142 | --total-num-update 125000 \
143 | --max-update 125000 \
144 | --max-sentences 32 \
145 | --update-freq 1 \
146 | --log-format simple \
147 | --log-interval 100 \
148 | --disable-validation \
149 | --save-interval-updates 5000 \
150 | --no-epoch-checkpoints \
151 | --fp16 \
152 | --fp16-init-scale 4 \
153 | --fp16-scale-window 256 \
154 | --min-loss-scale 0.0001 \
155 | --seed 1 \
156 | --save-dir ${PATH_TO_CKPT} \
157 | --ddp-backend=no_c10d \
158 | --distributed-no-spawn \
159 | --reset-dataloader \
160 | --batch-read-ahead 10000 \
161 | --rel-pos-buckets 32 \
162 | --max-rel-pos 128 \
163 | --deepnorm \
164 | --moe-expert-count 64 --moe-freq 2 \
165 | --moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
166 | --moe-eval-capacity-token-fraction -1.0 \
167 | --criterion masked_lm_moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
168 | --use-xmoe --pad-to-max-length
169 | ```
170 |
171 | ## Example: GPT Pretraining
172 |
173 | ### Data Format
174 |
175 | We use the format as in the FairSeq's [language modeling example](https://github.com/facebookresearch/fairseq/tree/main/examples/language_model#1-preprocess-the-data).
176 |
177 | ### Dense Model
178 |
179 | ```bash
180 | cd examples/fairseq/
181 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
182 | ${PATH_TO_DATA} \
183 | --num-workers 2 \
184 | --activation-fn gelu \
185 | --share-decoder-input-output-embed \
186 | --validate-interval-updates 1000 \
187 | --save-interval-updates 1000 \
188 | --no-epoch-checkpoints \
189 | --memory-efficient-fp16 \
190 | --fp16-init-scale 4 \
191 | --arch lm_base \
192 | --task language_modeling \
193 | --sample-break-mode none \
194 | --tokens-per-sample 128 \
195 | --optimizer adam --adam-betas "(0.9, 0.98)" \
196 | --adam-eps 1e-08 \
197 | --clip-norm 0.0 \
198 | --lr 5e-4 \
199 | --lr-scheduler polynomial_decay \
200 | --warmup-updates 750 \
201 | --dropout 0.1 \
202 | --attention-dropout 0.1 \
203 | --weight-decay 0.01 \
204 | --batch-size 4 \
205 | --update-freq 1 \
206 | --required-batch-size-multiple 1 \
207 | --total-num-update 50000 \
208 | --max-update 50000 \
209 | --seed 1 \
210 | --ddp-backend=c10d
211 | ```
212 |
213 | ### Sparse (MoE) Model
214 |
215 | ```bash
216 | cd examples/fairseq/
217 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
218 | ${PATH_TO_DATA} \
219 | --num-workers 2 \
220 | --activation-fn gelu \
221 | --share-decoder-input-output-embed \
222 | --validate-interval-updates 1000 \
223 | --save-interval-updates 1000 \
224 | --no-epoch-checkpoints \
225 | --memory-efficient-fp16 \
226 | --fp16-init-scale 4 \
227 | --arch lm_base \
228 | --task language_modeling \
229 | --sample-break-mode none \
230 | --tokens-per-sample 128 \
231 | --optimizer adam --adam-betas "(0.9, 0.98)" \
232 | --adam-eps 1e-08 \
233 | --clip-norm 0.0 \
234 | --lr 5e-4 \
235 | --lr-scheduler polynomial_decay \
236 | --warmup-updates 750 \
237 | --dropout 0.1 \
238 | --attention-dropout 0.1 \
239 | --weight-decay 0.01 \
240 | --batch-size 4 \
241 | --update-freq 1 \
242 | --required-batch-size-multiple 1 \
243 | --total-num-update 50000 \
244 | --max-update 50000 \
245 | --seed 1 \
246 | --ddp-backend=no_c10d \
247 | --moe-expert-count 2 --moe-freq 2 \
248 | --moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
249 | --moe-eval-capacity-token-fraction -1.0 \
250 | --criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
251 | --use-xmoe
252 | ```
253 |
254 | ### LongNet Model
255 |
256 | ```bash
257 | cd examples/fairseq/
258 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
259 | ${PATH_TO_DATA} \
260 | --num-workers 2 \
261 | --activation-fn gelu \
262 | --share-decoder-input-output-embed \
263 | --validate-interval-updates 1000 \
264 | --save-interval-updates 1000 \
265 | --no-epoch-checkpoints \
266 | --memory-efficient-fp16 \
267 | --fp16-init-scale 4 \
268 | --arch lm_base \
269 | --task language_modeling \
270 | --sample-break-mode none \
271 | --tokens-per-sample 4096 \
272 | --optimizer adam --adam-betas "(0.9, 0.98)" \
273 | --adam-eps 1e-08 \
274 | --clip-norm 0.0 \
275 | --lr 5e-4 \
276 | --lr-scheduler polynomial_decay \
277 | --warmup-updates 750 \
278 | --dropout 0.1 \
279 | --attention-dropout 0.1 \
280 | --weight-decay 0.01 \
281 | --batch-size 4 \
282 | --update-freq 1 \
283 | --required-batch-size-multiple 1 \
284 | --total-num-update 50000 \
285 | --max-update 50000 \
286 | --seed 1 \
287 | --ddp-backend=c10d \
288 | --flash-attention \
289 | --segment-length [2048,4096] \
290 | --dilated-ratio [1,2]
291 | ```
292 |
293 | ## Example: Machine Translation
294 |
295 | ### Data Format
296 |
297 | We follow the FairSeq's [neural machine translation example](https://github.com/facebookresearch/fairseq/tree/main/examples/translation#training-a-new-model) to preprocess the data.
298 |
299 | ### Dense Model
300 |
301 | ```bash
302 | cd examples/fairseq/
303 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
304 | ${PATH_TO_DATA} \
305 | --arch mt_base --share-decoder-input-output-embed \
306 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
307 | --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
308 | --dropout 0.3 --weight-decay 0.0001 \
309 | --max-tokens 4096 --fp16
310 | ```
311 |
312 | ### Sparse (MoE) Model
313 |
314 | ```bash
315 | cd examples/fairseq/
316 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
317 | ${PATH_TO_DATA} \
318 | --arch mt_base --share-decoder-input-output-embed \
319 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
320 | --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
321 | --dropout 0.3 --weight-decay 0.0001 \
322 | --moe-expert-count 2 --moe-freq 2 \
323 | --moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
324 | --moe-eval-capacity-token-fraction -1.0 \
325 | --criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
326 | --use-xmoe \
327 | --max-tokens 4096 --fp16
328 | ```
329 |
--------------------------------------------------------------------------------
/examples/fairseq/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/examples/fairseq/criterions/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | # automatically import any Python files in the criterions/ directory
5 | for file in sorted(os.listdir(os.path.dirname(__file__))):
6 | if file.endswith(".py") and not file.startswith("_"):
7 | file_name = file[: file.find(".py")]
8 | importlib.import_module("criterions." + file_name)
--------------------------------------------------------------------------------
/examples/fairseq/criterions/masked_lm_moe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import math
7 | import torch
8 | import torch.nn.functional as F
9 | from fairseq import metrics, utils
10 | from fairseq.criterions import MoECriterion, register_criterion, MoECriterionConfig
11 |
12 |
13 | @register_criterion("masked_lm_moe_cross_entropy", dataclass=MoECriterionConfig)
14 | class MaskedLMMoECrossEntropyCriterion(MoECriterion):
15 |
16 | def compute_inner_loss(self, model, sample, reduce=True):
17 | masked_tokens = sample["target"].ne(self.padding_idx)
18 | sample_size = masked_tokens.int().sum()
19 |
20 | masked_tokens = torch.where(
21 | masked_tokens.any(),
22 | masked_tokens,
23 | masked_tokens.new([True]),
24 | )
25 |
26 | net_output = model(**sample["net_input"], masked_tokens=masked_tokens)
27 | lprobs = model.get_normalized_probs(net_output, log_probs=True)
28 | lprobs = lprobs.view(-1, lprobs.size(-1))
29 | target = model.get_targets(sample, net_output)
30 |
31 | if masked_tokens is not None:
32 | target = target[masked_tokens]
33 |
34 | nll_loss = F.nll_loss(
35 | lprobs,
36 | target.view(-1),
37 | ignore_index=self.padding_idx,
38 | reduction="sum" if reduce else "none",
39 | )
40 | logging_output = {
41 | "inner_loss": nll_loss.data,
42 | "ntokens": sample["ntokens"],
43 | "nsentences": sample["target"].size(0),
44 | "sample_size": sample_size,
45 | }
46 | return net_output, nll_loss, sample_size, logging_output
47 |
48 | @staticmethod
49 | def reduce_metrics(logging_outputs) -> None:
50 | """Aggregate logging outputs from data parallel training."""
51 | MaskedLMMoECrossEntropyCriterion.reduce_moe_metrics(logging_outputs)
52 |
53 | loss_sum = sum(log.get("inner_loss", 0) for log in logging_outputs)
54 | ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
55 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
56 |
57 | # we divide by log(2) to convert the loss from base e to base 2
58 | metrics.log_scalar(
59 | "inner_loss", loss_sum / sample_size / math.log(2), sample_size, round=3
60 | )
61 | if sample_size != ntokens:
62 | metrics.log_scalar(
63 | "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
64 | )
65 | metrics.log_derived(
66 | "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
67 | )
68 | else:
69 | metrics.log_derived(
70 | "ppl", lambda meters: utils.get_perplexity(meters["inner_loss"].avg)
71 | )
--------------------------------------------------------------------------------
/examples/fairseq/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | # flake8: noqa
5 | import models
6 | import tasks
7 | import criterions
8 | from fairseq_cli.generate import cli_main
9 |
10 | if __name__ == "__main__":
11 | cli_main()
12 |
--------------------------------------------------------------------------------
/examples/fairseq/interactive.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | # flake8: noqa
5 | import models
6 | import tasks
7 | import criterions
8 | from fairseq_cli.interactive import cli_main
9 |
10 | if __name__ == "__main__":
11 | cli_main()
12 |
--------------------------------------------------------------------------------
/examples/fairseq/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import argparse
5 | import importlib
6 | import os
7 |
8 | MODEL_REGISTRY = {}
9 | MODEL_DATACLASS_REGISTRY = {}
10 | ARCH_MODEL_REGISTRY = {}
11 | ARCH_MODEL_NAME_REGISTRY = {}
12 | ARCH_MODEL_INV_REGISTRY = {}
13 | ARCH_CONFIG_REGISTRY = {}
14 |
15 | # automatically import any Python files in the models/ directory
16 | models_dir = os.path.dirname(__file__)
17 | for file in os.listdir(models_dir):
18 | path = os.path.join(models_dir, file)
19 | if (
20 | not file.startswith("_")
21 | and not file.startswith(".")
22 | and (file.endswith(".py") or os.path.isdir(path))
23 | ):
24 | model_name = file[: file.find(".py")] if file.endswith(".py") else file
25 | module = importlib.import_module("models." + model_name)
26 |
27 | # extra `model_parser` for sphinx
28 | if model_name in MODEL_REGISTRY:
29 | parser = argparse.ArgumentParser(add_help=False)
30 | group_archs = parser.add_argument_group("Named architectures")
31 | group_archs.add_argument(
32 | "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name]
33 | )
34 | group_args = parser.add_argument_group("Additional command-line arguments")
35 | MODEL_REGISTRY[model_name].add_args(group_args)
36 | globals()[model_name + "_parser"] = parser
37 |
--------------------------------------------------------------------------------
/examples/fairseq/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import argparse
5 | import importlib
6 | import os
7 |
8 | # register dataclass
9 | TASK_DATACLASS_REGISTRY = {}
10 | TASK_REGISTRY = {}
11 | TASK_CLASS_NAMES = set()
12 |
13 | # automatically import any Python files in the tasks/ directory
14 | tasks_dir = os.path.dirname(__file__)
15 | for file in os.listdir(tasks_dir):
16 | path = os.path.join(tasks_dir, file)
17 | if (
18 | not file.startswith("_")
19 | and not file.startswith(".")
20 | and (file.endswith(".py") or os.path.isdir(path))
21 | ):
22 | task_name = file[: file.find(".py")] if file.endswith(".py") else file
23 | module = importlib.import_module("tasks." + task_name)
24 |
25 | # expose `task_parser` for sphinx
26 | if task_name in TASK_REGISTRY:
27 | parser = argparse.ArgumentParser(add_help=False)
28 | group_task = parser.add_argument_group("Task name")
29 | # fmt: off
30 | group_task.add_argument('--task', metavar=task_name,
31 | help='Enable this task with: ``--task=' + task_name + '``')
32 | # fmt: on
33 | group_args = parser.add_argument_group("Additional command-line arguments")
34 | TASK_REGISTRY[task_name].add_args(group_args)
35 | globals()[task_name + "_parser"] = parser
36 |
--------------------------------------------------------------------------------
/examples/fairseq/tasks/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/examples/fairseq/tasks/data/basic_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch
5 | from infinibatch.iterators import CheckpointableIterator
6 |
7 | from . import utils
8 |
9 |
10 | class BaseBatchGen(CheckpointableIterator):
11 | """
12 | This is a base class for batch generators that use infinibatch
13 | """
14 |
15 | def __init__(self):
16 | self._iter = None
17 | self.epoch = 1
18 | self.next_epoch_idx = 1
19 | self.sharded_checkpoint = True
20 | self.should_close_after_finished = True
21 |
22 | def _build_iter(self):
23 | """
24 | Build infinibatch iterator and assign to self._iter
25 | """
26 | raise NotImplementedError()
27 |
28 | def _move_to_tensor(self, batch):
29 | def to_tensor(x):
30 | return torch.tensor(x)
31 |
32 | return utils.apply_to_sample(to_tensor, batch)
33 |
34 | @property
35 | def iterator(self):
36 | if self._iter is None:
37 | raise NotImplementedError("_build_iter() must called first")
38 | return self._iter
39 |
40 | def __iter__(self):
41 | if self._iter is None:
42 | raise NotImplementedError("_build_iter() must called first")
43 | return self._iter
44 |
45 | def __next__(self):
46 | return next(self._iter)
47 |
48 | def setstate(self, value):
49 | self._iter.setstate(value)
50 |
51 | def getstate(self):
52 | return self._iter.getstate()
53 |
54 | def close(self):
55 | self._iter.close()
56 |
57 | def __len__(self) -> int:
58 | return 819200000
59 |
60 | def next_epoch_itr(
61 | self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
62 | ):
63 | return self
64 |
65 | def end_of_epoch(self) -> bool:
66 | return False
67 |
68 | def state_dict(self):
69 | """Returns a dictionary containing a whole state of the iterator."""
70 | return self.getstate()
71 |
72 | def load_state_dict(self, state_dict):
73 | """Copies the state of the iterator from the given *state_dict*."""
74 | self.setstate(state_dict)
75 |
76 | @property
77 | def first_batch(self):
78 | return "DUMMY"
79 |
--------------------------------------------------------------------------------
/examples/fairseq/tasks/data/mlm_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import copy
5 | import itertools
6 | import os
7 |
8 | import numpy as np
9 | from infinibatch import iterators
10 |
11 | from .basic_loader import BaseBatchGen
12 | from .utils import NativeCheckpointableIterator, WeightIterator
13 |
14 |
15 | class MLMLoader(BaseBatchGen):
16 | def __init__(
17 | self,
18 | args,
19 | dataset,
20 | dictionary,
21 | tokenizer,
22 | max_tokens=None,
23 | max_sentences=None,
24 | max_positions=None,
25 | ignore_invalid_inputs=False,
26 | required_batch_size_multiple=1,
27 | seed=1,
28 | num_shards=1,
29 | shard_id=0,
30 | ):
31 | super().__init__()
32 | self.args = args
33 | self.data = dataset.data
34 | self.data_dir = dataset.data_dir
35 | self.shuffle = dataset.shuffle
36 | self.dictionary = dictionary
37 | self.tokenizer = tokenizer
38 |
39 | self.max_tokens = max_tokens
40 | self.max_sentences = max_sentences
41 | self.max_positions = max_positions
42 | self.tokens_per_sample = args.tokens_per_sample
43 | self.sample_break_mode = args.sample_break_mode
44 | self.ignore_invalid_inputs = ignore_invalid_inputs
45 | self.required_batch_size_multiple = required_batch_size_multiple
46 | self.seed = str(seed)
47 | self.num_shards = num_shards
48 | self.shard_id = shard_id
49 |
50 | self.batch_read_ahead = args.batch_read_ahead
51 |
52 | self._build_iter()
53 |
54 | def _build_iter(self):
55 | tokenized_lines = self._multilingual_tokenize()
56 | self.padded_batches = self._batchify(tokenized_lines)
57 |
58 | prefetch_batches = iterators.PrefetchIterator(
59 | self.padded_batches,
60 | buffer_size=10000,
61 | buffer_in_main_process=True,
62 | log_empty_buffer_warning=True and self.shard_id == 0,
63 | )
64 |
65 | prefetch_batches = iterators.MapIterator(prefetch_batches, self._move_to_tensor)
66 |
67 | self._iter = prefetch_batches
68 |
69 | def _multilingual_tokenize(self):
70 | multilingual_iters = []
71 | weights = []
72 |
73 | for data in self.data:
74 | multilingual_iters.append(self._tokenize(data))
75 | if "weight" in data:
76 | weights.append(float(data["weight"]))
77 | else:
78 | weights.append(int(data["count"]))
79 |
80 | if len(multilingual_iters) == 1:
81 | return multilingual_iters[0]
82 |
83 | sampling_iterator = WeightIterator(weights)
84 | control_iterator = NativeCheckpointableIterator(sampling_iterator)
85 | tokenized_lines = iterators.MultiplexIterator(
86 | control_iterator, multilingual_iters
87 | )
88 |
89 | return tokenized_lines
90 |
91 | def _tokenize(self, data):
92 | """
93 | data:
94 | {
95 | 'source': list[Path],
96 | 'source_lang': str,
97 | 'count': int,
98 | 'weight': float,
99 | 'name': str,
100 | }
101 | """
102 | dataset = list(
103 | zip(
104 | data["source"],
105 | itertools.repeat(data["source_lang"]),
106 | )
107 | )
108 |
109 | if self.shuffle:
110 | chunk_files = iterators.InfinitePermutationSourceIterator(
111 | dataset,
112 | seed=self.seed,
113 | shuffle=self.shuffle,
114 | num_instances=self.num_shards,
115 | instance_rank=self.shard_id,
116 | )
117 | else:
118 | chunk_files = iterators.ChunkedSourceIterator(
119 | dataset,
120 | num_instances=self.num_shards,
121 | instance_rank=self.shard_id,
122 | )
123 |
124 | tokenized_lines = iterators.SelectManyIterator(
125 | chunk_files, lambda files: self._read_from_files(*files)
126 | )
127 | tokenized_lines = iterators.SamplingRandomMapIterator(
128 | tokenized_lines, self._prepare, self.seed
129 | )
130 |
131 | return tokenized_lines
132 |
133 | def _batchify(self, lines):
134 |
135 | if self.max_sentences is not None:
136 | if self.batch_read_ahead > 0:
137 | lines = iterators.BlockwiseShuffleIterator(
138 | lines, self.batch_read_ahead, self.seed
139 | )
140 | batches = iterators.FixedBatchIterator(lines, self.max_sentences)
141 | else:
142 |
143 | def dynamic_batch_size(sample):
144 | lengths = [len(x) for x in sample]
145 | batch_size = self.max_tokens // max(lengths)
146 | batch_size = (
147 | batch_size
148 | // self.required_batch_size_multiple
149 | * self.required_batch_size_multiple
150 | )
151 | return max(1, batch_size)
152 |
153 | batches = iterators.BucketedReadaheadBatchIterator(
154 | lines,
155 | read_ahead=self.batch_read_ahead,
156 | key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
157 | batch_size=dynamic_batch_size,
158 | shuffle=self.shuffle,
159 | seed=self.seed,
160 | )
161 |
162 | def collate(batch):
163 | batch_size = len(batch)
164 |
165 | mlm_source_max_length = max([len(x[0]) for x in batch])
166 | mlm_target_max_length = max([len(x[1]) for x in batch])
167 | s2s_source_max_length = max([len(x[2]) for x in batch])
168 | s2s_target_max_length = max([len(x[3]) for x in batch])
169 | if self.args.pad_to_max_length:
170 | mlm_source_max_length = self.args.tokens_per_sample
171 | mlm_target_max_length = self.args.tokens_per_sample
172 |
173 | mlm_source_ids = np.full(
174 | shape=(batch_size, mlm_source_max_length),
175 | dtype=np.int32,
176 | fill_value=self.dictionary.pad(),
177 | )
178 | mlm_target_ids = np.full(
179 | shape=(batch_size, mlm_target_max_length),
180 | dtype=np.int32,
181 | fill_value=self.dictionary.pad(),
182 | )
183 | s2s_source_ids = np.full(
184 | shape=(batch_size, s2s_source_max_length),
185 | dtype=np.int32,
186 | fill_value=self.dictionary.pad(),
187 | )
188 | s2s_target_ids = np.full(
189 | shape=(batch_size, s2s_target_max_length - 1),
190 | dtype=np.int32,
191 | fill_value=self.dictionary.pad(),
192 | )
193 | s2s_prev_input_ids = np.full(
194 | shape=(batch_size, s2s_target_max_length - 1),
195 | dtype=np.int32,
196 | fill_value=self.dictionary.pad(),
197 | )
198 |
199 | for i, (
200 | mlm_input_ids,
201 | mlm_label_ids,
202 | s2s_input_ids,
203 | s2s_label_ids,
204 | ) in enumerate(batch):
205 | mlm_source_ids[i, : len(mlm_input_ids)] = mlm_input_ids
206 | mlm_target_ids[i, : len(mlm_label_ids)] = mlm_label_ids
207 | s2s_source_ids[i, : len(s2s_input_ids)] = s2s_input_ids
208 | s2s_target_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[1:]
209 | s2s_prev_input_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[:-1]
210 |
211 | ret_batch = {
212 | "net_input": {
213 | "src_tokens": mlm_source_ids.astype(np.int64),
214 | },
215 | "target": mlm_target_ids.astype(np.int64),
216 | "nsentences": batch_size,
217 | "ntokens": sum([len(x[0]) for x in batch]),
218 | }
219 |
220 | return ret_batch
221 |
222 | padded_batches = iterators.MapIterator(batches, collate)
223 |
224 | return padded_batches
225 |
226 | def _prepare(self, _random, doc):
227 | nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc)
228 | nonnoise_spans, noise_spans = self._span_corruption(_random, doc)
229 | return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans
230 |
231 | def _mask_lm(self, _random, doc):
232 | def mask_tokens():
233 | return ""
234 |
235 | length = len(doc)
236 | mask_tokens_num = int(length * self.args.mask_prob)
237 | mask_tokens_num = min(max(mask_tokens_num, 1), length - 1)
238 | possible_mask_positions = _random.sample(range(length), k=mask_tokens_num)
239 | possible_mask_positions = sorted(possible_mask_positions)
240 |
241 | nonmasked_tokens = copy.deepcopy(doc)
242 | masked_tokens = [self.dictionary.pad() for _ in range(len(doc))]
243 |
244 | for position in possible_mask_positions:
245 | # masked_tokens.append(nonmasked_tokens[position])
246 | masked_tokens[position] = nonmasked_tokens[position]
247 | nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()]
248 |
249 | return nonmasked_tokens, masked_tokens
250 |
251 | def _span_corruption(self, _random, doc):
252 | def mask_tokens(i):
253 | return f""
254 |
255 | length = len(doc)
256 | noise_tokens_num = int(length * self.args.mask_prob)
257 | noise_tokens_num = min(max(noise_tokens_num, 1), length - 1)
258 | noise_spans_num = int(noise_tokens_num / self.args.span_length)
259 | noise_spans_num = max(noise_spans_num, 1)
260 | nonnoise_tokens_num = length - noise_tokens_num
261 |
262 | if noise_spans_num == 1:
263 | noise_split_positions = [0, noise_tokens_num]
264 | else:
265 | possible_split_positions = list(range(1, noise_tokens_num))
266 | _random.shuffle(possible_split_positions)
267 | noise_split_positions = sorted(
268 | possible_split_positions[: noise_spans_num - 1]
269 | )
270 | noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
271 |
272 | possible_insert_positions = list(range(nonnoise_tokens_num))
273 | _random.shuffle(possible_insert_positions)
274 | noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num])
275 |
276 | nonnoise_spans, noise_spans = [], []
277 | last_end = 0
278 | for i in range(noise_spans_num):
279 | start_pos = noise_insert_positions[i] + noise_split_positions[i]
280 | end_pos = noise_insert_positions[i] + noise_split_positions[i + 1]
281 | mask_id = self.dictionary.indices[mask_tokens(i)]
282 |
283 | if getattr(self.args, "remove_target_sentinel", False):
284 | noise_spans.append(doc[start_pos:end_pos])
285 | else:
286 | noise_spans.append([mask_id] + doc[start_pos:end_pos])
287 |
288 | if getattr(self.args, "remove_source_sentinel", False):
289 | nonnoise_spans.extend(doc[last_end:start_pos])
290 | else:
291 | nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id])
292 |
293 | last_end = end_pos
294 |
295 | nonnoise_spans.extend(doc[last_end:])
296 | noise_spans = sum(noise_spans, [])
297 |
298 | return nonnoise_spans, noise_spans
299 |
300 | def _read_from_files(self, source_file, source_lang):
301 | # data = []
302 | file_path = os.path.join(self.data_dir, source_file)
303 |
304 | if not os.path.exists(file_path):
305 | print("| file {} not exists".format(file_path), flush=True)
306 | return iter([]) # skip bad file
307 |
308 | with open(file_path, "r", encoding="utf8") as f:
309 | lines = f.read().strip().split("\n")
310 |
311 | doc = [self.dictionary.bos()]
312 | for line in lines:
313 | if line == "":
314 | if self.sample_break_mode == "complete_doc":
315 | # data.append(doc)
316 | yield doc
317 | doc = [self.dictionary.bos()]
318 | continue
319 |
320 | tokenized_line = self.tokenizer.EncodeAsPieces(line)
321 | tokenized_id = [
322 | self.dictionary.index(token) for token in tokenized_line
323 | ] + [self.dictionary.eos_index]
324 |
325 | if len(tokenized_id) > self.tokens_per_sample:
326 | continue
327 | if len(doc) + len(tokenized_id) > self.tokens_per_sample:
328 | # data.append(doc)
329 | yield doc
330 | doc = [self.dictionary.bos()]
331 | doc.extend(tokenized_id)
332 |
333 | if len(doc) > 1 and len(doc) <= self.tokens_per_sample:
334 | # data.append(doc)
335 | yield doc
336 |
337 | # return data
338 |
--------------------------------------------------------------------------------
/examples/fairseq/tasks/data/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import collections
5 | from random import Random
6 | from typing import Dict, Iterable, Optional
7 |
8 | import numpy as np
9 | from infinibatch import iterators
10 |
11 |
12 | def apply_to_sample(f, sample):
13 | if hasattr(sample, "__len__") and len(sample) == 0:
14 | return {}
15 |
16 | def _apply(x):
17 | if isinstance(x, np.ndarray):
18 | return f(x)
19 | elif isinstance(x, collections.OrderedDict):
20 | # OrderedDict has attributes that needs to be preserved
21 | od = collections.OrderedDict(
22 | (key, _apply(value)) for key, value in x.items()
23 | )
24 | od.__dict__ = x.__dict__
25 | return od
26 | elif isinstance(x, dict):
27 | return {key: _apply(value) for key, value in x.items()}
28 | elif isinstance(x, list):
29 | return [_apply(x) for x in x]
30 | elif isinstance(x, tuple):
31 | return tuple(_apply(x) for x in x)
32 | elif isinstance(x, set):
33 | return {_apply(x) for x in x}
34 | else:
35 | return x
36 |
37 | return _apply(sample)
38 |
39 |
40 | class NativeCheckpointableIterator(iterators.CheckpointableIterator):
41 | def __init__(self, iterable: Iterable):
42 | self._input_iterable = iterable
43 | self.setstate(None)
44 |
45 | def getstate(self) -> Dict:
46 | return {"num_items_yielded": self._num_items_yielded}
47 |
48 | def setstate(self, checkpoint: Optional[Dict]):
49 | self._iterator = iter(self._input_iterable)
50 | self._num_items_yielded = (
51 | iterators._advance_iterator(self._iterator, checkpoint["num_items_yielded"])
52 | if checkpoint is not None
53 | else 0
54 | )
55 |
56 | def __next__(self):
57 | item = next(self._iterator)
58 | self._num_items_yielded += 1
59 | return item
60 |
61 | def close(self):
62 | pass
63 |
64 |
65 | class WeightIterator(object):
66 | def __init__(self, weights, seed):
67 | self.weights = weights
68 | self.seed = seed
69 | self.control_index = list(range(len(weights)))
70 | self.setstate(None)
71 |
72 | def __iter__(self):
73 | return self
74 |
75 | def getstate(self):
76 | return {"random_state": self._random_state}
77 |
78 | def setstate(self, checkpoint):
79 | self._random_state = checkpoint["random_state"] if checkpoint else None
80 | self._random = (
81 | None # this will trigger the lazy initialization in self.__next__
82 | )
83 |
84 | def __next__(self):
85 | if self._random is None:
86 | self._random = Random(self.seed)
87 | if self._random_state is not None:
88 | self._random.setstate(self._random_state)
89 | idx = self._random.choices(self.control_index, self.weights)[0]
90 | self._random_state = self._random.getstate()
91 | return idx
92 |
93 | def close(self):
94 | pass
95 |
--------------------------------------------------------------------------------
/examples/fairseq/tasks/pretraining.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import json
5 | import logging
6 | import os
7 | from argparse import Namespace
8 |
9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | #
11 | # This source code is licensed under the MIT license found in the
12 | # LICENSE file in the root directory of this source tree.
13 | from dataclasses import dataclass, field
14 |
15 | import sentencepiece as spm
16 | from fairseq import utils
17 | from fairseq.data import Dictionary
18 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass
19 | from fairseq.tasks import FairseqTask, register_task
20 | from omegaconf import II, MISSING
21 |
22 | from .data.mlm_loader import MLMLoader
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
27 | SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
28 |
29 |
30 | @dataclass
31 | class PretrainingConfig(FairseqDataclass):
32 | data: str = field(
33 | default=MISSING,
34 | metadata={
35 | "help": "colon separated path to data directories list, \
36 | will be iterated upon during epochs in round-robin manner"
37 | },
38 | )
39 | sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
40 | default="complete",
41 | metadata={
42 | "help": 'If omitted or "none", fills each sample with tokens-per-sample '
43 | 'tokens. If set to "complete", splits samples only at the end '
44 | "of sentence, but may include multiple sentences per sample. "
45 | '"complete_doc" is similar but respects doc boundaries. '
46 | 'If set to "eos", includes only one sentence per sample.'
47 | },
48 | )
49 | tokens_per_sample: int = field(
50 | default=1024,
51 | metadata={"help": "max number of tokens per sample for LM dataset"},
52 | )
53 | mask_prob: float = field(
54 | default=0.15,
55 | metadata={"help": "probability of replacing a token with mask"},
56 | )
57 | leave_unmasked_prob: float = field(
58 | default=0.1,
59 | metadata={"help": "probability that a masked token is unmasked"},
60 | )
61 | random_token_prob: float = field(
62 | default=0.1,
63 | metadata={"help": "probability of replacing a token with a random token"},
64 | )
65 | freq_weighted_replacement: bool = field(
66 | default=False,
67 | metadata={"help": "sample random replacement words based on word frequencies"},
68 | )
69 | mask_whole_words: bool = field(
70 | default=False,
71 | metadata={"help": "mask whole words; you may also want to set --bpe"},
72 | )
73 | mask_multiple_length: int = field(
74 | default=1,
75 | metadata={"help": "repeat the mask indices multiple times"},
76 | )
77 | mask_stdev: float = field(
78 | default=0.0,
79 | metadata={"help": "stdev of the mask length"},
80 | )
81 | shorten_method: SHORTEN_METHOD_CHOICES = field(
82 | default="none",
83 | metadata={
84 | "help": "if not none, shorten sequences that exceed --tokens-per-sample"
85 | },
86 | )
87 | shorten_data_split_list: str = field(
88 | default="",
89 | metadata={
90 | "help": "comma-separated list of dataset splits to apply shortening to, "
91 | 'e.g., "train,valid" (default: all dataset splits)'
92 | },
93 | )
94 | seed: int = II("common.seed")
95 | span_length: float = field(
96 | default=3.0,
97 | metadata={"help": "average span length for masking"},
98 | )
99 | remove_source_sentinel: bool = field(
100 | default=False,
101 | metadata={"help": "remove the source sentinel for the span corruption task"},
102 | )
103 | remove_target_sentinel: bool = field(
104 | default=False,
105 | metadata={"help": "remove the target sentinel for the span corruption task"},
106 | )
107 | batch_read_ahead: int = field(
108 | default=100000,
109 | metadata={"help": "batch read ahead size for infinibatch"},
110 | )
111 | required_batch_size_multiple: int = II("dataset.required_batch_size_multiple")
112 | spm_model: str = field(
113 | default="",
114 | metadata={"help": "sentencepice model to tokenize the data"},
115 | )
116 | dict_file: str = field(
117 | default="",
118 | metadata={"help": ""},
119 | )
120 | pad_to_max_length: bool = field(
121 | default=False,
122 | )
123 |
124 |
125 | @register_task("pretraining", dataclass=PretrainingConfig)
126 | class PLMTask(FairseqTask):
127 | def __init__(self, cfg, dictionary, tokenizer):
128 | super().__init__(cfg)
129 | self.cfg = cfg
130 | self.dictionary = dictionary
131 | self.tokenizer = tokenizer
132 | self.seed = cfg.seed
133 | self.mask_idx = dictionary.index("")
134 |
135 | @classmethod
136 | def setup_task(cls, cfg, **kwargs):
137 | paths = utils.split_paths(cfg.data)
138 | assert len(paths) > 0
139 | if cfg.dict_file != "":
140 | dictionary = Dictionary.load(cfg.dict_file)
141 | else:
142 | dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
143 |
144 | # add mask token
145 | dictionary.add_symbol("")
146 | for i in range(100):
147 | dictionary.add_symbol(f"")
148 |
149 | dictionary.pad_to_multiple_(cfg.required_batch_size_multiple)
150 | logger.info("dictionary: {} types".format(len(dictionary)))
151 |
152 | # tokenizer = SentencepieceBPE(Namespace(sentencepiece_model=cfg.spm_model))
153 | tokenizer = spm.SentencePieceProcessor()
154 | tokenizer.Load(cfg.spm_model)
155 | return cls(cfg, dictionary, tokenizer)
156 |
157 | def load_dataset(self, split, epoch=1, combine=False, **kwargs):
158 | self.datasets[split] = {
159 | "data": json.load(open(f"{self.cfg.data}/json/{split}.json")),
160 | "data_dir": self.cfg.data,
161 | "shuffle": True if split == "train" else False,
162 | }
163 | self.datasets[split] = Namespace(**self.datasets[split])
164 |
165 | def dataset(self, split):
166 | if split not in self.datasets:
167 | raise KeyError("Dataset not loaded: " + split)
168 |
169 | return self.datasets[split]
170 |
171 | def get_batch_iterator(
172 | self,
173 | dataset,
174 | max_tokens=None,
175 | max_sentences=None,
176 | max_positions=None,
177 | ignore_invalid_inputs=False,
178 | required_batch_size_multiple=1,
179 | seed=1,
180 | num_shards=1,
181 | shard_id=0,
182 | num_workers=0,
183 | epoch=1,
184 | data_buffer_size=0,
185 | disable_iterator_cache=False,
186 | **kwargs,
187 | ):
188 | return MLMLoader(
189 | self.cfg,
190 | dataset,
191 | self.dictionary,
192 | self.tokenizer,
193 | max_tokens=max_tokens,
194 | max_sentences=max_sentences,
195 | max_positions=max_positions,
196 | ignore_invalid_inputs=ignore_invalid_inputs,
197 | required_batch_size_multiple=required_batch_size_multiple,
198 | seed=seed,
199 | num_shards=num_shards,
200 | shard_id=shard_id,
201 | )
202 |
203 | @property
204 | def source_dictionary(self):
205 | return self.dictionary
206 |
207 | @property
208 | def target_dictionary(self):
209 | return self.dictionary
210 |
--------------------------------------------------------------------------------
/examples/fairseq/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | # flake8: noqa
5 | import models
6 | import tasks
7 | import criterions
8 | from fairseq_cli.train import cli_main
9 |
10 | if __name__ == "__main__":
11 | cli_main()
12 |
--------------------------------------------------------------------------------
/examples/fairseq/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/examples/fairseq/utils/sparse_clip.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import math
5 | import warnings
6 |
7 | import torch
8 | import torch.distributed as dist
9 | from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
10 |
11 |
12 | @torch.no_grad()
13 | def clip_grad_norm_(
14 | params, max_norm, moe_expert_count, aggregate_norm_fn=None
15 | ) -> torch.Tensor:
16 | def grad_exists(p):
17 | return p is not None and getattr(p, "grad", None) is not None
18 |
19 | if isinstance(params, torch.Tensor):
20 | params = [params]
21 | params = list(params)
22 | params = list(filter(grad_exists, params))
23 | grads, expert_grads, base_expert_grads, sharded_grads = [], [], [], []
24 | denom = math.sqrt(max(dist.get_global_world_size(), moe_expert_count))
25 | for p in params:
26 | if hasattr(p, "expert"):
27 | expert_grads.append(p.grad.detach() / denom)
28 | elif hasattr(p, "base_expert"):
29 | base_expert_grads.append(p.grad.detach())
30 | elif hasattr(p, "_is_sharded"):
31 | sharded_grads.append(p.grad.detach())
32 | else:
33 | grads.append(p.grad.detach())
34 | if len(grads) == 0:
35 | if len(params) > 0:
36 | total_norm = params[0].new_tensor(0.0)
37 | else:
38 | total_norm = torch.tensor(0.0)
39 | elif len(grads) == 1:
40 | total_norm = torch.norm(grads[0], p=2, dtype=torch.float32)
41 | else:
42 | if multi_tensor_l2norm_available:
43 | total_norm = multi_tensor_total_norm(grads)
44 | else:
45 | if torch.cuda.is_available():
46 | warnings.warn(
47 | "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; "
48 | "you may get better performance by installing NVIDIA's apex library"
49 | )
50 | device = torch.cuda.current_device()
51 | elif grads[0].device.type == "xla":
52 | device = grads[0].device
53 | else:
54 | device = torch.device("cpu")
55 | total_norm = torch.norm(
56 | torch.stack(
57 | [torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]
58 | )
59 | )
60 |
61 | # calculate split_norm and all_reduce with other workers
62 | norms = [total_norm]
63 | for split_grads in [expert_grads, sharded_grads]:
64 | if len(split_grads) == 0:
65 | continue
66 | split_norm = torch.norm(
67 | torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads])
68 | )
69 | if dist.is_initialized():
70 | split_norm.pow_(2)
71 | dist.all_reduce(split_norm)
72 | split_norm.sqrt_()
73 | norms.append(split_norm)
74 | if len(norms) > 1:
75 | total_norm = torch.norm(torch.stack(norms))
76 |
77 | if aggregate_norm_fn is not None:
78 | total_norm = aggregate_norm_fn(total_norm)
79 |
80 | if max_norm > 0:
81 | max_norm = float(max_norm)
82 | clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
83 | for g in grads + expert_grads + sharded_grads + base_expert_grads:
84 | g.mul_(clip_coef)
85 | return total_norm
86 |
--------------------------------------------------------------------------------
/examples/longvit/README.md:
--------------------------------------------------------------------------------
1 | # [(LongViT) When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology](https://arxiv.org/abs/2312.03558)
2 |
3 | **LongViT** is a vision Transformer that can process gigapixel images (e.g., 32,768x32,768 images) in an end-to-end manner. We split the image into millions of patches and employ [LongNet](https://arxiv.org/abs/2307.02486) to directly model the extremely long sequence. We apply LongViT in the field of computational pathology and achieve remarkable performance on cancer subtyping and survival prediction tasks.
4 |
5 |
6 | ## Setup
7 | ```
8 | pip install -r requirements.txt
9 | pip install git+https://github.com/shumingma/fairseq.git@moe
10 | pip install -v -U git+https://github.com/facebookresearch/xformers.git@v0.0.20#egg=xformers
11 | ```
12 |
13 |
14 | ## Pretraining
15 |
16 | We perform self-supervised pretraining on TCGA diagnostic slides using [DINO](https://arxiv.org/abs/2104.14294) objective. The detailed instructions can be found at [`get_started_for_tcga_pretraining.md`](get_started/get_started_for_tcga_pretraining.md).
17 |
18 | The link to the pretrained LongViT model on TCGA diagnostic slides:
19 | - [`LongViT`](https://github.com/wenhui0924/longvit_ckpts/releases/download/longvit/longvit_small_patch32_1024.pth): #layer=12; hidden=384; FFN factor=4x; #head=16; patch=32x32
20 |
21 |
22 | ## Fine-tuning on Subtyping Classification
23 |
24 | We perform finetuning on cancer subtyping on images with sizes up to 32,768x32,768 (1M patches). The detailed instructions can be found at [`get_started_for_tcga_subtyping.md`](get_started/get_started_for_tcga_subtyping.md).
25 |
26 |
27 | ## Fine-tuning on Survival Prediction
28 |
29 | We perform finetuning on survival prediction on images with sizes up to 32,768x32,768 (1M patches). The detailed instructions can be found at [`get_started_for_tcga_survival_prediction.md`](get_started/get_started_for_tcga_survival_prediction.md).
30 |
31 |
32 | ## Citation
33 |
34 | If you find this repository useful, please consider citing our work:
35 | ```
36 | @article{longvit,
37 | title={When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology},
38 | author={Wang, Wenhui and Ma, Shuming and Xu, Hanwen and Usuyama, Naoto and Ding, Jiayu and Poon, Hoifung and Wei, Furu},
39 | journal={arXiv preprint arXiv:2312.03558},
40 | year={2023}
41 | }
42 |
43 | @article{longnet,
44 | title={LongNet: Scaling transformers to 1,000,000,000 tokens},
45 | author={Ding, Jiayu and Ma, Shuming and Dong, Li and Zhang, Xingxing and Huang, Shaohan and Wang, Wenhui and Zheng, Nanning and Wei, Furu},
46 | journal={arXiv preprint arXiv:2307.02486},
47 | year={2023}
48 | }
49 |
50 | @article{torchscale,
51 | title={TorchScale: Transformers at scale},
52 | author={Ma, Shuming and Wang, Hongyu and Huang, Shaohan and Wang, Wenhui and Chi, Zewen and Dong, Li and Benhaim, Alon and Patra, Barun and Chaudhary, Vishrav and Song, Xia and others},
53 | journal={arXiv preprint arXiv:2211.13184},
54 | year={2022}
55 | }
56 | ```
57 |
58 |
59 | ## Acknowledgement
60 |
61 | This repository is built using the [BEiT-3](https://github.com/microsoft/unilm/tree/master/beit3), the [MCAT](https://github.com/mahmoodlab/MCAT), the [DINO](https://github.com/facebookresearch/dino), the [HIPT](https://github.com/mahmoodlab/HIPT) repository and the [timm](https://github.com/rwightman/pytorch-image-models) library.
62 |
63 |
64 | ## License
65 | This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
66 |
67 | [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
68 |
69 | ### Contact Information
70 |
71 | For help or issues using LongViT models, please submit a GitHub issue.
72 |
--------------------------------------------------------------------------------
/examples/longvit/data_preprocessing/cache_transformed_images.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import random
5 | import argparse
6 | from PIL import Image, ImageFilter, ImageOps
7 | from multiprocessing import Pool, cpu_count
8 | from timm.data.transforms import RandomResizedCropAndInterpolation
9 | import torchvision.transforms as transforms
10 |
11 | Image.MAX_IMAGE_PIXELS = 6400000000
12 |
13 |
14 | def build_transform(input_size):
15 | train_interpolation = "bicubic"
16 | t = [
17 | RandomResizedCropAndInterpolation(input_size, scale=(0.5, 1.0), interpolation=train_interpolation),
18 | transforms.RandomHorizontalFlip(),
19 | ]
20 | t = transforms.Compose(t)
21 |
22 | return t
23 |
24 |
25 | def pil_loader(path):
26 | with open(path, "rb") as f:
27 | img = Image.open(f)
28 | return img.convert("RGB")
29 |
30 |
31 | def save_image(transformed_img, output_image_path):
32 | if isinstance(transformed_img, torch.Tensor):
33 | transformed_img = transforms.ToPILImage()(transformed_img)
34 | transformed_img.save(output_image_path)
35 |
36 |
37 | def get_image_files(input_dir):
38 | for root, _, files in os.walk(input_dir):
39 | for file in files:
40 | if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
41 | yield os.path.join(root, file)
42 |
43 |
44 | def transform_and_save_crops(args):
45 | input_path, input_dir, output_dir, transform = args
46 | print(input_path)
47 | file_basename = os.path.basename(input_path)
48 |
49 | img = pil_loader(input_path)
50 | transformed_img = transform(img)
51 | output_image_path = os.path.join(output_dir, file_basename)
52 | save_image(transformed_img, output_image_path)
53 |
54 |
55 | if __name__ == '__main__':
56 | parser = argparse.ArgumentParser(description='Save transformed images in a directory.')
57 | parser.add_argument('input_dir', help='Path to the input directory.')
58 | parser.add_argument('output_dir', help='Path to the output directory.')
59 | parser.add_argument('-p', '--processes', type=int, default=cpu_count(), help='Number of processes to use. Default: number of CPU cores')
60 | parser.add_argument('--input_size', type=int, default=16384, help='input image size')
61 | args = parser.parse_args()
62 |
63 | input_dir = args.input_dir
64 | output_dir = args.output_dir
65 | num_processes = args.processes
66 | input_size = args.input_size
67 | print("num_processes: {}".format(num_processes))
68 | print("input_size: {}".format(input_size))
69 |
70 | transform = build_transform(input_size=input_size)
71 |
72 | image_files = list(get_image_files(input_dir))
73 | task_args = [(file, input_dir, output_dir, transform) for file in image_files]
74 |
75 | os.makedirs(output_dir, exist_ok=True)
76 |
77 | with Pool(processes=num_processes) as pool:
78 | pool.map(transform_and_save_crops, task_args)
79 |
--------------------------------------------------------------------------------
/examples/longvit/data_preprocessing/convert_wsi_to_images.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import argparse
4 | import openslide
5 |
6 | from PIL import Image
7 | from concurrent.futures import ProcessPoolExecutor
8 |
9 |
10 | def convert_wsi_to_images(slide_path, image_path, target_size, level=0):
11 | slide = openslide.open_slide(slide_path)
12 | level_dims = slide.level_dimensions
13 | region = slide.read_region((0,0), level, level_dims[level])
14 | region = region.convert("RGB")
15 | print("convert: {}({}) -> {}".format(slide_path, region.size, image_path))
16 | resized_img = region.resize((target_size, target_size), Image.BICUBIC)
17 | resized_img.save(image_path)
18 |
19 |
20 | def process_slides(input_folder, output_folder, target_size, level=0):
21 | if not os.path.exists(output_folder):
22 | os.makedirs(output_folder)
23 |
24 | slide_paths = glob.glob(os.path.join(input_folder, "*.svs"))
25 |
26 | with ProcessPoolExecutor(max_workers=1) as executor:
27 | for slide_path in slide_paths:
28 | image_path = os.path.join(output_folder, os.path.basename(slide_path).split(".svs")[0] + ".jpg")
29 | executor.submit(convert_wsi_to_images, slide_path, image_path, target_size, level=level)
30 |
31 |
32 | if __name__ == "__main__":
33 | parser = argparse.ArgumentParser(description="Convert slides into images")
34 | parser.add_argument("input_folder", type=str, help="")
35 | parser.add_argument("output_folder", type=str, help="")
36 | parser.add_argument("target_size", type=int, help="")
37 | parser.add_argument("level", type=int, help="")
38 |
39 | args = parser.parse_args()
40 | input_folder = args.input_folder
41 | output_folder = args.output_folder
42 | target_size = args.target_size
43 | level = args.level
44 |
45 | process_slides(input_folder, output_folder, target_size, level=level)
46 |
--------------------------------------------------------------------------------
/examples/longvit/data_preprocessing/create_tcga_subtyping_index.py:
--------------------------------------------------------------------------------
1 | from datasets import TCGASubtypingDataset
2 |
3 | tcga_task = "tcga_brca"
4 | for k_fold in range(10):
5 | TCGASubtypingDataset.make_dataset_index(
6 | task=tcga_task,
7 | csv_path="./subtyping_dataset_csv/{}_subset.csv.zip".format(tcga_task),
8 | csv_split_path="./subtyping_splits/10foldcv_subtype/{}/splits_{}.csv".format(tcga_task, k_fold),
9 | k_fold=k_fold,
10 | index_path="./subtyping_split_index/{}".format(tcga_task),
11 | ignore=['MDLC', 'PD', 'ACBC', 'IMMC', 'BRCNOS', 'BRCA', 'SPC', 'MBC', 'MPT'],
12 | label_dict = {'IDC':0, 'ILC':1},
13 | )
14 |
15 | tcga_task = "tcga_lung"
16 | for k_fold in range(10):
17 | TCGASubtypingDataset.make_dataset_index(
18 | task=tcga_task,
19 | csv_path="./subtyping_dataset_csv/{}_subset.csv.zip".format(tcga_task),
20 | csv_split_path="./subtyping_splits/10foldcv_subtype/{}/splits_{}.csv".format(tcga_task, k_fold),
21 | k_fold=k_fold,
22 | index_path="./subtyping_split_index/{}".format(tcga_task),
23 | ignore=[],
24 | label_dict = {'LUAD':0, 'LUSC':1},
25 | )
26 |
27 | tcga_task = "tcga_kidney"
28 | for k_fold in range(10):
29 | TCGASubtypingDataset.make_dataset_index(
30 | task=tcga_task,
31 | csv_path="./subtyping_dataset_csv/{}_subset.csv.zip".format(tcga_task),
32 | csv_split_path="./subtyping_splits/10foldcv_subtype/{}/splits_{}.csv".format(tcga_task, k_fold),
33 | k_fold=k_fold,
34 | index_path="./subtyping_split_index/{}".format(tcga_task),
35 | ignore=[],
36 | label_dict = {'CCRCC':0, 'PRCC':1, 'CHRCC':2},
37 | )
38 |
--------------------------------------------------------------------------------
/examples/longvit/data_preprocessing/create_tcga_survival_index.py:
--------------------------------------------------------------------------------
1 | from datasets import TCGASurvivalDataset
2 |
3 | for tcga_task in ["tcga_ucec", "tcga_luad", "tcga_brca"]:
4 | for k_fold in range(5):
5 | TCGASurvivalDataset.make_dataset_index(
6 | task=tcga_task,
7 | csv_path="./survival_dataset_csv/{}_all_clean.csv.zip".format(tcga_task),
8 | csv_split_path="./survival_splits/5foldcv/{}/splits_{}.csv".format(tcga_task, k_fold),
9 | k_fold=k_fold,
10 | index_path="./survival_split_index/{}".format(tcga_task),
11 | )
--------------------------------------------------------------------------------
/examples/longvit/data_preprocessing/generate_1024_crops.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import json
5 | import numpy as np
6 | import openslide
7 | import time
8 | import torch
9 | import openslide
10 | import argparse
11 | import random
12 | import shutil
13 |
14 | import glob
15 | from concurrent.futures import ProcessPoolExecutor
16 |
17 | from PIL import Image
18 | from torchvision import transforms
19 | from torchvision.transforms import InterpolationMode
20 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
21 |
22 |
23 | def is_similar_pixel(pixel1, pixel2, threshold=30):
24 | return np.linalg.norm(pixel1 - pixel2) < threshold
25 |
26 |
27 | def should_discard_image(image_path, target_pixel=np.array([243, 243, 243]), threshold=30, similarity_ratio=0.99):
28 | image = cv2.imread(image_path)
29 | height, width, _ = image.shape
30 |
31 | similar_pixels = 0
32 | total_pixels = height * width
33 |
34 | for y in range(height):
35 | for x in range(width):
36 | pixel = image[y, x]
37 |
38 | if is_similar_pixel(pixel, target_pixel, threshold):
39 | similar_pixels += 1
40 |
41 | ratio = similar_pixels / total_pixels
42 | return ratio > similarity_ratio
43 |
44 |
45 | def random_crop(slide_path, output_path, min_crop_size, max_crop_size, level=0):
46 | slide = openslide.open_slide(slide_path)
47 | level_dim = slide.level_dimensions
48 | slide_width, slide_height = slide.dimensions
49 |
50 | crop_width = random.randint(min_crop_size, max_crop_size)
51 | crop_height = random.randint(min_crop_size, max_crop_size)
52 |
53 | x = random.randint(0, slide_width - crop_width)
54 | y = random.randint(0, slide_height - crop_height)
55 |
56 | region = slide.read_region((x, y), level, (crop_width, crop_height))
57 | region = region.convert("RGB")
58 | region.save(output_path)
59 |
60 |
61 | def get_crops(slide_path, output_folder, crop_number, min_crop_size, max_crop_size):
62 | print(slide_path)
63 |
64 | index = 0
65 | while index < crop_number:
66 | output_path = os.path.join(output_folder, os.path.basename(slide_path).split(".svs")[0], f"{str(index).zfill(8)}.JPEG")
67 |
68 | dir_path = os.path.dirname(output_path)
69 | if not os.path.exists(dir_path):
70 | os.makedirs(dir_path)
71 |
72 | random_crop(slide_path, output_path, min_crop_size, max_crop_size)
73 | if not should_discard_image(output_path):
74 | index += 1
75 |
76 |
77 | def process_slides(input_folder, output_folder, crop_number=100, min_crop_size=1024, max_crop_size=1536):
78 | if not os.path.exists(output_folder):
79 | os.makedirs(output_folder)
80 |
81 | slide_paths = glob.glob(f"{input_folder}/**/*.svs", recursive=True)
82 |
83 | with ProcessPoolExecutor(max_workers=4) as executor:
84 | for slide_path in slide_paths:
85 | executor.submit(get_crops, slide_path, output_folder, crop_number, min_crop_size, max_crop_size)
86 |
87 |
88 | if __name__ == "__main__":
89 | parser = argparse.ArgumentParser(description="Generate crops from slides")
90 | parser.add_argument("input_folder", type=str, help="")
91 | parser.add_argument("output_folder", type=str, help="")
92 | parser.add_argument("crop_number", type=int, help="")
93 |
94 | args = parser.parse_args()
95 | input_folder = args.input_folder
96 | output_folder = args.output_folder
97 | crop_number = args.crop_number
98 |
99 | process_slides(input_folder, output_folder, crop_number=crop_number)
100 |
--------------------------------------------------------------------------------
/examples/longvit/data_preprocessing/split_to_small_images.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import shutil
4 | import argparse
5 | from PIL import Image
6 | from concurrent.futures import ProcessPoolExecutor
7 |
8 | Image.MAX_IMAGE_PIXELS = 6400000000
9 |
10 |
11 | def split_image(image_path, input_folder, output_folder, num_splits):
12 | print(image_path)
13 | file_name, file_ext = os.path.splitext(os.path.basename(image_path))
14 |
15 | img = Image.open(image_path)
16 | width, height = img.size
17 |
18 | block_width = width
19 | block_height = height // num_splits
20 |
21 | for i in range(num_splits):
22 | left = 0
23 | upper = i * block_height
24 | right = block_width
25 | lower = (i + 1) * block_height
26 | cropped_img = img.crop((left, upper, right, lower))
27 | cropped_img.save(f"{output_folder}/{file_name}_{i}{file_ext}")
28 |
29 |
30 | def find_images(input_folder):
31 | image_files = []
32 | for root, _, files in os.walk(input_folder):
33 | for f in files:
34 | if f.lower().endswith(('.png', '.jpg', '.jpeg')):
35 | image_files.append(os.path.join(root, f))
36 | return image_files
37 |
38 |
39 | def process_images(image_files, input_folder, output_folder, num_splits, num_processes):
40 | with ProcessPoolExecutor(max_workers=num_processes) as executor:
41 | for image_file in image_files:
42 | executor.submit(split_image, image_file, input_folder, output_folder, num_splits)
43 |
44 |
45 | def main():
46 | parser = argparse.ArgumentParser(description='Split images into smaller tiles')
47 | parser.add_argument('--input', type=str, required=True, help='Path to the input folder containing images')
48 | parser.add_argument('--output', type=str, required=True, help='Path to the output folder for saving the tiles')
49 | parser.add_argument('--num_splits', type=int, default=16, help='Size of the tiles (default: 4096)')
50 | parser.add_argument('--processes', type=int, default=1, help='Number of processes (default: number of CPU cores)')
51 | args = parser.parse_args()
52 |
53 | input_folder = args.input
54 | output_folder = args.output
55 | num_splits = args.num_splits
56 | num_processes = args.processes
57 |
58 | if not os.path.exists(output_folder):
59 | os.makedirs(output_folder)
60 |
61 | image_files = find_images(input_folder)
62 | process_images(image_files, input_folder, output_folder, num_splits, num_processes)
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
67 |
68 |
--------------------------------------------------------------------------------
/examples/longvit/get_started/get_started_for_tcga_pretraining.md:
--------------------------------------------------------------------------------
1 | # Pretraining LongViT on TCGA using DINO
2 |
3 | ## Setup
4 |
5 | 1. Download TCGA diagnostic whole slides from [NIH Genomic Data Commons Data Portal](https://portal.gdc.cancer.gov/).
6 |
7 | 2. Generate 1,024x1,024 regions from WSIs:
8 | ```
9 | # we randomly generate 100 small regions for each whole slide image
10 | python data_preprocessing/generate_1024_crops.py /path/to/your_WSIs /path/to/your_crops 100
11 | ```
12 |
13 | ## Pretraining LongViT
14 |
15 | Replace the `vision_transformer.py` in [DINO](https://github.com/facebookresearch/dino) with [LongViT vision_transformer.py](../pretraining/vision_transformer.py), and modify the `global crop size` to 1024 and `local crop size` to 512 to preform LongViT pretraining using DINO framework.
--------------------------------------------------------------------------------
/examples/longvit/get_started/get_started_for_tcga_subtyping.md:
--------------------------------------------------------------------------------
1 | # Fine-tuning LongViT on TCGA Subtyping
2 |
3 | ## Setup
4 |
5 | 1. Download TCGA diagnostic whole slides from [NIH Genomic Data Commons Data Portal](https://portal.gdc.cancer.gov/), and organize the dataset (e.g., BRCA WSIs) as following structure:
6 |
7 | ```
8 | /path/to/your_WSIs/
9 | TCGA-3C-AALI-01Z-00-DX1.F6E9A5DF-D8FB-45CF-B4BD-C6B76294C291.svs
10 | ...
11 | TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9.svs
12 | ...
13 | ```
14 |
15 | 2. Download [dataset annotation csv](https://github.com/mahmoodlab/HIPT/tree/master/2-Weakly-Supervised-Subtyping/dataset_csv) and [splits for cross validation](https://github.com/mahmoodlab/HIPT/tree/master/2-Weakly-Supervised-Subtyping/splits/10foldcv_subtype) from the HIPT repository.
16 |
17 | 3. Generate the index json files of each split using the following command.
18 | ```
19 | # Modify the `csv_path` and `csv_split_path` to your path.
20 | python data_preprocessing/create_tcga_subtyping_index.py
21 | ```
22 |
23 | 4. Resize whole slide images to the desired size for finetuning.
24 | ```
25 | python data_preprocessing/convert_wsi_to_images.py /path/to/your_WSIs /path/to/your_resized_WSIs ${target_size} ${wsi_level}
26 | ```
27 |
28 | 5. (Optional) For very large images (e.g., 32,768x32,768), we suggest parallelizing the training across multiple GPU devices due to the constraints of computation and memory. We split the sequence of millions of patches along the sequence dimension.
29 | ```
30 | # num_splits is equal to the number of GPUs you used (e.g., 8 in our experiment)
31 | python data_preprocessing/split_to_small_images.py /path/to/your_resized_WSIs /path/to/your_splited_WSIs --num_splits ${num_splits}
32 | ```
33 |
34 | 6. (Optional) We find performing image augmentation slightly improves the performance. For very large images (e.g., 32,768x32,768), we perform the augmentation and cache the resulted images of each epoch.
35 | ```
36 | # Run the command 10 times (number of epochs in finetuning) using i from 0-9
37 | python data_preprocessing/cache_transformed_images.py /path/to/your_resized_WSIs /path/to/your_augmentated_WSIs/epoch_$i --input_size 32768
38 | ```
39 |
40 | Split these cached images as in step 5 and organize the data as following structure:
41 | ```
42 | /path/to/your_splited_WSIs/
43 | epoch_0/
44 | TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9_0.jpg
45 | ...
46 | TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9_7.jpg
47 | ...
48 | epoch_1/
49 | TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9_0.jpg
50 | ...
51 | TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9_7.jpg
52 | ...
53 | ...
54 | epoch_5/
55 | ...
56 | epoch_9/
57 | wo_augmentation/
58 | ```
59 |
60 |
61 | ## Example: Fine-tuning LongViT on TCGA Subtyping
62 |
63 | The LongViT model can be fine-tuned using 8 V100-32GB. For images with a size less than or equal to 16,384x16,384, we can directly perform finetuning without using sequence parallel.
64 |
65 | ```bash
66 | # IMAGE_SIZE - {1024, 4096, 8192, 16384}
67 | # TASK - {"brca", "kidney", "lung"}
68 | # K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
69 | python -m torch.distributed.launch --nproc_per_node=8 run_longvit_finetuning.py \
70 | --input_size ${IMAGE_SIZE} \
71 | --model longvit_small_patch32_${IMAGE_SIZE} \
72 | --task tcga_${TASK}_subtyping \
73 | --batch_size 1 \
74 | --layer_decay 1.0 \
75 | --lr 5e-5 \
76 | --update_freq 1 \
77 | --epochs 10 \
78 | --warmup_epochs 1 \
79 | --drop_path 0.1 \
80 | --finetune /your_longvit_model_path/longvit_small_patch32_1024.pth
81 | --data_path ./subtyping_split_index/tcga_${TASK} \
82 | --image_dir /path/to/your_resized_WSIs \
83 | --output_dir /path/to/save/your_model \
84 | --log_dir /path/to/save/your_model/log \
85 | --weight_decay 0.05 \
86 | --seed 42 \
87 | --save_ckpt_freq 5 \
88 | --k_fold ${K_FOLD} \
89 | --num_workers 1 \
90 | --enable_deepspeed \
91 | --model_key teacher \
92 | --randaug
93 | ```
94 | - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretraining).
95 | - `--randaug`: perform image augmentation.
96 |
97 |
98 | Sequence parallel of training on 32,768x32,768 images:
99 |
100 | ```bash
101 | # TASK - {"brca", "kidney", "lung"}
102 | # K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
103 | python -m torch.distributed.launch --nproc_per_node=8 run_longvit_finetuning.py \
104 | --input_size 32768 \
105 | --model longvit_small_patch32_32768 \
106 | --task tcga_${TASK}_subtyping \
107 | --batch_size 2 \
108 | --layer_decay 1.0 \
109 | --lr 5e-5 \
110 | --update_freq 4 \
111 | --epochs 10 \
112 | --warmup_epochs 1 \
113 | --drop_path 0.1 \
114 | --finetune /your_longvit_model_path/longvit_small_patch32_1024.pth
115 | --data_path ./subtyping_split_index/tcga_${TASK} \
116 | --image_dir /path/to/your_splited_WSIs \
117 | --output_dir /path/to/save/your_model \
118 | --log_dir /path/to/save/your_model/log \
119 | --weight_decay 0.05 \
120 | --seed 42 \
121 | --save_ckpt_freq 5 \
122 | --k_fold ${K_FOLD} \
123 | --num_workers 1 \
124 | --enable_deepspeed \
125 | --model_key teacher \
126 | --seq_parallel \
127 | --cached_randaug
128 | ```
129 | - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretraining).
130 | - `--seq_parallel`: parallelize the training for very large images.
131 | - `--cached_randaug`: perform training on the cached augmented images.
132 |
133 |
134 | ## Example: Evaluate LongViT on TCGA Subtyping
135 |
136 | ```bash
137 | # IMAGE_SIZE - {1024, 4096, 8192, 16384, 32768}
138 | # TASK - {"brca", "kidney", "lung"}
139 | # K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
140 | python -m torch.distributed.launch --nproc_per_node=1 run_longvit_finetuning.py \
141 | --input_size ${IMAGE_SIZE} \
142 | --model longvit_small_patch32_${IMAGE_SIZE} \
143 | --task tcga_${TASK}_subtyping \
144 | --batch_size 1 \
145 | --layer_decay 1.0 \
146 | --lr 5e-5 \
147 | --update_freq 1 \
148 | --epochs 10 \
149 | --warmup_epochs 1 \
150 | --drop_path 0.1 \
151 | --finetune /path/to/save/your_model/checkpoint-best/mp_rank_00_model_states.pt \
152 | --data_path ./subtyping_split_index/tcga_${TASK} \
153 | --image_dir /path/to/your_resized_WSIs \
154 | --output_dir /path/to/save/your_model \
155 | --log_dir /path/to/save/your_model/log \
156 | --weight_decay 0.05 \
157 | --seed 42 \
158 | --save_ckpt_freq 5 \
159 | --k_fold ${K_FOLD} \
160 | --num_workers 1 \
161 | --enable_deepspeed \
162 | --model_key module \
163 | --eval \
164 | --no_auto_resume
165 | ```
166 | - `--eval`: performing evaluation on test set.
167 | - `--finetune`: best val model used for test.
168 |
169 | For the model trained with sequence parallel, add `--seq_parallel` and use the same number of GPUs as training to perform evaluation.
--------------------------------------------------------------------------------
/examples/longvit/get_started/get_started_for_tcga_survival_prediction.md:
--------------------------------------------------------------------------------
1 | # Fine-tuning LongViT on TCGA Survival Prediction
2 |
3 | ## Setup
4 |
5 | 1. Download TCGA diagnostic whole slides from [NIH Genomic Data Commons Data Portal](https://portal.gdc.cancer.gov/), and organize the dataset (e.g., BRCA WSIs) as following structure:
6 |
7 | ```
8 | /path/to/your_WSIs/
9 | TCGA-3C-AALI-01Z-00-DX1.F6E9A5DF-D8FB-45CF-B4BD-C6B76294C291.svs
10 | ...
11 | TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9.svs
12 | ...
13 | ```
14 |
15 | 2. Download [dataset annotation csv](https://github.com/mahmoodlab/MCAT/tree/master/datasets_csv_sig) and [splits for cross validation](https://github.com/mahmoodlab/MCAT/tree/master/splits/5foldcv) from the MCAT repository.
16 |
17 | 3. Generate the index json files of each split using the following command.
18 | ```
19 | # Modify the `csv_path` and `csv_split_path` to your path.
20 | python data_preprocessing/create_tcga_survival_index.py
21 | ```
22 |
23 | 4. Resize whole slide images to the desired size for finetuning.
24 | ```
25 | python data_preprocessing/convert_wsi_to_images.py /path/to/your_WSIs /path/to/your_resized_WSIs ${target_size} ${wsi_level}
26 | ```
27 |
28 | 5. (Optional) For very large images (e.g., 32,768x32,768), we suggest parallelizing the training across multiple GPU devices due to the constraints of computation and memory. We split the sequence of millions of patches along the sequence dimension.
29 | ```
30 | # num_splits is equal to the number of GPUs you used (e.g., 8 in our experiment)
31 | python data_preprocessing/split_to_small_images.py /path/to/your_resized_WSIs /path/to/your_splited_WSIs --num_splits ${num_splits}
32 | ```
33 |
34 |
35 | ## Example: Fine-tuning LongViT on TCGA Survival Prediction
36 |
37 | The LongViT model can be fine-tuned using 8 V100-32GB. For images with a size less than or equal to 16,384x16,384, we can directly perform finetuning without using sequence parallel.
38 |
39 | ```bash
40 | # IMAGE_SIZE - {1024, 4096, 8192, 16384}
41 | # TASK - {"brca", "kidney", "lung"}
42 | # K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
43 | python -m torch.distributed.launch --nproc_per_node=8 run_longvit_finetuning.py \
44 | --input_size ${IMAGE_SIZE} \
45 | --model longvit_small_patch32_${IMAGE_SIZE} \
46 | --task tcga_${TASK}_survival \
47 | --batch_size 1 \
48 | --layer_decay 1.0 \
49 | --lr 5e-5 \
50 | --update_freq 1 \
51 | --epochs 10 \
52 | --warmup_epochs 1 \
53 | --drop_path 0.1 \
54 | --finetune /your_longvit_model_path/longvit_small_patch32_1024.pth
55 | --data_path ./survival_split_index/tcga_${TASK} \
56 | --image_dir /path/to/your_resized_WSIs \
57 | --output_dir /path/to/save/your_model \
58 | --log_dir /path/to/save/your_model/log \
59 | --weight_decay 0.05 \
60 | --seed 42 \
61 | --save_ckpt_freq 5 \
62 | --k_fold ${K_FOLD} \
63 | --num_workers 1 \
64 | --enable_deepspeed \
65 | --model_key teacher \
66 | --randaug
67 | ```
68 | - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretraining).
69 | - `--randaug`: perform image augmentation.
70 |
71 |
72 | Parallelize the training of 32,768x32,768 images:
73 |
74 | ```bash
75 | # TASK - {"brca", "kidney", "lung"}
76 | # K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
77 | python -m torch.distributed.launch --nproc_per_node=8 run_longvit_finetuning.py \
78 | --input_size 32768 \
79 | --model longvit_small_patch32_32768 \
80 | --task tcga_${TASK}_survival \
81 | --batch_size 2 \
82 | --layer_decay 1.0 \
83 | --lr 5e-5 \
84 | --update_freq 4 \
85 | --epochs 10 \
86 | --warmup_epochs 1 \
87 | --drop_path 0.1 \
88 | --finetune /your_longvit_model_path/longvit_small_patch32_1024.pth
89 | --data_path ./subtyping_split_index/tcga_${TASK} \
90 | --image_dir /path/to/your_splited_WSIs \
91 | --output_dir /path/to/save/your_model \
92 | --log_dir /path/to/save/your_model/log \
93 | --weight_decay 0.05 \
94 | --seed 42 \
95 | --save_ckpt_freq 5 \
96 | --k_fold ${K_FOLD} \
97 | --num_workers 1 \
98 | --enable_deepspeed \
99 | --model_key teacher \
100 | --seq_parallel
101 | ```
102 | - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretraining).
103 | - `--seq_parallel`: parallelize the training for very large images.
104 |
105 |
106 | ## Example: Evaluate LongViT on TCGA Subtyping
107 |
108 | ```bash
109 | # IMAGE_SIZE - {1024, 4096, 8192, 16384, 32768}
110 | # TASK - {"brca", "kidney", "lung"}
111 | # K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
112 | python -m torch.distributed.launch --nproc_per_node=1 run_longvit_finetuning.py \
113 | --input_size ${IMAGE_SIZE} \
114 | --model longvit_small_patch32_${IMAGE_SIZE} \
115 | --task tcga_${TASK}_survival \
116 | --batch_size 1 \
117 | --layer_decay 1.0 \
118 | --lr 5e-5 \
119 | --update_freq 1 \
120 | --epochs 10 \
121 | --warmup_epochs 1 \
122 | --drop_path 0.1 \
123 | --finetune /path/to/save/your_model/checkpoint-best/mp_rank_00_model_states.pt \
124 | --data_path ./survival_split_index/tcga_${TASK} \
125 | --image_dir /path/to/your_resized_WSIs \
126 | --output_dir /path/to/save/your_model \
127 | --log_dir /path/to/save/your_model/log \
128 | --weight_decay 0.05 \
129 | --seed 42 \
130 | --save_ckpt_freq 5 \
131 | --k_fold ${K_FOLD} \
132 | --num_workers 1 \
133 | --enable_deepspeed \
134 | --model_key module \
135 | --eval \
136 | --no_auto_resume
137 | ```
138 | - `--eval`: performing evaluation.
139 | - `--finetune`: best val model.
140 |
141 | For the model trained with sequence parallel, add `--seq_parallel` and use the same number of GPUs as training to perform evaluation.
--------------------------------------------------------------------------------
/examples/longvit/longvit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Mostly copy-paste from timm library.
16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17 | """
18 | import math
19 | from functools import partial
20 |
21 | import utils
22 | import torch
23 | import torch.nn as nn
24 |
25 | from torchscale.architecture.encoder import Encoder
26 | from torchscale.model.LongNet import LongNetEncoder
27 | from torchscale.architecture.config import EncoderConfig
28 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_
29 |
30 |
31 | def trunc_normal_(tensor, mean=0., std=1.):
32 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
33 |
34 |
35 | def drop_path(x, drop_prob: float = 0., training: bool = False):
36 | if drop_prob == 0. or not training:
37 | return x
38 | keep_prob = 1 - drop_prob
39 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
40 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
41 | random_tensor.floor_() # binarize
42 | output = x.div(keep_prob) * random_tensor
43 | return output
44 |
45 |
46 | class DropPath(nn.Module):
47 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
48 | """
49 | def __init__(self, drop_prob=None):
50 | super(DropPath, self).__init__()
51 | self.drop_prob = drop_prob
52 |
53 | def forward(self, x):
54 | return drop_path(x, self.drop_prob, self.training)
55 |
56 |
57 | class Mlp(nn.Module):
58 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
59 | super().__init__()
60 | out_features = out_features or in_features
61 | hidden_features = hidden_features or in_features
62 | self.fc1 = nn.Linear(in_features, hidden_features)
63 | self.act = act_layer()
64 | self.fc2 = nn.Linear(hidden_features, out_features)
65 | self.drop = nn.Dropout(drop)
66 |
67 | def forward(self, x):
68 | x = self.fc1(x)
69 | x = self.act(x)
70 | x = self.drop(x)
71 | x = self.fc2(x)
72 | x = self.drop(x)
73 | return x
74 |
75 |
76 | class Attention(nn.Module):
77 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
78 | super().__init__()
79 | self.num_heads = num_heads
80 | head_dim = dim // num_heads
81 | self.scale = qk_scale or head_dim ** -0.5
82 |
83 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
84 | self.attn_drop = nn.Dropout(attn_drop)
85 | self.proj = nn.Linear(dim, dim)
86 | self.proj_drop = nn.Dropout(proj_drop)
87 |
88 | def forward(self, x):
89 | B, N, C = x.shape
90 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
91 | q, k, v = qkv[0], qkv[1], qkv[2]
92 |
93 | attn = (q @ k.transpose(-2, -1)) * self.scale
94 | attn = attn.softmax(dim=-1)
95 | attn = self.attn_drop(attn)
96 |
97 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
98 | x = self.proj(x)
99 | x = self.proj_drop(x)
100 | return x, attn
101 |
102 |
103 | class Block(nn.Module):
104 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
105 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
106 | super().__init__()
107 | self.norm1 = norm_layer(dim)
108 | self.attn = Attention(
109 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
110 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
111 | self.norm2 = norm_layer(dim)
112 | mlp_hidden_dim = int(dim * mlp_ratio)
113 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
114 |
115 | def forward(self, x, return_attention=False):
116 | y, attn = self.attn(self.norm1(x))
117 | if return_attention:
118 | return attn
119 | x = x + self.drop_path(y)
120 | x = x + self.drop_path(self.mlp(self.norm2(x)))
121 | return x
122 |
123 |
124 | class PatchEmbed(nn.Module):
125 | """ Image to Patch Embedding
126 | """
127 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
128 | super().__init__()
129 | num_patches = (img_size // patch_size) * (img_size // patch_size)
130 | self.img_size = img_size
131 | self.patch_size = patch_size
132 | self.num_patches = num_patches
133 |
134 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
135 |
136 | def forward(self, x):
137 | B, C, H, W = x.shape
138 | x = self.proj(x).flatten(2).transpose(1, 2)
139 | return x
140 |
141 |
142 | class LongViT(nn.Module):
143 | """ Vision Transformer """
144 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
145 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
146 | drop_path_rate=0., norm_layer=nn.LayerNorm, flash_attention=True, dilated_ratio="[1,2,4,8,16]", segment_length="[64,128,256,512,1024]", checkpoint_activations=False, seq_parallel=False, **kwargs):
147 | super().__init__()
148 | self.num_features = self.embed_dim = embed_dim
149 |
150 | self.patch_embed = PatchEmbed(
151 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
152 | num_patches = self.patch_embed.num_patches
153 |
154 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
155 | self.pos_drop = nn.Dropout(p=drop_rate)
156 |
157 | if img_size == 4096:
158 | segment_length = "[1024,2048,4096,8192,16384]"
159 | elif img_size == 8192:
160 | segment_length = "[1024,4096,8192,16384,65536]"
161 | elif img_size == 16384:
162 | segment_length = "[1024,4096,16384,65536,262144]"
163 | elif img_size == 32768:
164 | segment_length = "[1024,4096,32768,262144,1048576]"
165 |
166 | self.seq_parallel = seq_parallel
167 | encoder_config = EncoderConfig(
168 | img_size=img_size, patch_size=patch_size, vocab_size=64010, multiway=False,
169 | layernorm_embedding=False, normalize_output=False, no_output_layer=True,
170 | drop_path_rate=drop_path_rate, encoder_embed_dim=embed_dim, encoder_attention_heads=num_heads,
171 | encoder_ffn_embed_dim=int(embed_dim * mlp_ratio), encoder_layers=depth,
172 | checkpoint_activations=checkpoint_activations, flash_attention=flash_attention,
173 | dilated_ratio=dilated_ratio, segment_length=segment_length, seq_parallel=seq_parallel,
174 | )
175 | if flash_attention:
176 | print("Using Torchscale LoneNetEncoder")
177 | print("segment_length: {}".format(encoder_config.segment_length))
178 | print("dilated_ratio: {}".format(encoder_config.dilated_ratio))
179 | print("checkpoint_activations: {}".format(encoder_config.checkpoint_activations))
180 | print("drop_path_rate: {}".format(encoder_config.drop_path_rate))
181 | self.encoder = LongNetEncoder(encoder_config, embed_tokens=None, embed_positions=None,
182 | output_projection=None, is_encoder_decoder=False,)
183 | else:
184 | print("Using Torchscale Encoder")
185 | self.encoder = Encoder(encoder_config, embed_tokens=None, embed_positions=None,
186 | output_projection=None, is_encoder_decoder=False,)
187 |
188 | trunc_normal_(self.pos_embed, std=.02)
189 | self.apply(self._init_weights)
190 |
191 | def _init_weights(self, m):
192 | if isinstance(m, nn.Linear):
193 | trunc_normal_(m.weight, std=.02)
194 | if isinstance(m, nn.Linear) and m.bias is not None:
195 | nn.init.constant_(m.bias, 0)
196 | elif isinstance(m, nn.LayerNorm):
197 | nn.init.constant_(m.bias, 0)
198 | nn.init.constant_(m.weight, 1.0)
199 |
200 | def interpolate_pos_encoding(self, x, w, h):
201 | npatch = x.shape[1]
202 | N = self.pos_embed.shape[1]
203 | if npatch == N and w == h:
204 | return self.pos_embed
205 | patch_pos_embed = self.pos_embed
206 | dim = x.shape[-1]
207 | w0 = w // self.patch_embed.patch_size
208 | h0 = h // self.patch_embed.patch_size
209 | # we add a small number to avoid floating point error in the interpolation
210 | # see discussion at https://github.com/facebookresearch/dino/issues/8
211 | w0, h0 = w0 + 0.1, h0 + 0.1
212 | patch_pos_embed = nn.functional.interpolate(
213 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
214 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
215 | mode='bicubic',
216 | )
217 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
218 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
219 | return patch_pos_embed
220 |
221 | def prepare_tokens(self, x):
222 | B, nc, w, h = x.shape
223 | x = self.patch_embed(x) # patch linear embedding
224 |
225 | # add positional encoding to each token
226 | if self.seq_parallel:
227 | rank_seq_len = x.size(1)
228 | cur_rank = utils.get_rank()
229 | start_idx = cur_rank * rank_seq_len
230 | end_idx = (cur_rank + 1) * rank_seq_len
231 | x = x + self.pos_embed[:, start_idx:end_idx, :]
232 | else:
233 | x = x + self.interpolate_pos_encoding(x, w, h)
234 |
235 | return self.pos_drop(x)
236 |
237 | def forward(self, x):
238 | x = self.prepare_tokens(x)
239 | x = self.encoder(src_tokens=None, token_embeddings=x)["encoder_out"]
240 | return x
241 |
--------------------------------------------------------------------------------
/examples/longvit/modeling_finetune.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4 | # Copyright (c) 2023 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # --------------------------------------------------------'
7 |
8 | import utils
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import numpy as np
13 |
14 | from timm.models.registry import register_model
15 | from functools import partial
16 | from longvit import LongViT
17 | from torchscale.architecture.config import EncoderConfig
18 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_
19 |
20 |
21 | def _get_small_config(
22 | img_size=1024, patch_size=32, drop_path_rate=0,
23 | checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
24 | ):
25 | return EncoderConfig(
26 | img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=False,
27 | layernorm_embedding=False, normalize_output=False, no_output_layer=True,
28 | drop_path_rate=drop_path_rate, encoder_embed_dim=384, encoder_attention_heads=16,
29 | encoder_ffn_embed_dim=int(384 * mlp_ratio), encoder_layers=12,
30 | checkpoint_activations=checkpoint_activations,
31 | )
32 |
33 |
34 | def trunc_normal_(tensor, mean=0., std=1.):
35 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
36 |
37 |
38 | class LongViTForTCGAClassification(nn.Module):
39 | def __init__(
40 | self,
41 | args,
42 | num_classes,
43 | norm_layer=nn.LayerNorm,
44 | seq_parallel=False,
45 | **kwargs
46 | ):
47 | super().__init__()
48 | self.model = LongViT(
49 | img_size=args.img_size, patch_size=args.patch_size, embed_dim=args.encoder_embed_dim,
50 | depth=args.encoder_layers, num_heads=args.encoder_attention_heads,
51 | mlp_ratio=4, drop_path_rate=args.drop_path_rate,
52 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
53 | checkpoint_activations=args.checkpoint_activations, seq_parallel=seq_parallel
54 | )
55 | embed_dim = args.encoder_embed_dim
56 | self.depth = args.encoder_layers
57 | self.fc_norm = norm_layer(embed_dim)
58 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
59 |
60 | self.fc_norm.apply(self._init_weights)
61 | self.head.apply(self._init_weights)
62 |
63 | def _init_weights(self, m):
64 | if isinstance(m, nn.Linear):
65 | trunc_normal_(m.weight, std=.02)
66 | if isinstance(m, nn.Linear) and m.bias is not None:
67 | nn.init.constant_(m.bias, 0)
68 | elif isinstance(m, nn.LayerNorm):
69 | nn.init.constant_(m.bias, 0)
70 | nn.init.constant_(m.weight, 1.0)
71 |
72 | def get_num_layers(self):
73 | return self.depth
74 |
75 | @torch.jit.ignore
76 | def no_weight_decay(self):
77 | return {'model.pos_embed'}
78 |
79 | def forward(self, image, **kwargs):
80 | x = self.model(image)
81 | t = x[:, :, :]
82 | cls_x = self.fc_norm(t.mean(1))
83 | return self.head(cls_x)
84 |
85 |
86 | class LongViTForTCGAClassificationSeqParallel(nn.Module):
87 | def __init__(
88 | self,
89 | args,
90 | num_classes,
91 | norm_layer=nn.LayerNorm,
92 | seq_parallel=False,
93 | **kwargs
94 | ):
95 | super().__init__()
96 | self.model = LongViT(
97 | img_size=args.img_size, patch_size=args.patch_size, embed_dim=args.encoder_embed_dim,
98 | depth=args.encoder_layers, num_heads=args.encoder_attention_heads,
99 | mlp_ratio=4, drop_path_rate=args.drop_path_rate,
100 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
101 | checkpoint_activations=args.checkpoint_activations, seq_parallel=seq_parallel,
102 | )
103 | embed_dim = args.encoder_embed_dim
104 | self.depth = args.encoder_layers
105 | self.fc_norm = norm_layer(embed_dim)
106 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
107 |
108 | self.fc_norm.apply(self._init_weights)
109 | self.head.apply(self._init_weights)
110 |
111 | def _init_weights(self, m):
112 | if isinstance(m, nn.Linear):
113 | trunc_normal_(m.weight, std=.02)
114 | if isinstance(m, nn.Linear) and m.bias is not None:
115 | nn.init.constant_(m.bias, 0)
116 | elif isinstance(m, nn.LayerNorm):
117 | nn.init.constant_(m.bias, 0)
118 | nn.init.constant_(m.weight, 1.0)
119 |
120 | def get_num_layers(self):
121 | return self.depth
122 |
123 | @torch.jit.ignore
124 | def no_weight_decay(self):
125 | return {'model.pos_embed'}
126 |
127 | def forward(self, image, **kwargs):
128 | x = self.model(image)
129 | t = x[:, :, :].contiguous()
130 | gatherd_t = utils.gather_tcga_features(t)
131 | cls_x = self.fc_norm(gatherd_t.mean(1))
132 | return self.head(cls_x)
133 |
134 |
135 | @register_model
136 | def longvit_small_patch32_1024_tcga_subtyping(pretrained=False, task=None, **kwargs):
137 | args = _get_small_config(img_size=1024, patch_size=32, **kwargs)
138 | if task == "tcga_kidney_subtyping":
139 | model = LongViTForTCGAClassification(args, num_classes=3, **kwargs)
140 | else:
141 | model = LongViTForTCGAClassification(args, num_classes=2, **kwargs)
142 | return model
143 |
144 |
145 | @register_model
146 | def longvit_small_patch32_4096_tcga_subtyping(pretrained=False, task=None, **kwargs):
147 | args = _get_small_config(img_size=4096, patch_size=32, **kwargs)
148 | if task == "tcga_kidney_subtyping":
149 | model = LongViTForTCGAClassification(args, num_classes=3, **kwargs)
150 | else:
151 | model = LongViTForTCGAClassification(args, num_classes=2, **kwargs)
152 | return model
153 |
154 |
155 | @register_model
156 | def longvit_small_patch32_8192_tcga_subtyping(pretrained=False, task=None, **kwargs):
157 | args = _get_small_config(img_size=8192, patch_size=32, **kwargs)
158 | args.checkpoint_activations = True
159 | if task == "tcga_kidney_subtyping":
160 | model = LongViTForTCGAClassification(args, num_classes=3, **kwargs)
161 | else:
162 | model = LongViTForTCGAClassification(args, num_classes=2, **kwargs)
163 | return model
164 |
165 |
166 | @register_model
167 | def longvit_small_patch32_16384_tcga_subtyping(pretrained=False, task=None, **kwargs):
168 | args = _get_small_config(img_size=16384, patch_size=32, **kwargs)
169 | args.checkpoint_activations = True
170 | if task == "tcga_kidney_subtyping":
171 | model = LongViTForTCGAClassification(args, num_classes=3, **kwargs)
172 | else:
173 | model = LongViTForTCGAClassification(args, num_classes=2, **kwargs)
174 | return model
175 |
176 |
177 | @register_model
178 | def longvit_small_patch32_32768_tcga_subtyping(pretrained=False, task=None, seq_parallel=False, **kwargs):
179 | args = _get_small_config(img_size=32768, patch_size=32, **kwargs)
180 | args.checkpoint_activations = True
181 | if task == "tcga_kidney_subtyping":
182 | model = LongViTForTCGAClassificationSeqParallel(args, num_classes=3, seq_parallel=seq_parallel, **kwargs)
183 | else:
184 | model = LongViTForTCGAClassificationSeqParallel(args, num_classes=2, seq_parallel=seq_parallel, **kwargs)
185 | return model
186 |
187 |
188 | @register_model
189 | def longvit_small_patch32_1024_tcga_survival(pretrained=False, task=None, **kwargs):
190 | args = _get_small_config(img_size=1024, patch_size=32, **kwargs)
191 | model = LongViTForTCGAClassification(args, num_classes=4, **kwargs)
192 | return model
193 |
194 |
195 | @register_model
196 | def longvit_small_patch32_4096_tcga_survival(pretrained=False, task=None, **kwargs):
197 | args = _get_small_config(img_size=4096, patch_size=32, **kwargs)
198 | model = LongViTForTCGAClassification(args, num_classes=4, **kwargs)
199 | return model
200 |
201 |
202 | @register_model
203 | def longvit_small_patch32_8192_tcga_survival(pretrained=False, task=None, **kwargs):
204 | args = _get_small_config(img_size=8192, patch_size=32, **kwargs)
205 | args.checkpoint_activations = True
206 | model = LongViTForTCGAClassification(args, num_classes=4, **kwargs)
207 | return model
208 |
209 |
210 | @register_model
211 | def longvit_small_patch32_16384_tcga_survival(pretrained=False, task=None, **kwargs):
212 | args = _get_small_config(img_size=16384, patch_size=32, **kwargs)
213 | args.checkpoint_activations = True
214 | model = LongViTForTCGAClassification(args, num_classes=4, **kwargs)
215 | return model
216 |
217 |
218 | @register_model
219 | def longvit_small_patch32_32768_tcga_survival(pretrained=False, task=None, seq_parallel=False, **kwargs):
220 | args = _get_small_config(img_size=32768, patch_size=32, **kwargs)
221 | args.checkpoint_activations = True
222 | model = LongViTForTCGAClassificationSeqParallel(args, num_classes=4, seq_parallel=seq_parallel, **kwargs)
223 | return model
224 |
--------------------------------------------------------------------------------
/examples/longvit/optim_factory.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4 | # Copyright (c) 2023 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # --------------------------------------------------------'
7 |
8 | from torch import optim as optim
9 | from timm.optim.lookahead import Lookahead
10 |
11 | import json
12 |
13 |
14 | def get_num_layer_for_vit(var_name, num_max_layer):
15 | if "embed" in var_name:
16 | return 0
17 | elif var_name in (
18 | "cls_token", "mask_token", "pos_embed", "model.pos_embed", "language_pos_embed",
19 | "word_embeddings.weight", "vision_cls_token", "vision_pos_embed"
20 | ):
21 | return 0
22 | elif var_name.startswith("patch_embed"):
23 | return 0
24 | elif var_name.startswith("rel_pos_bias"):
25 | return num_max_layer - 1
26 | elif "layers." in var_name:
27 | layer_id = int(var_name.split('layers.')[1].split('.')[0])
28 | return layer_id + 1
29 | else:
30 | return num_max_layer - 1
31 |
32 |
33 | def get_is_head_flag_for_vit(var_name, num_max_layer):
34 | if var_name.startswith("head"):
35 | return 1
36 | # elif var_name.startswith("pooler"):
37 | # return 1
38 | else:
39 | return 0
40 |
41 |
42 | class LayerDecayValueAssigner(object):
43 | def __init__(self, values, scale_handler=None):
44 | self.scale_handler = scale_handler or get_num_layer_for_vit
45 | self.values = values
46 |
47 | def get_scale(self, layer_id):
48 | return self.values[layer_id]
49 |
50 | def get_layer_id(self, var_name):
51 | return self.scale_handler(var_name, len(self.values))
52 |
53 |
54 | # The implementation code is modified from Timm (https://github.com/huggingface/pytorch-image-models/tree/main/timm
55 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
56 | parameter_group_names = {}
57 | parameter_group_vars = {}
58 |
59 | for name, param in model.named_parameters():
60 | if not param.requires_grad:
61 | continue # frozen weights
62 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
63 | group_name = "no_decay"
64 | this_weight_decay = 0.
65 | else:
66 | group_name = "decay"
67 | this_weight_decay = weight_decay
68 | if get_num_layer is not None:
69 | layer_id = get_num_layer(name)
70 | group_name = "layer_%d_%s" % (layer_id, group_name)
71 | else:
72 | layer_id = None
73 |
74 | if group_name not in parameter_group_names:
75 | if get_layer_scale is not None:
76 | scale = get_layer_scale(layer_id)
77 | else:
78 | scale = 1.
79 |
80 | parameter_group_names[group_name] = {
81 | "weight_decay": this_weight_decay,
82 | "params": [],
83 | "lr_scale": scale
84 | }
85 | parameter_group_vars[group_name] = {
86 | "weight_decay": this_weight_decay,
87 | "params": [],
88 | "lr_scale": scale
89 | }
90 |
91 | parameter_group_vars[group_name]["params"].append(param)
92 | parameter_group_names[group_name]["params"].append(name)
93 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
94 | return list(parameter_group_vars.values())
95 |
96 |
97 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
98 | opt_lower = args.opt.lower()
99 | weight_decay = args.weight_decay
100 | if weight_decay and filter_bias_and_bn:
101 | skip = {}
102 | if skip_list is not None:
103 | skip = skip_list
104 | elif hasattr(model, 'no_weight_decay'):
105 | skip = model.no_weight_decay()
106 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
107 | weight_decay = 0.
108 | else:
109 | parameters = model.parameters()
110 |
111 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
112 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
113 | opt_args['eps'] = args.opt_eps
114 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
115 | opt_args['betas'] = args.opt_betas
116 |
117 | opt_split = opt_lower.split('_')
118 | opt_lower = opt_split[-1]
119 | if opt_lower == 'adamw':
120 | optimizer = optim.AdamW(parameters, **opt_args)
121 | else:
122 | raise ValueError("Invalid optimizer")
123 |
124 | if len(opt_split) > 1:
125 | if opt_split[0] == 'lookahead':
126 | optimizer = Lookahead(optimizer)
127 |
128 | return optimizer
129 |
--------------------------------------------------------------------------------
/examples/longvit/pretraining/vision_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Mostly copy-paste from timm library.
16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17 | """
18 | import math
19 | from functools import partial
20 |
21 | import torch
22 | import torch.nn as nn
23 |
24 | from utils import trunc_normal_
25 | from torchscale.architecture.encoder import Encoder
26 | from torchscale.model.LongNet import LongNetEncoder
27 | from torchscale.architecture.config import EncoderConfig
28 |
29 |
30 | def drop_path(x, drop_prob: float = 0., training: bool = False):
31 | if drop_prob == 0. or not training:
32 | return x
33 | keep_prob = 1 - drop_prob
34 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
35 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
36 | random_tensor.floor_() # binarize
37 | output = x.div(keep_prob) * random_tensor
38 | return output
39 |
40 |
41 | class DropPath(nn.Module):
42 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
43 | """
44 | def __init__(self, drop_prob=None):
45 | super(DropPath, self).__init__()
46 | self.drop_prob = drop_prob
47 |
48 | def forward(self, x):
49 | return drop_path(x, self.drop_prob, self.training)
50 |
51 |
52 | class Mlp(nn.Module):
53 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
54 | super().__init__()
55 | out_features = out_features or in_features
56 | hidden_features = hidden_features or in_features
57 | self.fc1 = nn.Linear(in_features, hidden_features)
58 | self.act = act_layer()
59 | self.fc2 = nn.Linear(hidden_features, out_features)
60 | self.drop = nn.Dropout(drop)
61 |
62 | def forward(self, x):
63 | x = self.fc1(x)
64 | x = self.act(x)
65 | x = self.drop(x)
66 | x = self.fc2(x)
67 | x = self.drop(x)
68 | return x
69 |
70 |
71 | class Attention(nn.Module):
72 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
73 | super().__init__()
74 | self.num_heads = num_heads
75 | head_dim = dim // num_heads
76 | self.scale = qk_scale or head_dim ** -0.5
77 |
78 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
79 | self.attn_drop = nn.Dropout(attn_drop)
80 | self.proj = nn.Linear(dim, dim)
81 | self.proj_drop = nn.Dropout(proj_drop)
82 |
83 | def forward(self, x):
84 | B, N, C = x.shape
85 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
86 | q, k, v = qkv[0], qkv[1], qkv[2]
87 |
88 | attn = (q @ k.transpose(-2, -1)) * self.scale
89 | attn = attn.softmax(dim=-1)
90 | attn = self.attn_drop(attn)
91 |
92 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
93 | x = self.proj(x)
94 | x = self.proj_drop(x)
95 | return x, attn
96 |
97 |
98 | class Block(nn.Module):
99 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
100 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
101 | super().__init__()
102 | self.norm1 = norm_layer(dim)
103 | self.attn = Attention(
104 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
105 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
106 | self.norm2 = norm_layer(dim)
107 | mlp_hidden_dim = int(dim * mlp_ratio)
108 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
109 |
110 | def forward(self, x, return_attention=False):
111 | y, attn = self.attn(self.norm1(x))
112 | if return_attention:
113 | return attn
114 | x = x + self.drop_path(y)
115 | x = x + self.drop_path(self.mlp(self.norm2(x)))
116 | return x
117 |
118 |
119 | class PatchEmbed(nn.Module):
120 | """ Image to Patch Embedding
121 | """
122 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
123 | super().__init__()
124 | num_patches = (img_size // patch_size) * (img_size // patch_size)
125 | self.img_size = img_size
126 | self.patch_size = patch_size
127 | self.num_patches = num_patches
128 |
129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
130 |
131 | def forward(self, x):
132 | B, C, H, W = x.shape
133 | x = self.proj(x).flatten(2).transpose(1, 2)
134 | return x
135 |
136 |
137 | class VisionTransformer(nn.Module):
138 | """ Vision Transformer """
139 | def __init__(self, img_size=1024, patch_size=32, in_chans=3, num_classes=0, embed_dim=768, depth=12,
140 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
141 | drop_path_rate=0., norm_layer=nn.LayerNorm, flash_attention=True, dilated_ratio="[1,2,4,8,16]", segment_length="[64,128,256,512,1024]", checkpoint_activations=False, **kwargs):
142 | super().__init__()
143 | self.num_features = self.embed_dim = embed_dim
144 |
145 | self.patch_embed = PatchEmbed(
146 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
147 | num_patches = self.patch_embed.num_patches
148 |
149 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
150 | self.pos_drop = nn.Dropout(p=drop_rate)
151 |
152 | encoder_config = EncoderConfig(
153 | img_size=img_size, patch_size=patch_size, vocab_size=64010, multiway=False,
154 | layernorm_embedding=False, normalize_output=False, no_output_layer=True,
155 | drop_path_rate=drop_path_rate, encoder_embed_dim=embed_dim, encoder_attention_heads=num_heads,
156 | encoder_ffn_embed_dim=int(embed_dim * mlp_ratio), encoder_layers=depth,
157 | checkpoint_activations=checkpoint_activations, flash_attention=flash_attention,
158 | dilated_ratio=dilated_ratio, segment_length=segment_length, seq_parallel=False,
159 | )
160 | if flash_attention:
161 | print("Using Torchscale LoneNetEncoder")
162 | self.encoder = LongNetEncoder(encoder_config, embed_tokens=None, embed_positions=None,
163 | output_projection=None, is_encoder_decoder=False,)
164 | else:
165 | print("Using Torchscale Encoder")
166 | self.encoder = Encoder(encoder_config, embed_tokens=None, embed_positions=None,
167 | output_projection=None, is_encoder_decoder=False,)
168 |
169 | self.norm = norm_layer(embed_dim)
170 |
171 | # Classifier head
172 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
173 |
174 | trunc_normal_(self.pos_embed, std=.02)
175 | self.apply(self._init_weights)
176 |
177 | def _init_weights(self, m):
178 | if isinstance(m, nn.Linear):
179 | trunc_normal_(m.weight, std=.02)
180 | if isinstance(m, nn.Linear) and m.bias is not None:
181 | nn.init.constant_(m.bias, 0)
182 | elif isinstance(m, nn.LayerNorm):
183 | nn.init.constant_(m.bias, 0)
184 | nn.init.constant_(m.weight, 1.0)
185 |
186 | def interpolate_pos_encoding(self, x, w, h):
187 | npatch = x.shape[1]
188 | N = self.pos_embed.shape[1]
189 | if npatch == N and w == h:
190 | return self.pos_embed
191 | patch_pos_embed = self.pos_embed
192 | dim = x.shape[-1]
193 | w0 = w // self.patch_embed.patch_size
194 | h0 = h // self.patch_embed.patch_size
195 | # we add a small number to avoid floating point error in the interpolation
196 | # see discussion at https://github.com/facebookresearch/dino/issues/8
197 | w0, h0 = w0 + 0.1, h0 + 0.1
198 | patch_pos_embed = nn.functional.interpolate(
199 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
200 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
201 | mode='bicubic',
202 | )
203 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
204 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
205 | return patch_pos_embed
206 |
207 | def prepare_tokens(self, x):
208 | B, nc, w, h = x.shape
209 | x = self.patch_embed(x) # patch linear embedding
210 |
211 | # add positional encoding to each token
212 | x = x + self.interpolate_pos_encoding(x, w, h)
213 |
214 | return self.pos_drop(x)
215 |
216 | def forward(self, x):
217 | x = self.prepare_tokens(x)
218 | x = self.encoder(src_tokens=None, token_embeddings=x)["encoder_out"]
219 | x = self.norm(x)
220 | t = x[:, :, :]
221 | cls_x = t.mean(1)
222 | return cls_x
223 |
224 |
225 | def vit_small(patch_size=32, **kwargs):
226 | model = VisionTransformer(
227 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=16, mlp_ratio=4,
228 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
229 | return model
230 |
231 |
232 | class DINOHead(nn.Module):
233 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
234 | super().__init__()
235 | nlayers = max(nlayers, 1)
236 | if nlayers == 1:
237 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
238 | else:
239 | layers = [nn.Linear(in_dim, hidden_dim)]
240 | if use_bn:
241 | layers.append(nn.BatchNorm1d(hidden_dim))
242 | layers.append(nn.GELU())
243 | for _ in range(nlayers - 2):
244 | layers.append(nn.Linear(hidden_dim, hidden_dim))
245 | if use_bn:
246 | layers.append(nn.BatchNorm1d(hidden_dim))
247 | layers.append(nn.GELU())
248 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
249 | self.mlp = nn.Sequential(*layers)
250 | self.apply(self._init_weights)
251 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
252 | self.last_layer.weight_g.data.fill_(1)
253 | if norm_last_layer:
254 | self.last_layer.weight_g.requires_grad = False
255 |
256 | def _init_weights(self, m):
257 | if isinstance(m, nn.Linear):
258 | trunc_normal_(m.weight, std=.02)
259 | if isinstance(m, nn.Linear) and m.bias is not None:
260 | nn.init.constant_(m.bias, 0)
261 |
262 | def forward(self, x):
263 | x = self.mlp(x)
264 | x = nn.functional.normalize(x, dim=-1, p=2)
265 | x = self.last_layer(x)
266 | return x
267 |
--------------------------------------------------------------------------------
/examples/longvit/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.0
2 | timm==0.6.13
3 | Pillow==10.0.0
4 | blobfile==2.0.2
5 | mypy==1.4.1
6 | numpy==1.22.4
7 | pytest==7.2.2
8 | requests==2.31.0
9 | einops==0.6.1
10 | tensorboardX==1.8
11 | scipy==1.6.3
12 | ftfy==6.1.1
13 | opencv-python==4.8.0.74
14 | pyarrow==9.0.0
15 | transformers==4.8.1
16 | deepspeed==0.4.0
17 | scikit-survival==0.22.1
18 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | from io import open
5 |
6 | from setuptools import find_packages, setup
7 |
8 | setup(
9 | name="torchscale",
10 | version="0.2.0",
11 | author="TorchScale Team",
12 | author_email="Shuming.Ma@microsoft.com",
13 | description="Transformers at any scale",
14 | long_description=open("README.md", "r", encoding="utf-8").read(),
15 | long_description_content_type="text/markdown",
16 | keywords="Transformers at any scale",
17 | license="MIT",
18 | url="https://github.com/microsoft/torchscale",
19 | packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
20 | install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.6.13", "einops"],
21 | python_requires=">=3.8.0",
22 | classifiers=[
23 | "Programming Language :: Python :: 3",
24 | ],
25 | )
26 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/tests/test_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import pytest
5 | import torch
6 |
7 | from torchscale.architecture.config import DecoderConfig
8 | from torchscale.architecture.decoder import Decoder
9 |
10 | testcases = [
11 | {},
12 | {"vocab_size": 64000},
13 | {"activation_fn": "relu"},
14 | {"drop_path_rate": 0.1},
15 | {"decoder_normalize_before": False},
16 | {"no_scale_embedding": False},
17 | {"layernorm_embedding": True},
18 | {"rel_pos_buckets": 32, "max_rel_pos": 256},
19 | {"deepnorm": True, "subln": False, "decoder_normalize_before": False},
20 | {"bert_init": True},
21 | {"multiway": True},
22 | {"share_decoder_input_output_embed": True},
23 | {"checkpoint_activations": True},
24 | {"fsdp": True},
25 | ]
26 |
27 |
28 | @pytest.mark.parametrize("args", testcases)
29 | def test_decoder(args):
30 | config = DecoderConfig(**args)
31 | model = Decoder(config)
32 | prev_output_tokens = torch.ones(2, 10)
33 | token_embeddings = torch.rand(2, 10, config.decoder_embed_dim)
34 | model(
35 | prev_output_tokens=prev_output_tokens,
36 | token_embeddings=token_embeddings,
37 | features_only=True,
38 | )
39 |
--------------------------------------------------------------------------------
/tests/test_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import pytest
5 | import torch
6 |
7 | from torchscale.architecture.config import EncoderConfig
8 | from torchscale.architecture.encoder import Encoder
9 |
10 | testcases = [
11 | {},
12 | {"vocab_size": 64000},
13 | {"activation_fn": "relu"},
14 | {"drop_path_rate": 0.1},
15 | {"encoder_normalize_before": False},
16 | {"no_scale_embedding": False},
17 | {"layernorm_embedding": True},
18 | {"rel_pos_buckets": 32, "max_rel_pos": 256},
19 | {"deepnorm": True, "subln": False, "encoder_normalize_before": False},
20 | {"bert_init": True},
21 | {"multiway": True},
22 | {"share_encoder_input_output_embed": True},
23 | {"checkpoint_activations": True},
24 | {"fsdp": True},
25 | ]
26 |
27 |
28 | @pytest.mark.parametrize("args", testcases)
29 | def test_encoder(args):
30 | config = EncoderConfig(**args)
31 | model = Encoder(config)
32 | token_embeddings = torch.rand(2, 10, config.encoder_embed_dim)
33 | model(src_tokens=None, token_embeddings=token_embeddings)
34 |
--------------------------------------------------------------------------------
/tests/test_encoder_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import pytest
5 | import torch
6 |
7 | from torchscale.architecture.config import EncoderDecoderConfig
8 | from torchscale.architecture.encoder_decoder import EncoderDecoder
9 | from torchscale.component.embedding import PositionalEmbedding, TextEmbedding
10 |
11 | testcases = [
12 | {},
13 | {"vocab_size": 64000},
14 | {"activation_fn": "relu"},
15 | {"drop_path_rate": 0.1},
16 | {"encoder_normalize_before": False, "decoder_normalize_before": False},
17 | {"no_scale_embedding": False},
18 | {"layernorm_embedding": True},
19 | {"rel_pos_buckets": 32, "max_rel_pos": 256},
20 | {
21 | "deepnorm": True,
22 | "subln": False,
23 | "encoder_normalize_before": False,
24 | "decoder_normalize_before": False,
25 | },
26 | {"bert_init": True},
27 | {"multiway": True},
28 | {"share_decoder_input_output_embed": True},
29 | {"share_all_embeddings": True},
30 | {"checkpoint_activations": True},
31 | {"fsdp": True},
32 | ]
33 |
34 |
35 | @pytest.mark.parametrize("args", testcases)
36 | def test_decoder(args):
37 | config = EncoderDecoderConfig(**args)
38 | model = EncoderDecoder(
39 | config,
40 | encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
41 | decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
42 | encoder_embed_positions=PositionalEmbedding(
43 | config.max_source_positions, config.encoder_embed_dim
44 | ),
45 | decoder_embed_positions=PositionalEmbedding(
46 | config.max_target_positions, config.decoder_embed_dim
47 | ),
48 | )
49 |
50 | src_tokens = torch.ones(2, 20).long()
51 | prev_output_tokens = torch.ones(2, 10).long()
52 |
53 | model(
54 | src_tokens=src_tokens,
55 | prev_output_tokens=prev_output_tokens,
56 | features_only=True,
57 | )
58 |
--------------------------------------------------------------------------------
/torchscale/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/torchscale/architecture/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/torchscale/architecture/encoder_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch.nn as nn
5 |
6 | from torchscale.architecture.decoder import Decoder
7 | from torchscale.architecture.encoder import Encoder
8 |
9 |
10 | class EncoderDecoder(nn.Module):
11 | def __init__(
12 | self,
13 | args,
14 | encoder_embed_tokens=None,
15 | encoder_embed_positions=None,
16 | decoder_embed_tokens=None,
17 | decoder_embed_positions=None,
18 | output_projection=None,
19 | **kwargs
20 | ):
21 | super().__init__()
22 | self.args = args
23 | if args.share_all_embeddings:
24 | args.share_decoder_input_output_embed = True
25 |
26 | self.encoder = Encoder(
27 | args,
28 | encoder_embed_tokens,
29 | encoder_embed_positions,
30 | is_encoder_decoder=True,
31 | **kwargs
32 | )
33 |
34 | if args.share_all_embeddings and decoder_embed_tokens is None:
35 | decoder_embed_tokens = self.encoder.embed_tokens
36 |
37 | self.decoder = Decoder(
38 | args,
39 | decoder_embed_tokens,
40 | decoder_embed_positions,
41 | output_projection,
42 | is_encoder_decoder=True,
43 | **kwargs
44 | )
45 |
46 | def forward(
47 | self,
48 | src_tokens,
49 | prev_output_tokens,
50 | return_all_hiddens=False,
51 | features_only=False,
52 | **kwargs
53 | ):
54 | encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
55 | decoder_out = self.decoder(
56 | prev_output_tokens,
57 | encoder_out=encoder_out,
58 | features_only=features_only,
59 | return_all_hiddens=return_all_hiddens,
60 | )
61 | return decoder_out
62 |
--------------------------------------------------------------------------------
/torchscale/architecture/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch.nn as nn
5 |
6 | from torchscale.component.multihead_attention import MultiheadAttention
7 | from torchscale.component.multiway_network import MultiwayNetwork
8 |
9 |
10 | def init_bert_params(module):
11 | def normal_(data):
12 | data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
13 |
14 | if isinstance(module, nn.Linear):
15 | normal_(module.weight.data)
16 | if module.bias is not None:
17 | module.bias.data.zero_()
18 | if isinstance(module, nn.Embedding):
19 | normal_(module.weight.data)
20 | if module.padding_idx is not None:
21 | module.weight.data[module.padding_idx].zero_()
22 | if isinstance(module, MultiheadAttention):
23 | if isinstance(module.q_proj, MultiwayNetwork):
24 | normal_(module.q_proj.A.weight.data)
25 | normal_(module.q_proj.B.weight.data)
26 | normal_(module.k_proj.A.weight.data)
27 | normal_(module.k_proj.B.weight.data)
28 | normal_(module.v_proj.A.weight.data)
29 | normal_(module.v_proj.B.weight.data)
30 | else:
31 | normal_(module.q_proj.weight.data)
32 | normal_(module.k_proj.weight.data)
33 | normal_(module.v_proj.weight.data)
34 |
--------------------------------------------------------------------------------
/torchscale/component/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/torchscale/component/dilated_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from einops import rearrange
9 |
10 | from .multihead_attention import MultiheadAttention
11 | from .utils import padding_to_multiple_of, all_gather_func, get_data_parallel_rank, get_data_parallel_world_size
12 |
13 |
14 | class DilatedAttention(MultiheadAttention):
15 |
16 | def dense_to_sparse(self, x, ratio):
17 | length = x.size(1)
18 | padding = padding_to_multiple_of(length, ratio)
19 | head_padding = padding_to_multiple_of(self.num_heads, ratio)
20 |
21 | if padding > 0 or head_padding > 0:
22 | x = F.pad(x, (0, 0, 0, head_padding, 0, padding), value = 0.)
23 |
24 | x = rearrange(x, 'b (l r1) (r2 h) d -> b l h d r1 r2', r1=ratio, r2=ratio)
25 | x = torch.diagonal(x, offset=0, dim1=4, dim2=5)
26 | x = rearrange(x, 'b l h d r -> b l (r h) d')
27 |
28 | if head_padding > 0:
29 | x = x[:, :, :self.num_heads]
30 |
31 | return x
32 |
33 | def sparse_to_dense(self, out, lse, ratio):
34 | head_padding = padding_to_multiple_of(self.num_heads, ratio)
35 |
36 | if head_padding > 0:
37 | out = F.pad(out, (0, 0, 0, head_padding), value = 0.)
38 | lse = F.pad(lse, (0, 0, 0, head_padding), value = -1e8)
39 |
40 | out = rearrange(out, 'b l (r h) d -> b l h d r', r=ratio)
41 | out = torch.diag_embed(out, offset=0, dim1=4, dim2=5)
42 | out = rearrange(out, 'b l h d r1 r2 -> b (r2 h) (l r1) d', r1=ratio, r2=ratio)
43 |
44 | lse = rearrange(lse, 'b (r h) l -> b l h r', r=ratio)
45 | lse = torch.diag_embed(lse, offset=0, dim1=3, dim2=4)
46 | lse = lse.masked_fill_(lse==0, -1e8)
47 | lse = rearrange(lse, 'b l h r1 r2 -> b (r2 h) (l r1) 1', r1=ratio, r2=ratio)
48 |
49 | if head_padding > 0:
50 | out = out[:, :self.num_heads]
51 | lse = lse[:, :self.num_heads]
52 |
53 | return out, lse
54 |
55 | def gather_kv(self, x, sl, seq_len, is_causal=True):
56 | bsz = x.size(0)
57 | assert sl % seq_len == 0
58 | num_rank_per_segment = sl // seq_len
59 |
60 | x = all_gather_func(x)
61 | current_rank = get_data_parallel_rank()
62 | x = rearrange(x, '(w b) l h d -> w b l h d', b=bsz)
63 |
64 | if is_causal:
65 | if current_rank > 0:
66 | x = x[:current_rank]
67 | else:
68 | x = x[:1] * 0
69 |
70 | current_segment = current_rank // num_rank_per_segment * num_rank_per_segment
71 | x = x[current_segment:current_segment+num_rank_per_segment]
72 |
73 | x = rearrange(x, 'w b l h d -> b (w l) h d')
74 | return x
75 |
76 | def gathering(self, x, dr, sl, is_causal=True, offset=0, is_kv=False, seq_parall=True):
77 |
78 | curr_x = x
79 | if offset > 0:
80 | curr_x = F.pad(curr_x, (0, 0, 0, 0, offset % sl, 0), value=0.)
81 | seq_len = curr_x.size(1)
82 | should_gather_kv = is_kv and (get_data_parallel_world_size() > 1) and (sl > seq_len) and seq_parall
83 | _sl = sl
84 | sl = min(sl, seq_len)
85 | padding = padding_to_multiple_of(seq_len, sl)
86 |
87 | if padding > 0:
88 | curr_x = F.pad(curr_x, (0, 0, 0, 0, 0, padding), value = 0.)
89 |
90 | curr_x = rearrange(curr_x, 'b (n g) h d -> (b n) g h d', g=sl)
91 | curr_x = self.dense_to_sparse(curr_x, dr)
92 |
93 | if should_gather_kv:
94 | curr_x = self.gather_kv(curr_x, _sl, seq_len, is_causal)
95 |
96 | curr_x = rearrange(curr_x, 'b l h d -> (b h) l d')
97 |
98 | return curr_x
99 |
100 | def scattering(self, outs, lses, seq_len, bsz, offset=0):
101 | assert len(outs) == len(lses)
102 | assert len(outs) % len(self.args.dilated_ratio) == 0
103 | all_outs, all_lses = [], []
104 | drs = self.args.dilated_ratio
105 | if len(outs) > len(drs):
106 | drs = drs * (len(outs) // len(drs))
107 |
108 | for dr, o, lse in zip(drs, outs, lses):
109 | o = rearrange(o, 'b l (h d) -> b l h d', h=self.num_heads)
110 | o, lse = self.sparse_to_dense(o, lse, dr)
111 | o = rearrange(o, '(b n) h g d -> (b h) (n g) d', b=bsz)
112 | lse = rearrange(lse, '(b n) h g 1 -> (b h) (n g) 1', b=bsz)
113 | o = o[:, offset:offset+seq_len]
114 | lse = lse[:, offset:offset+seq_len]
115 |
116 | all_outs.append(o)
117 | all_lses.append(lse)
118 |
119 | with torch.no_grad():
120 | max_lse = torch.stack(all_lses, dim=0)
121 | max_lse = max_lse.max(0)[0]
122 | all_lses = [torch.exp(lse-max_lse) for lse in all_lses]
123 | lse_sum = torch.stack(all_lses, dim=0).sum(0)
124 | all_lses = [lse / lse_sum for lse in all_lses]
125 |
126 | out = 0
127 | for o, lse in zip(all_outs, all_lses):
128 | out += o * lse.type_as(o)
129 | out = rearrange(out, '(b h) l d -> b l (h d)', h=self.num_heads)
130 |
131 | return out
132 |
133 | def forward(
134 | self,
135 | query,
136 | key,
137 | value,
138 | incremental_state=None,
139 | key_padding_mask=None,
140 | attn_mask=None,
141 | rel_pos=None,
142 | is_first_step=False,
143 | is_causal=False,
144 | ):
145 | assert self.args.flash_attention
146 | assert rel_pos is None
147 | bsz, tgt_len, embed_dim = query.size()
148 | src_len = tgt_len
149 | assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
150 |
151 | key_bsz, src_len, _ = key.size()
152 | assert key_bsz == bsz, f"{query.size(), key.size()}"
153 | assert value is not None
154 | assert bsz, src_len == value.shape[:2]
155 |
156 | q = self.q_proj(query)
157 | k = self.k_proj(key)
158 | v = self.v_proj(value)
159 |
160 | q = rearrange(q, 'b l (h d) -> (b h) l d', h=self.num_heads)
161 | k = rearrange(k, 'b l (h d) -> (b h) l d', h=self.num_heads)
162 | v = rearrange(v, 'b l (h d) -> (b h) l d', h=self.num_heads)
163 |
164 | if incremental_state is not None and not is_first_step:
165 | offset = src_len - 1
166 | else:
167 | offset = 0
168 |
169 | if incremental_state is not None:
170 | if "prev_key" in incremental_state:
171 | prev_key = incremental_state["prev_key"].view(
172 | bsz * self.num_heads, -1, self.head_dim
173 | )
174 | prev_value = incremental_state["prev_value"].view(
175 | bsz * self.num_heads, -1, self.head_dim
176 | )
177 | k = torch.cat([prev_key, k], dim=1)
178 | v = torch.cat([prev_value, v], dim=1)
179 | incremental_state["prev_key"] = k.view(
180 | bsz, self.num_heads, -1, self.head_dim
181 | )
182 | incremental_state["prev_value"] = v.view(
183 | bsz, self.num_heads, -1, self.head_dim
184 | )
185 | src_len = k.size(1)
186 |
187 | if self.xpos is not None:
188 | if incremental_state is not None and not is_first_step:
189 | offset = src_len - 1
190 | else:
191 | offset = 0
192 | k = self.xpos(k, offset=0, downscale=True)
193 | q = self.xpos(q, offset=offset, downscale=False)
194 |
195 | q = rearrange(q, '(b h) l d -> b l h d', h=self.num_heads)
196 | k = rearrange(k, '(b h) l d -> b l h d', h=self.num_heads)
197 | v = rearrange(v, '(b h) l d -> b l h d', h=self.num_heads)
198 |
199 | outs, lses = [], []
200 | for sl, dr in zip(self.args.segment_length, self.args.dilated_ratio):
201 | ki = self.gathering(k, dr, sl, is_causal=is_causal, offset=0, is_kv=True, seq_parall=self.args.seq_parallel)
202 | vi = self.gathering(v, dr, sl, is_causal=is_causal, offset=0, is_kv=True, seq_parall=self.args.seq_parallel)
203 | qi = self.gathering(q, dr, sl, is_causal=is_causal, offset=offset, is_kv=False, seq_parall=self.args.seq_parallel)
204 |
205 | out, lse = self.attention_ops(qi, ki, vi, key_padding_mask=key_padding_mask, attn_mask=attn_mask, rel_pos=rel_pos, is_causal=is_causal)
206 |
207 | outs.append(out)
208 | lses.append(lse)
209 |
210 | attn = self.scattering(outs, lses, tgt_len, bsz, offset=offset)
211 |
212 | if self.inner_attn_ln is not None:
213 | attn = self.inner_attn_ln(attn)
214 |
215 | attn = self.out_proj(attn)
216 |
217 | return attn, None
218 |
--------------------------------------------------------------------------------
/torchscale/component/droppath.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch.nn as nn
5 | from timm.models.layers import drop_path
6 |
7 |
8 | class DropPath(nn.Module):
9 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
10 |
11 | def __init__(self, drop_prob=None):
12 | super(DropPath, self).__init__()
13 | self.drop_prob = drop_prob
14 |
15 | def forward(self, x):
16 | return drop_path(x, self.drop_prob, self.training)
17 |
18 | def extra_repr(self):
19 | return "p={}".format(self.drop_prob)
20 |
--------------------------------------------------------------------------------
/torchscale/component/embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class VisionLanguageEmbedding(nn.Module):
10 | def __init__(self, text_embed, vision_embed):
11 | super().__init__()
12 | self.text_embed = text_embed
13 | self.vision_embed = vision_embed
14 |
15 | def forward(self, textual_tokens, visual_tokens, **kwargs):
16 | if textual_tokens is None:
17 | return self.vision_embed(visual_tokens)
18 |
19 | if visual_tokens is None:
20 | return self.text_embed(textual_tokens)
21 |
22 | x1 = self.vision_embed(visual_tokens)
23 | x2 = self.text_embed(textual_tokens)
24 |
25 | return torch.cat([x1, x2], dim=1)
26 |
27 |
28 | class VisionEmbedding(nn.Module):
29 | """Image to Patch Embedding"""
30 |
31 | def __init__(
32 | self,
33 | img_size=224,
34 | patch_size=16,
35 | in_chans=3,
36 | embed_dim=768,
37 | contain_mask_token=False,
38 | prepend_cls_token=False,
39 | ):
40 | super().__init__()
41 | img_size = (img_size, img_size)
42 | patch_size = (patch_size, patch_size)
43 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
44 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
45 | self.img_size = img_size
46 | self.patch_size = patch_size
47 | self.num_patches = num_patches
48 |
49 | self.proj = nn.Conv2d(
50 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
51 | )
52 |
53 | if contain_mask_token:
54 | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
55 | else:
56 | self.mask_token = None
57 |
58 | if prepend_cls_token:
59 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
60 | else:
61 | self.cls_token = None
62 |
63 | def num_position_embeddings(self):
64 | if self.cls_token is None:
65 | return self.num_patches
66 | else:
67 | return self.num_patches + 1
68 |
69 | def forward(self, x, masked_position=None, **kwargs):
70 | B, C, H, W = x.shape
71 | assert (
72 | H == self.img_size[0] and W == self.img_size[1]
73 | ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
74 | x = self.proj(x).flatten(2).transpose(1, 2)
75 |
76 | batch_size, seq_len, _ = x.size()
77 |
78 | if masked_position is not None:
79 | assert self.mask_token is not None
80 | mask_token = self.mask_token.expand(batch_size, seq_len, -1)
81 | w = masked_position.unsqueeze(-1).type_as(mask_token)
82 | x = x * (1 - w) + mask_token * w
83 |
84 | if self.cls_token is not None:
85 | cls_tokens = self.cls_token.expand(
86 | batch_size, -1, -1
87 | ) # stole cls_tokens impl from Phil Wang, thanks
88 | x = torch.cat((cls_tokens, x), dim=1)
89 |
90 | return x
91 |
92 |
93 | class TextEmbedding(nn.Embedding):
94 | def reset_parameters(self):
95 | nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5)
96 | self._fill_padding_idx_with_zero()
97 |
98 |
99 | class PositionalEmbedding(nn.Embedding):
100 | def forward(
101 | self,
102 | x,
103 | positions=None,
104 | **kwargs,
105 | ):
106 | if positions is None:
107 | # being consistent with Fairseq, which starts from 2.
108 | positions = (
109 | torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0)
110 | )
111 | return F.embedding(
112 | positions,
113 | self.weight,
114 | self.padding_idx,
115 | self.max_norm,
116 | self.norm_type,
117 | self.scale_grad_by_freq,
118 | self.sparse,
119 | )
120 |
--------------------------------------------------------------------------------
/torchscale/component/feedforward_network.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | try:
8 | from apex.normalization import FusedLayerNorm as LayerNorm
9 | except ModuleNotFoundError:
10 | from torch.nn import LayerNorm
11 |
12 |
13 | from .xmoe.global_groups import get_moe_group
14 |
15 |
16 | class set_torch_seed(object):
17 | def __init__(self, seed):
18 | assert isinstance(seed, int)
19 | self.rng_state = self.get_rng_state()
20 |
21 | torch.manual_seed(seed)
22 | if torch.cuda.is_available():
23 | torch.cuda.manual_seed(seed)
24 |
25 | def get_rng_state(self):
26 | state = {"torch_rng_state": torch.get_rng_state()}
27 | if torch.cuda.is_available():
28 | state["cuda_rng_state"] = torch.cuda.get_rng_state()
29 | return state
30 |
31 | def set_rng_state(self, state):
32 | torch.set_rng_state(state["torch_rng_state"])
33 | if torch.cuda.is_available():
34 | torch.cuda.set_rng_state(state["cuda_rng_state"])
35 |
36 | def __enter__(self):
37 | return self
38 |
39 | def __exit__(self, *exc):
40 | self.set_rng_state(self.rng_state)
41 |
42 |
43 | def make_experts(args, embed_dim, expert_ffn_dim):
44 | world_size = (
45 | 1
46 | if not torch.distributed.is_initialized()
47 | else torch.distributed.get_world_size()
48 | )
49 | expert_list = []
50 | ddp_rank = args.ddp_rank
51 | start_seed = torch.randint(1000000, (1,)).item()
52 | # at least as many experts than gpus
53 | if args.moe_expert_count >= world_size:
54 | assert (
55 | args.moe_expert_count % world_size == 0
56 | ), f"{args.moe_expert_count}, {world_size}"
57 | local_moe_expert_count = args.moe_expert_count // world_size
58 | for i in range(local_moe_expert_count):
59 | with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
60 | expert_list.append(
61 | FeedForwardNetwork(
62 | embed_dim,
63 | expert_ffn_dim,
64 | args.activation_fn,
65 | args.dropout,
66 | args.activation_dropout,
67 | args.layernorm_eps,
68 | args.subln,
69 | )
70 | )
71 | else:
72 | assert (
73 | world_size % args.moe_expert_count == 0
74 | ), f"{world_size}, {args.moe_expert_count}"
75 |
76 | moe_idx, _ = get_moe_group(args.moe_expert_count)
77 |
78 | with set_torch_seed(start_seed + moe_idx):
79 | expert_list.append(
80 | FeedForwardNetwork(
81 | embed_dim,
82 | expert_ffn_dim,
83 | args.activation_fn,
84 | args.dropout,
85 | args.activation_dropout,
86 | args.layernorm_eps,
87 | args.subln,
88 | )
89 | )
90 | experts = nn.ModuleList(expert_list)
91 | return experts
92 |
93 |
94 | def get_activation_fn(activation):
95 | if activation == "relu":
96 | return F.relu
97 | elif activation == "gelu":
98 | return F.gelu
99 | elif activation == "swish":
100 | return F.silu
101 | else:
102 | raise NotImplementedError
103 |
104 |
105 | class FeedForwardNetwork(nn.Module):
106 | def __init__(
107 | self,
108 | embed_dim,
109 | ffn_dim,
110 | activation_fn,
111 | dropout,
112 | activation_dropout,
113 | layernorm_eps,
114 | subln=False,
115 | ):
116 | super().__init__()
117 | self.embed_dim = embed_dim
118 | self.activation_fn = get_activation_fn(activation=str(activation_fn))
119 | self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
120 | self.dropout_module = torch.nn.Dropout(dropout)
121 | self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
122 | self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
123 | self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
124 |
125 | def reset_parameters(self):
126 | self.fc1.reset_parameters()
127 | self.fc2.reset_parameters()
128 | if self.ffn_layernorm is not None:
129 | self.ffn_layernorm.reset_parameters()
130 |
131 | def forward(self, x):
132 | x_shape = x.shape
133 | x = x.reshape(-1, x.size(-1))
134 | x = self.fc1(x)
135 | x = self.activation_fn(x.float()).type_as(x)
136 | x = self.activation_dropout_module(x)
137 | if self.ffn_layernorm is not None:
138 | x = self.ffn_layernorm(x)
139 | x = self.fc2(x)
140 | x = x.view(x_shape)
141 | x = self.dropout_module(x)
142 | return x
143 |
--------------------------------------------------------------------------------
/torchscale/component/flash_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 |
5 | from typing import Any, Optional
6 | import torch
7 |
8 | if torch.cuda.is_available():
9 | try:
10 | if torch.cuda.get_device_capability()[0] > 7:
11 | from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
12 |
13 | def flash_attn_func(q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False):
14 | assert bias is None
15 | attn, lse, _ = _flash_attn_func(q, k, v, dropout_p=dropout, softmax_scale=softmax_scale, causal=is_causal, return_attn_probs=True)
16 | return attn, lse
17 |
18 | else:
19 | from xformers.ops.fmha import (
20 | cutlass,
21 | Inputs,
22 | Context,
23 | _memory_efficient_attention_forward_requires_grad,
24 | _memory_efficient_attention_backward,
25 | LowerTriangularMask,
26 | )
27 |
28 | class FlashAttnFunc(torch.autograd.Function):
29 | @staticmethod
30 | # type: ignore
31 | def forward(ctx, q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False):
32 | if is_causal:
33 | assert bias is None
34 | attn_bias = LowerTriangularMask()
35 | else:
36 | attn_bias = bias
37 |
38 | inp = Inputs(
39 | query=q,
40 | key=k,
41 | value=v,
42 | attn_bias=attn_bias,
43 | p=dropout,
44 | scale=softmax_scale,
45 | )
46 | op_fw = cutlass.FwOp
47 | op_bw = cutlass.BwOp
48 |
49 | out, op_ctx = _memory_efficient_attention_forward_requires_grad(
50 | inp=inp, op=op_fw
51 | )
52 |
53 | # Saving attn_bias is a bit complicated, as the
54 | # torch part should go in `save_for_backward`
55 | if isinstance(inp.attn_bias, torch.Tensor):
56 | attn_bias_tensor = inp.attn_bias
57 | attn_bias_ctx = None
58 | else:
59 | attn_bias_tensor = None
60 | attn_bias_ctx = inp.attn_bias
61 |
62 | ctx.save_for_backward(
63 | inp.query,
64 | inp.key,
65 | inp.value,
66 | op_ctx.out,
67 | op_ctx.lse,
68 | )
69 | ctx.rng_state = op_ctx.rng_state
70 | ctx.attn_bias_tensor = attn_bias_tensor
71 | if op_ctx.op_bw is not None:
72 | if op_bw is not None and op_bw is not op_ctx.op_bw:
73 | raise ValueError(
74 | f"Specified op_bw={op_bw.NAME}, but forward op "
75 | f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
76 | )
77 | op_bw = op_ctx.op_bw
78 | ctx.op_fw = op_fw
79 | ctx.op_bw = op_bw
80 | ctx.p = inp.p
81 |
82 | ctx.scale = inp.scale
83 | ctx.attn_bias_ctx = attn_bias_ctx
84 | return out, op_ctx.lse
85 |
86 | @staticmethod
87 | def deserialize_bias(
88 | attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]
89 | ) -> Any:
90 | if attn_bias_tensor is None:
91 | return attn_bias_ctx
92 | return attn_bias_tensor
93 |
94 | @classmethod
95 | @torch.autograd.function.once_differentiable
96 | def backward(cls, ctx, grad, dlse):
97 | # Re-create context
98 | query, key, value, out, lse = ctx.saved_tensors
99 | attn_bias_tensor = ctx.attn_bias_tensor
100 | rng_state = ctx.rng_state
101 | inp = Inputs(
102 | query=query,
103 | key=key,
104 | value=value,
105 | attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
106 | p=ctx.p,
107 | scale=ctx.scale,
108 | )
109 | op_ctx = Context(
110 | lse=lse,
111 | out=out,
112 | rng_state=rng_state,
113 | )
114 | grads = _memory_efficient_attention_backward(
115 | ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw
116 | )
117 | return grads.dq, grads.dk, grads.dv, None, grads.db, None, None
118 |
119 | flash_attn_func = FlashAttnFunc.apply
120 | except ModuleNotFoundError:
121 | flash_attn_func = None
122 | else:
123 | flash_attn_func = None
124 |
--------------------------------------------------------------------------------
/torchscale/component/gate_linear_unit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from .feedforward_network import get_activation_fn
9 |
10 |
11 | class GLU(nn.Module):
12 | def __init__(
13 | self,
14 | embed_dim,
15 | ffn_dim,
16 | activation_fn,
17 | dropout,
18 | activation_dropout,
19 | ):
20 | super().__init__()
21 | self.embed_dim = embed_dim
22 | self.activation_fn = get_activation_fn(activation=str(activation_fn))
23 | self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
24 | self.dropout_module = torch.nn.Dropout(dropout)
25 | self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False)
26 | self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False)
27 | self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False)
28 |
29 | def reset_parameters(self):
30 | self.fc1.reset_parameters()
31 | self.fc2.reset_parameters()
32 | self.gate.reset_parameters()
33 |
34 | def forward(self, x):
35 | x_shape = x.shape
36 | x = x.reshape(-1, x.size(-1))
37 | g = self.gate(x)
38 | x = self.fc1(x)
39 | x = self.activation_fn(x.float()).type_as(x) * g
40 | x = self.activation_dropout_module(x)
41 | x = self.fc2(x)
42 | x = x.view(x_shape)
43 | x = self.dropout_module(x)
44 | return x
45 |
--------------------------------------------------------------------------------
/torchscale/component/multihead_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 | from einops import rearrange
10 | try:
11 | from apex.normalization import FusedLayerNorm as LayerNorm
12 | except ModuleNotFoundError:
13 | from torch.nn import LayerNorm
14 |
15 | from .multiway_network import MultiwayWrapper
16 | from .xpos_relative_position import XPOS
17 | from .flash_attention import flash_attn_func
18 |
19 |
20 | class MultiheadAttention(nn.Module):
21 | def __init__(
22 | self,
23 | args,
24 | embed_dim,
25 | num_heads,
26 | dropout=0.0,
27 | self_attention=False,
28 | encoder_decoder_attention=False,
29 | subln=False,
30 | ):
31 | super().__init__()
32 | self.args = args
33 | self.embed_dim = embed_dim
34 | self.num_heads = num_heads
35 | self.head_dim = embed_dim // num_heads
36 | self.scaling = self.head_dim**-0.5
37 | self.dropout = dropout
38 |
39 | self.self_attention = self_attention
40 | self.encoder_decoder_attention = encoder_decoder_attention
41 | assert self.self_attention ^ self.encoder_decoder_attention
42 |
43 | self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
44 | self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
45 | self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
46 | self.out_proj = MultiwayWrapper(
47 | args, nn.Linear(embed_dim, embed_dim, bias=True)
48 | )
49 | self.inner_attn_ln = (
50 | MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
51 | if subln and self.self_attention
52 | else None
53 | )
54 | self.dropout_module = torch.nn.Dropout(dropout)
55 | self.xpos = (
56 | XPOS(self.head_dim, args.xpos_scale_base)
57 | if args.xpos_rel_pos and self.self_attention
58 | else None
59 | )
60 |
61 | def reset_parameters(self):
62 | nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
63 | nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
64 | nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
65 | nn.init.xavier_uniform_(self.out_proj.weight)
66 | nn.init.constant_(self.out_proj.bias, 0.0)
67 |
68 | def attention_ops(self, q, k, v, key_padding_mask=None, attn_mask=None, rel_pos=None, is_causal=False):
69 | if not self.args.flash_attention:
70 | q *= self.scaling
71 | attn_weights = torch.bmm(q, k.transpose(1, 2))
72 |
73 | if attn_mask is not None:
74 | attn_weights = torch.nan_to_num(attn_weights)
75 | attn_mask = attn_mask.unsqueeze(0)
76 | attn_weights += attn_mask
77 |
78 | if key_padding_mask is not None:
79 | attn_weights = rearrange(attn_weights, '(b h) t s -> b h t s', h=self.num_heads)
80 | attn_weights = attn_weights.masked_fill(
81 | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
82 | float("-inf"),
83 | )
84 | attn_weights = rearrange(attn_weights, 'b h t s -> (b h) t s')
85 |
86 | if rel_pos is not None:
87 | rel_pos = rel_pos.view(attn_weights.size())
88 | attn_weights = attn_weights + rel_pos
89 |
90 | attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
91 | attn_weights
92 | )
93 | attn_probs = self.dropout_module(attn_weights)
94 |
95 | attn = torch.bmm(attn_probs, v)
96 | attn = rearrange(attn, '(b h) l d -> b l (h d)', h=self.num_heads)
97 | else:
98 | assert flash_attn_func is not None
99 | assert rel_pos is None
100 | q = rearrange(q, '(b h) l d -> b l h d', h=self.num_heads)
101 | k = rearrange(k, '(b h) l d -> b l h d', h=self.num_heads)
102 | v = rearrange(v, '(b h) l d -> b l h d', h=self.num_heads)
103 | attn, lse = flash_attn_func(q, k, v, self.dropout, attn_mask, None, is_causal)
104 | attn = rearrange(attn, 'b l h d -> b l (h d)')
105 | attn_weights = lse[:, :, :attn.size(1)]
106 |
107 | return attn, attn_weights
108 |
109 | def forward(
110 | self,
111 | query,
112 | key,
113 | value,
114 | incremental_state=None,
115 | key_padding_mask=None,
116 | attn_mask=None,
117 | rel_pos=None,
118 | is_first_step=False,
119 | is_causal=False,
120 | ):
121 | bsz, tgt_len, embed_dim = query.size()
122 | src_len = tgt_len
123 | assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
124 |
125 | key_bsz, src_len, _ = key.size()
126 | assert key_bsz == bsz, f"{query.size(), key.size()}"
127 | assert value is not None
128 | assert bsz, src_len == value.shape[:2]
129 |
130 | q = self.q_proj(query)
131 | k = self.k_proj(key)
132 | v = self.v_proj(value)
133 |
134 | q = rearrange(q, 'b l (h d) -> (b h) l d', h=self.num_heads)
135 | k = rearrange(k, 'b l (h d) -> (b h) l d', h=self.num_heads)
136 | v = rearrange(v, 'b l (h d) -> (b h) l d', h=self.num_heads)
137 |
138 | if incremental_state is not None:
139 | if "prev_key" in incremental_state:
140 | prev_key = incremental_state["prev_key"].view(
141 | bsz * self.num_heads, -1, self.head_dim
142 | )
143 | prev_value = incremental_state["prev_value"].view(
144 | bsz * self.num_heads, -1, self.head_dim
145 | )
146 | k = torch.cat([prev_key, k], dim=1)
147 | v = torch.cat([prev_value, v], dim=1)
148 | incremental_state["prev_key"] = k.view(
149 | bsz, self.num_heads, -1, self.head_dim
150 | )
151 | incremental_state["prev_value"] = v.view(
152 | bsz, self.num_heads, -1, self.head_dim
153 | )
154 | src_len = k.size(1)
155 |
156 | if self.xpos is not None:
157 | if incremental_state is not None and not is_first_step:
158 | offset = src_len - 1
159 | else:
160 | offset = 0
161 | k = self.xpos(k, offset=0, downscale=True)
162 | q = self.xpos(q, offset=offset, downscale=False)
163 |
164 | attn, attn_weights = self.attention_ops(q, k, v, key_padding_mask=key_padding_mask, attn_mask=attn_mask, rel_pos=rel_pos, is_causal=is_causal)
165 |
166 | if self.inner_attn_ln is not None:
167 | attn = self.inner_attn_ln(attn)
168 |
169 | attn = self.out_proj(attn)
170 |
171 | return attn, attn_weights
172 |
--------------------------------------------------------------------------------
/torchscale/component/multiscale_retention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 | from .rms_norm import RMSNorm
9 |
10 | from .multiway_network import MultiwayWrapper
11 |
12 | def rotate_every_two(x):
13 | x1 = x[:, :, :, ::2]
14 | x2 = x[:, :, :, 1::2]
15 | x = torch.stack((-x2, x1), dim=-1)
16 | return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
17 |
18 | def duplicate_interleave(m):
19 | """
20 | A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
21 | """
22 | dim0 = m.shape[0]
23 | m = m.view(-1, 1) # flatten the matrix
24 | m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
25 | m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
26 | return m
27 |
28 | def theta_shift(x, sin, cos):
29 | return (x * cos) + (rotate_every_two(x) * sin)
30 |
31 | def get_activation_fn(activation):
32 | if activation == "swish":
33 | return F.silu
34 | elif activation == "gelu":
35 | return F.gelu
36 | else:
37 | raise NotImplementedError
38 |
39 | class MultiScaleRetention(nn.Module):
40 | def __init__(
41 | self,
42 | args,
43 | embed_dim,
44 | value_dim,
45 | num_heads,
46 | gate_fn="swish",
47 | ):
48 | super().__init__()
49 | self.args = args
50 | self.embed_dim = embed_dim
51 | self.value_dim = value_dim
52 | self.num_heads = num_heads
53 | self.head_dim = self.value_dim // num_heads
54 | self.key_dim = self.embed_dim // num_heads
55 | self.scaling = self.key_dim ** -0.5
56 |
57 | self.gate_fn = get_activation_fn(activation=str(gate_fn))
58 |
59 | self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
60 | self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
61 | self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
62 | self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
63 |
64 | self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False))
65 |
66 | self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
67 | self.reset_parameters()
68 |
69 | def reset_parameters(self):
70 | nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5)
71 | nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5)
72 | nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
73 | nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
74 | nn.init.xavier_uniform_(self.out_proj.weight, gain=2 ** -1)
75 |
76 | def parallel_forward(self, qr, kr, v, mask):
77 | bsz, tgt_len, embed_dim = v.size()
78 |
79 | vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
80 |
81 | qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len
82 | qk_mat = qk_mat * mask
83 | # invariant after normalization
84 | qk_mat = qk_mat / qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1, max=5e4)
85 | output = torch.matmul(qk_mat, vr)
86 | output = output.transpose(1, 2)
87 | return output
88 |
89 | def recurrent_forward(
90 | self,
91 | qr, kr, v,
92 | decay,
93 | incremental_state
94 | ):
95 | bsz = v.size(0)
96 |
97 | v = v.view(bsz, self.num_heads, self.head_dim, 1)
98 | kv = kr * v
99 | if "prev_key_value" in incremental_state:
100 | prev_kv = incremental_state["prev_key_value"]
101 | prev_scale = incremental_state["scale"]
102 | scale = prev_scale * decay + 1
103 | kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
104 | # kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
105 | else:
106 | scale = torch.ones_like(decay)
107 |
108 | incremental_state["prev_key_value"] = kv
109 | incremental_state["scale"] = scale
110 |
111 | output = torch.sum(qr * kv, dim=3)
112 | return output
113 |
114 | def chunk_recurrent_forward(
115 | self,
116 | qr, kr, v,
117 | inner_mask
118 | ):
119 | mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask
120 | bsz, tgt_len, embed_dim = v.size()
121 | chunk_len = mask.size(1)
122 | num_chunks = tgt_len // chunk_len
123 |
124 | assert tgt_len % chunk_len == 0
125 |
126 | qr = qr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
127 | kr = kr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
128 | v = v.view(bsz, num_chunks, chunk_len, self.num_heads, self.head_dim).transpose(2, 3)
129 |
130 | kr_t = kr.transpose(-1, -2)
131 |
132 | qk_mat = qr @ kr_t # bsz * num_heads * chunk_len * chunk_len
133 | qk_mat = qk_mat * mask
134 | inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1)
135 | qk_mat = qk_mat / inner_scale
136 | inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
137 |
138 | # reduce kv in one chunk
139 | kv = kr_t @ (v * value_inner_decay)
140 |
141 | kv_recurrent = []
142 | cross_scale = []
143 | kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
144 | kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v)
145 |
146 | # accumulate kv by loop
147 | for i in range(num_chunks):
148 | kv_recurrent.append(kv_state / kv_scale)
149 | cross_scale.append(kv_scale)
150 | kv_state = kv_state * cross_decay + kv[:, i]
151 | kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).max(dim=-1, keepdim=True).values.clamp(min=1)
152 |
153 | kv_recurrent = torch.stack(kv_recurrent, dim=1)
154 | cross_scale = torch.stack(cross_scale, dim=1)
155 |
156 | all_scale = torch.maximum(inner_scale, cross_scale)
157 | align_inner_scale = all_scale / inner_scale
158 | align_cross_scale = all_scale / cross_scale
159 |
160 | cross_output = (qr * query_inner_decay) @ kv_recurrent
161 | output = inner_output / align_inner_scale + cross_output / align_cross_scale
162 | # output = inner_output / cross_scale + cross_output / inner_scale
163 |
164 | output = output.transpose(2, 3)
165 | return output
166 |
167 | def forward(
168 | self,
169 | x,
170 | rel_pos,
171 | chunkwise_recurrent=False,
172 | incremental_state=None
173 | ):
174 | bsz, tgt_len, _ = x.size()
175 | (sin, cos), inner_mask = rel_pos
176 |
177 | q = self.q_proj(x)
178 | k = self.k_proj(x)
179 | v = self.v_proj(x)
180 | g = self.g_proj(x)
181 |
182 | k *= self.scaling
183 | q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
184 | k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
185 |
186 | qr = theta_shift(q, sin, cos)
187 | kr = theta_shift(k, sin, cos)
188 |
189 | if incremental_state is not None:
190 | output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
191 | elif chunkwise_recurrent:
192 | output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
193 | else:
194 | output = self.parallel_forward(qr, kr, v, inner_mask)
195 |
196 | output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
197 |
198 | output = self.gate_fn(g) * output
199 |
200 | output = self.out_proj(output)
201 |
202 | return output
203 |
204 |
205 |
--------------------------------------------------------------------------------
/torchscale/component/multiway_network.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import copy
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | def MultiwayWrapper(args, module, dim=1):
11 | if args.multiway:
12 | return MultiwayNetwork(module, dim=dim)
13 | return module
14 |
15 |
16 | def set_split_position(position):
17 | def apply_fn(module):
18 | if hasattr(module, "split_position"):
19 | module.split_position = position
20 |
21 | return apply_fn
22 |
23 |
24 | class MultiwayNetwork(nn.Module):
25 | def __init__(self, module, dim=1):
26 | super().__init__()
27 | self.dim = dim
28 | self.A = module
29 | self.B = copy.deepcopy(module)
30 | self.B.reset_parameters()
31 | self.split_position = -1
32 |
33 | def forward(self, x, **kwargs):
34 | if self.split_position == -1:
35 | return self.A(x, **kwargs)
36 | if self.split_position == 0:
37 | return self.B(x, **kwargs)
38 | x1, x2 = torch.split(
39 | x,
40 | [self.split_position, x.size(self.dim) - self.split_position],
41 | dim=self.dim,
42 | )
43 | # x1, x2 = x[:self.split_position], x[self.split_position:]
44 | y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
45 | return torch.cat([y1, y2], dim=self.dim)
46 |
47 |
48 | class MutliwayEmbedding(MultiwayNetwork):
49 | def __init__(self, modules, dim=1):
50 | super(MultiwayNetwork, self).__init__()
51 | self.dim = dim
52 | assert len(modules) == 2
53 | self.A = modules[0]
54 | self.B = modules[1]
55 | self.split_position = -1
--------------------------------------------------------------------------------
/torchscale/component/relative_position_bias.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class RelativePositionBias(nn.Module):
11 | def __init__(
12 | self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12
13 | ):
14 | super().__init__()
15 | self.bidirectional = bidirectional
16 | self.num_buckets = num_buckets
17 | self.max_distance = max_distance
18 | self.n_heads = n_heads
19 | self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
20 |
21 | @staticmethod
22 | def _relative_position_bucket(
23 | relative_position, bidirectional=True, num_buckets=32, max_distance=128
24 | ):
25 | ret = 0
26 | n = -relative_position
27 | if bidirectional:
28 | num_buckets //= 2
29 | ret += (n < 0).to(torch.long) * num_buckets
30 | n = torch.abs(n)
31 | else:
32 | n = torch.max(n, torch.zeros_like(n))
33 |
34 | max_exact = num_buckets // 2
35 | is_small = n < max_exact
36 |
37 | val_if_large = max_exact + (
38 | torch.log(n.float() / max_exact)
39 | / math.log(max_distance / max_exact)
40 | * (num_buckets - max_exact)
41 | ).to(torch.long)
42 | val_if_large = torch.min(
43 | val_if_large, torch.full_like(val_if_large, num_buckets - 1)
44 | )
45 |
46 | ret += torch.where(is_small, n, val_if_large)
47 | return ret
48 |
49 | def compute_bias(self, qlen, klen, step=None):
50 | step = 0 if step is None else step
51 | context_position = torch.arange(
52 | step,
53 | step + qlen,
54 | dtype=torch.long,
55 | device=self.relative_attention_bias.weight.device,
56 | )[:, None]
57 | memory_position = torch.arange(
58 | klen, dtype=torch.long, device=self.relative_attention_bias.weight.device
59 | )[None, :]
60 | relative_position = memory_position - context_position # shape (qlen, klen)
61 |
62 | rp_bucket = self._relative_position_bucket(
63 | relative_position, # shape (qlen, klen)
64 | bidirectional=self.bidirectional,
65 | num_buckets=self.num_buckets,
66 | max_distance=self.max_distance,
67 | )
68 | rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
69 | values = self.relative_attention_bias(
70 | rp_bucket
71 | ) # shape (qlen, klen, num_heads)
72 | values = values.permute([2, 0, 1]).unsqueeze(
73 | 0
74 | ) # shape (1, num_heads, qlen, klen)
75 | return values
76 |
77 | def forward(self, batch_size, qlen, klen, step=None):
78 | # shape (batch * num_heads, qlen, klen)
79 | return (
80 | self.compute_bias(qlen, klen, step)
81 | .repeat(batch_size, 1, 1, 1)
82 | .view(-1, qlen, klen)
83 | )
84 |
--------------------------------------------------------------------------------
/torchscale/component/rms_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | class RMSNorm(nn.Module):
8 | def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
9 | super().__init__()
10 | self.eps = eps
11 | self.elementwise_affine = elementwise_affine
12 | if self.elementwise_affine:
13 | self.weight = nn.Parameter(torch.ones(dim))
14 | else:
15 | self.register_parameter('weight', None)
16 |
17 | def _norm(self, x):
18 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
19 |
20 | def forward(self, x):
21 | output = self._norm(x.float()).type_as(x)
22 | if self.weight is not None:
23 | output = output * self.weight
24 | return output
25 |
--------------------------------------------------------------------------------
/torchscale/component/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch
5 | import torch.distributed as dist
6 |
7 | def padding_to_multiple_of(n, mult):
8 | remainder = n % mult
9 | if remainder == 0:
10 | return 0
11 | return mult - remainder
12 |
13 | def get_data_parallel_group():
14 | if torch.distributed.is_initialized():
15 | if not hasattr(get_data_parallel_group, "_global_group"):
16 | get_data_parallel_group._global_group = dist.new_group()
17 | return get_data_parallel_group._global_group
18 | else:
19 | return None
20 |
21 | def get_rank(group):
22 | return dist.get_rank(group=group)
23 |
24 | def get_world_size(group):
25 | if torch.distributed.is_initialized():
26 | return dist.get_world_size(group=group)
27 | else:
28 | return 1
29 |
30 | def get_data_parallel_rank():
31 | return get_rank(get_data_parallel_group())
32 |
33 | def get_data_parallel_world_size():
34 | return get_world_size(get_data_parallel_group())
35 |
36 |
37 | class Allgather(torch.autograd.Function):
38 |
39 | @staticmethod
40 | def forward(ctx, input_):
41 | world_size = get_data_parallel_world_size()
42 | dim_size = list(input_.size())
43 | dim_size[0] = dim_size[0] * world_size
44 |
45 | output = torch.empty(dim_size, dtype=input_.dtype,
46 | device=torch.cuda.current_device())
47 | torch.distributed._all_gather_base(output, input_.contiguous(),
48 | group=get_data_parallel_group())
49 |
50 | return output
51 |
52 | @staticmethod
53 | def backward(ctx, grad_output):
54 | world_size = get_data_parallel_world_size()
55 |
56 | dim_size = list(grad_output.size())
57 | assert dim_size[0] % world_size == 0, \
58 | "First dimension of the tensor should be divisible by tensor parallel size"
59 |
60 | dim_size[0] = dim_size[0] // world_size
61 |
62 | output = torch.empty(dim_size, dtype=grad_output.dtype,
63 | device=torch.cuda.current_device())
64 |
65 | torch.distributed._reduce_scatter_base(output, grad_output.contiguous(),
66 | group=get_data_parallel_group())
67 |
68 | return output
69 |
70 | all_gather_func = Allgather.apply
71 |
--------------------------------------------------------------------------------
/torchscale/component/xmoe/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------
/torchscale/component/xmoe/global_groups.py:
--------------------------------------------------------------------------------
1 | import torch.distributed as dist
2 |
3 |
4 | def _find_my_group_index(grouped_ranks):
5 | my_rank = dist.get_rank()
6 | for i, group in enumerate(grouped_ranks):
7 | if my_rank in group:
8 | return i
9 | raise RuntimeError
10 |
11 | def get_moe_group(moe_expert_count=None):
12 | if dist.is_initialized():
13 | if not hasattr(get_moe_group, "_moe_groups"):
14 | world_size = dist.get_world_size()
15 |
16 | if world_size <= moe_expert_count:
17 | assert moe_expert_count % world_size == 0
18 | moe_groups = [[i] for i in range(world_size)]
19 |
20 | else:
21 | assert world_size % moe_expert_count == 0
22 | ranks_per_group = world_size // moe_expert_count
23 | moe_groups = [
24 | [i + j * moe_expert_count for j in range(ranks_per_group)]
25 | for i in range(moe_expert_count)
26 | ]
27 |
28 | get_moe_group._moe_expert_count = moe_expert_count
29 | get_moe_group._moe_group_idx = moe_groups
30 | get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
31 |
32 | my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
33 | return my_group_idx, get_moe_group._moe_groups[my_group_idx]
34 |
35 |
36 | def get_all2all_group(moe_expert_count):
37 | if dist.is_initialized():
38 | if not hasattr(get_all2all_group, "_all2all_groups"):
39 | world_size = dist.get_world_size()
40 |
41 | # more experts than world size
42 | if world_size <= moe_expert_count:
43 | assert moe_expert_count % world_size == 0
44 | all2all_groups = [[i for i in range(world_size)]]
45 |
46 | # larger world than num experts
47 | else:
48 | assert world_size % moe_expert_count == 0
49 | ranks_per_group = world_size // moe_expert_count
50 | all2all_groups = [
51 | [i * moe_expert_count + j for j in range(moe_expert_count)]
52 | for i in range(ranks_per_group)
53 | ]
54 |
55 | get_all2all_group._all2all_group_idx = all2all_groups
56 | get_all2all_group._all2all_groups = [
57 | dist.new_group(g) for g in all2all_groups
58 | ]
59 |
60 | my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
61 | return get_all2all_group._all2all_groups[my_group_idx]
62 |
--------------------------------------------------------------------------------
/torchscale/component/xpos_relative_position.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 |
8 | def fixed_pos_embedding(x):
9 | seq_len, dim = x.shape
10 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim))
11 | sinusoid_inp = (
12 | torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x)
13 | )
14 | return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
15 |
16 | def rotate_every_two(x):
17 | x1 = x[:, :, ::2]
18 | x2 = x[:, :, 1::2]
19 | x = torch.stack((-x2, x1), dim=-1)
20 | return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
21 |
22 | def duplicate_interleave(m):
23 | """
24 | A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
25 | """
26 | dim0 = m.shape[0]
27 | m = m.view(-1, 1) # flatten the matrix
28 | m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
29 | m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
30 | return m
31 |
32 | def apply_rotary_pos_emb(x, sin, cos, scale=1):
33 | sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos))
34 | # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
35 | return (x * cos) + (rotate_every_two(x) * sin)
36 |
37 |
38 | class XPOS(nn.Module):
39 | def __init__(
40 | self, head_dim, scale_base=512
41 | ):
42 | super().__init__()
43 | self.head_dim = head_dim
44 | self.scale_base = scale_base
45 | self.register_buffer(
46 | "scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim)
47 | )
48 |
49 | def forward(self, x, offset=0, downscale=False):
50 | length = x.shape[1]
51 | min_pos = -(length + offset) // 2
52 | max_pos = length + offset + min_pos
53 | scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None]
54 | sin, cos = fixed_pos_embedding(scale)
55 |
56 | if scale.shape[0] > length:
57 | scale = scale[-length:]
58 | sin = sin[-length:]
59 | cos = cos[-length:]
60 |
61 | if downscale:
62 | scale = 1 / scale
63 |
64 | x = apply_rotary_pos_emb(x, sin, cos, scale)
65 | return x
66 |
--------------------------------------------------------------------------------
/torchscale/model/BEiT3.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torchscale.architecture.encoder import Encoder
8 | from torchscale.component.embedding import (
9 | PositionalEmbedding,
10 | TextEmbedding,
11 | VisionEmbedding,
12 | )
13 | from torchscale.component.multiway_network import MutliwayEmbedding
14 |
15 |
16 | class BEiT3(nn.Module):
17 | def __init__(self, args, **kwargs):
18 | super().__init__()
19 | self.args = args
20 | assert args.multiway
21 | assert args.vocab_size > 0
22 | assert not args.share_encoder_input_output_embed
23 | self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim)
24 | self.vision_embed = VisionEmbedding(
25 | args.img_size,
26 | args.patch_size,
27 | args.in_chans,
28 | args.encoder_embed_dim,
29 | contain_mask_token=True,
30 | prepend_cls_token=True,
31 | )
32 | # being consistent with Fairseq, which starts from 2 for position embedding
33 | embed_positions = MutliwayEmbedding(
34 | modules=[
35 | PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim),
36 | PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
37 | ],
38 | dim=1,
39 | )
40 | self.encoder = Encoder(
41 | args,
42 | embed_tokens=None,
43 | embed_positions=embed_positions,
44 | output_projection=None,
45 | is_encoder_decoder=False,
46 | )
47 |
48 | def forward(
49 | self,
50 | textual_tokens=None,
51 | visual_tokens=None,
52 | text_padding_position=None,
53 | attn_mask=None,
54 | vision_masked_position=None,
55 | incremental_state=None,
56 | positions=None,
57 | ):
58 | assert textual_tokens is not None or visual_tokens is not None
59 |
60 | if textual_tokens is None:
61 | x = self.vision_embed(visual_tokens, vision_masked_position)
62 | encoder_padding_mask = None
63 | multiway_split_position = -1
64 | elif visual_tokens is None:
65 | x = self.text_embed(textual_tokens)
66 | encoder_padding_mask = text_padding_position
67 | multiway_split_position = 0
68 | else:
69 | x1 = self.vision_embed(visual_tokens, vision_masked_position)
70 | multiway_split_position = x1.size(1)
71 | x2 = self.text_embed(textual_tokens)
72 | x = torch.cat([x1, x2], dim=1)
73 |
74 | if text_padding_position is not None:
75 | encoder_padding_mask = torch.cat(
76 | [
77 | torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
78 | text_padding_position,
79 | ],
80 | dim=1,
81 | )
82 | else:
83 | encoder_padding_mask = None
84 |
85 | encoder_out = self.encoder(
86 | src_tokens=None,
87 | encoder_padding_mask=encoder_padding_mask,
88 | attn_mask=attn_mask,
89 | token_embeddings=x,
90 | multiway_split_position=multiway_split_position,
91 | incremental_state=incremental_state,
92 | positions=positions,
93 | )
94 | encoder_out["multiway_split_position"] = multiway_split_position
95 |
96 | return encoder_out
97 |
--------------------------------------------------------------------------------
/torchscale/model/LongNet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
4 | from torchscale.architecture.decoder import Decoder, DecoderLayer
5 | from torchscale.architecture.encoder import Encoder, EncoderLayer
6 | from torchscale.component.dilated_attention import DilatedAttention
7 | from fairscale.nn import checkpoint_wrapper, wrap
8 |
9 |
10 | class LongNetDecoderLayer(DecoderLayer):
11 |
12 | def build_self_attention(self, embed_dim, args):
13 | return DilatedAttention(
14 | args,
15 | embed_dim,
16 | args.decoder_attention_heads,
17 | dropout=args.attention_dropout,
18 | self_attention=True,
19 | encoder_decoder_attention=False,
20 | subln=args.subln,
21 | )
22 |
23 | class LongNetDecoder(Decoder):
24 |
25 | def build_decoder_layer(
26 | self, args, depth, is_moe_layer=False, is_encoder_decoder=False
27 | ):
28 | layer = LongNetDecoderLayer(
29 | args,
30 | depth,
31 | is_moe_layer=is_moe_layer,
32 | is_encoder_decoder=is_encoder_decoder,
33 | )
34 | if args.checkpoint_activations:
35 | layer = checkpoint_wrapper(layer)
36 | if args.fsdp:
37 | layer = wrap(layer)
38 | return layer
39 |
40 | class LongNetEncoderLayer(EncoderLayer):
41 |
42 | def build_self_attention(self, embed_dim, args):
43 | return DilatedAttention(
44 | args,
45 | embed_dim,
46 | args.encoder_attention_heads,
47 | dropout=args.attention_dropout,
48 | self_attention=True,
49 | encoder_decoder_attention=False,
50 | subln=args.subln,
51 | )
52 |
53 | class LongNetEncoder(Encoder):
54 |
55 | def build_encoder_layer(
56 | self, args, depth, is_moe_layer=False, is_encoder_decoder=False
57 | ):
58 | layer = LongNetEncoderLayer(
59 | args,
60 | depth,
61 | is_moe_layer=is_moe_layer,
62 | is_encoder_decoder=is_encoder_decoder,
63 | )
64 | if args.checkpoint_activations:
65 | layer = checkpoint_wrapper(layer)
66 | if args.fsdp:
67 | layer = wrap(layer)
68 | return layer
69 |
--------------------------------------------------------------------------------
/torchscale/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 |
--------------------------------------------------------------------------------