├── .github └── workflows │ └── test.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── examples ├── __init__.py ├── fairseq │ ├── README.md │ ├── __init__.py │ ├── criterions │ │ ├── __init__.py │ │ └── masked_lm_moe.py │ ├── generate.py │ ├── interactive.py │ ├── models │ │ ├── __init__.py │ │ ├── bert.py │ │ ├── language_modeling.py │ │ ├── machine_translation.py │ │ └── retnet.py │ ├── tasks │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── basic_loader.py │ │ │ ├── mlm_loader.py │ │ │ └── utils.py │ │ └── pretraining.py │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ └── sparse_clip.py └── longvit │ ├── README.md │ ├── data_preprocessing │ ├── cache_transformed_images.py │ ├── convert_wsi_to_images.py │ ├── create_tcga_subtyping_index.py │ ├── create_tcga_survival_index.py │ ├── generate_1024_crops.py │ └── split_to_small_images.py │ ├── datasets.py │ ├── engine_for_finetuning.py │ ├── get_started │ ├── get_started_for_tcga_pretraining.md │ ├── get_started_for_tcga_subtyping.md │ └── get_started_for_tcga_survival_prediction.md │ ├── longvit.py │ ├── modeling_finetune.py │ ├── optim_factory.py │ ├── pretraining │ └── vision_transformer.py │ ├── requirements.txt │ ├── run_longvit_finetuning.py │ └── utils.py ├── setup.py ├── tests ├── __init__.py ├── test_decoder.py ├── test_encoder.py └── test_encoder_decoder.py └── torchscale ├── __init__.py ├── architecture ├── __init__.py ├── config.py ├── decoder.py ├── encoder.py ├── encoder_decoder.py ├── retnet.py └── utils.py ├── component ├── __init__.py ├── dilated_attention.py ├── droppath.py ├── embedding.py ├── feedforward_network.py ├── flash_attention.py ├── gate_linear_unit.py ├── multihead_attention.py ├── multiscale_retention.py ├── multiway_network.py ├── relative_position_bias.py ├── rms_norm.py ├── utils.py ├── xmoe │ ├── __init__.py │ ├── global_groups.py │ ├── moe_layer.py │ └── routing.py └── xpos_relative_position.py └── model ├── BEiT3.py ├── LongNet.py └── __init__.py /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 3.10 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: "3.10" 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 20 | if [ -f setup.py ]; then pip install .; fi 21 | - name: Install pytest 22 | run: | 23 | pip install pytest 24 | - name: Run tests 25 | run: | 26 | pytest tests/ 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Aa][Rr][Mm]/ 27 | [Aa][Rr][Mm]64/ 28 | bld/ 29 | [Bb]in/ 30 | [Oo]bj/ 31 | [Ll]og/ 32 | [Ll]ogs/ 33 | 34 | # Visual Studio 2015/2017 cache/options directory 35 | .vs/ 36 | # Uncomment if you have tasks that create the project's static files in wwwroot 37 | #wwwroot/ 38 | 39 | # Visual Studio 2017 auto generated files 40 | Generated\ Files/ 41 | 42 | # MSTest test Results 43 | [Tt]est[Rr]esult*/ 44 | [Bb]uild[Ll]og.* 45 | 46 | # NUnit 47 | *.VisualState.xml 48 | TestResult.xml 49 | nunit-*.xml 50 | 51 | # Build Results of an ATL Project 52 | [Dd]ebugPS/ 53 | [Rr]eleasePS/ 54 | dlldata.c 55 | 56 | # Benchmark Results 57 | BenchmarkDotNet.Artifacts/ 58 | 59 | # .NET Core 60 | project.lock.json 61 | project.fragment.lock.json 62 | artifacts/ 63 | 64 | # StyleCop 65 | StyleCopReport.xml 66 | 67 | # Files built by Visual Studio 68 | *_i.c 69 | *_p.c 70 | *_h.h 71 | *.ilk 72 | *.meta 73 | *.obj 74 | *.iobj 75 | *.pch 76 | *.pdb 77 | *.ipdb 78 | *.pgc 79 | *.pgd 80 | *.rsp 81 | *.sbr 82 | *.tlb 83 | *.tli 84 | *.tlh 85 | *.tmp 86 | *.tmp_proj 87 | *_wpftmp.csproj 88 | *.log 89 | *.vspscc 90 | *.vssscc 91 | .builds 92 | *.pidb 93 | *.svclog 94 | *.scc 95 | 96 | # Chutzpah Test files 97 | _Chutzpah* 98 | 99 | # Visual C++ cache files 100 | ipch/ 101 | *.aps 102 | *.ncb 103 | *.opendb 104 | *.opensdf 105 | *.sdf 106 | *.cachefile 107 | *.VC.db 108 | *.VC.VC.opendb 109 | 110 | # Visual Studio profiler 111 | *.psess 112 | *.vsp 113 | *.vspx 114 | *.sap 115 | 116 | # Visual Studio Trace Files 117 | *.e2e 118 | 119 | # TFS 2012 Local Workspace 120 | $tf/ 121 | 122 | # Guidance Automation Toolkit 123 | *.gpState 124 | 125 | # ReSharper is a .NET coding add-in 126 | _ReSharper*/ 127 | *.[Rr]e[Ss]harper 128 | *.DotSettings.user 129 | 130 | # TeamCity is a build add-in 131 | _TeamCity* 132 | 133 | # DotCover is a Code Coverage Tool 134 | *.dotCover 135 | 136 | # AxoCover is a Code Coverage Tool 137 | .axoCover/* 138 | !.axoCover/settings.json 139 | 140 | # Visual Studio code coverage results 141 | *.coverage 142 | *.coveragexml 143 | 144 | # NCrunch 145 | _NCrunch_* 146 | .*crunch*.local.xml 147 | nCrunchTemp_* 148 | 149 | # MightyMoose 150 | *.mm.* 151 | AutoTest.Net/ 152 | 153 | # Web workbench (sass) 154 | .sass-cache/ 155 | 156 | # Installshield output folder 157 | [Ee]xpress/ 158 | 159 | # DocProject is a documentation generator add-in 160 | DocProject/buildhelp/ 161 | DocProject/Help/*.HxT 162 | DocProject/Help/*.HxC 163 | DocProject/Help/*.hhc 164 | DocProject/Help/*.hhk 165 | DocProject/Help/*.hhp 166 | DocProject/Help/Html2 167 | DocProject/Help/html 168 | 169 | # Click-Once directory 170 | publish/ 171 | 172 | # Publish Web Output 173 | *.[Pp]ublish.xml 174 | *.azurePubxml 175 | # Note: Comment the next line if you want to checkin your web deploy settings, 176 | # but database connection strings (with potential passwords) will be unencrypted 177 | *.pubxml 178 | *.publishproj 179 | 180 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 181 | # checkin your Azure Web App publish settings, but sensitive information contained 182 | # in these scripts will be unencrypted 183 | PublishScripts/ 184 | 185 | # NuGet Packages 186 | *.nupkg 187 | # NuGet Symbol Packages 188 | *.snupkg 189 | # The packages folder can be ignored because of Package Restore 190 | **/[Pp]ackages/* 191 | # except build/, which is used as an MSBuild target. 192 | !**/[Pp]ackages/build/ 193 | # Uncomment if necessary however generally it will be regenerated when needed 194 | #!**/[Pp]ackages/repositories.config 195 | # NuGet v3's project.json files produces more ignorable files 196 | *.nuget.props 197 | *.nuget.targets 198 | 199 | # Microsoft Azure Build Output 200 | csx/ 201 | *.build.csdef 202 | 203 | # Microsoft Azure Emulator 204 | ecf/ 205 | rcf/ 206 | 207 | # Windows Store app package directories and files 208 | AppPackages/ 209 | BundleArtifacts/ 210 | Package.StoreAssociation.xml 211 | _pkginfo.txt 212 | *.appx 213 | *.appxbundle 214 | *.appxupload 215 | 216 | # Visual Studio cache files 217 | # files ending in .cache can be ignored 218 | *.[Cc]ache 219 | # but keep track of directories ending in .cache 220 | !?*.[Cc]ache/ 221 | 222 | # Others 223 | ClientBin/ 224 | ~$* 225 | *~ 226 | *.dbmdl 227 | *.dbproj.schemaview 228 | *.jfm 229 | *.pfx 230 | *.publishsettings 231 | orleans.codegen.cs 232 | 233 | # Including strong name files can present a security risk 234 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 235 | #*.snk 236 | 237 | # Since there are multiple workflows, uncomment next line to ignore bower_components 238 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 239 | #bower_components/ 240 | 241 | # RIA/Silverlight projects 242 | Generated_Code/ 243 | 244 | # Backup & report files from converting an old project file 245 | # to a newer Visual Studio version. Backup files are not needed, 246 | # because we have git ;-) 247 | _UpgradeReport_Files/ 248 | Backup*/ 249 | UpgradeLog*.XML 250 | UpgradeLog*.htm 251 | ServiceFabricBackup/ 252 | *.rptproj.bak 253 | 254 | # SQL Server files 255 | *.mdf 256 | *.ldf 257 | *.ndf 258 | 259 | # Business Intelligence projects 260 | *.rdl.data 261 | *.bim.layout 262 | *.bim_*.settings 263 | *.rptproj.rsuser 264 | *- [Bb]ackup.rdl 265 | *- [Bb]ackup ([0-9]).rdl 266 | *- [Bb]ackup ([0-9][0-9]).rdl 267 | 268 | # Microsoft Fakes 269 | FakesAssemblies/ 270 | 271 | # GhostDoc plugin setting file 272 | *.GhostDoc.xml 273 | 274 | # Node.js Tools for Visual Studio 275 | .ntvs_analysis.dat 276 | node_modules/ 277 | 278 | # Visual Studio 6 build log 279 | *.plg 280 | 281 | # Visual Studio 6 workspace options file 282 | *.opt 283 | 284 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 285 | *.vbw 286 | 287 | # Visual Studio LightSwitch build output 288 | **/*.HTMLClient/GeneratedArtifacts 289 | **/*.DesktopClient/GeneratedArtifacts 290 | **/*.DesktopClient/ModelManifest.xml 291 | **/*.Server/GeneratedArtifacts 292 | **/*.Server/ModelManifest.xml 293 | _Pvt_Extensions 294 | 295 | # Paket dependency manager 296 | .paket/paket.exe 297 | paket-files/ 298 | 299 | # FAKE - F# Make 300 | .fake/ 301 | 302 | # CodeRush personal settings 303 | .cr/personal 304 | 305 | # Python Tools for Visual Studio (PTVS) 306 | __pycache__/ 307 | *.pyc 308 | 309 | # Cake - Uncomment if you are using it 310 | # tools/** 311 | # !tools/packages.config 312 | 313 | # Tabs Studio 314 | *.tss 315 | 316 | # Telerik's JustMock configuration file 317 | *.jmconfig 318 | 319 | # BizTalk build output 320 | *.btp.cs 321 | *.btm.cs 322 | *.odx.cs 323 | *.xsd.cs 324 | 325 | # OpenCover UI analysis results 326 | OpenCover/ 327 | 328 | # Azure Stream Analytics local run output 329 | ASALocalRun/ 330 | 331 | # MSBuild Binary and Structured Log 332 | *.binlog 333 | 334 | # NVidia Nsight GPU debugger configuration file 335 | *.nvuser 336 | 337 | # MFractors (Xamarin productivity tool) working folder 338 | .mfractor/ 339 | 340 | # Local History for Visual Studio 341 | .localhistory/ 342 | 343 | # BeatPulse healthcheck temp database 344 | healthchecksdb 345 | 346 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 347 | MigrationBackup/ 348 | 349 | # Ionide (cross platform F# VS Code tools) working folder 350 | .ionide/ 351 | 352 | 353 | # Byte-compiled / optimized / DLL files 354 | __pycache__/ 355 | *.py[cod] 356 | *$py.class 357 | 358 | # C extensions 359 | *.so 360 | 361 | # Distribution / packaging 362 | .Python 363 | build/ 364 | develop-eggs/ 365 | dist/ 366 | downloads/ 367 | eggs/ 368 | .eggs/ 369 | lib/ 370 | lib64/ 371 | parts/ 372 | sdist/ 373 | var/ 374 | wheels/ 375 | share/python-wheels/ 376 | *.egg-info/ 377 | .installed.cfg 378 | *.egg 379 | MANIFEST 380 | 381 | # PyInstaller 382 | # Usually these files are written by a python script from a template 383 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 384 | *.manifest 385 | *.spec 386 | 387 | # Installer logs 388 | pip-log.txt 389 | pip-delete-this-directory.txt 390 | 391 | # Unit test / coverage reports 392 | htmlcov/ 393 | .tox/ 394 | .nox/ 395 | .coverage 396 | .coverage.* 397 | .cache 398 | nosetests.xml 399 | coverage.xml 400 | *.cover 401 | *.py,cover 402 | .hypothesis/ 403 | .pytest_cache/ 404 | cover/ 405 | 406 | # Translations 407 | *.mo 408 | *.pot 409 | 410 | # Django stuff: 411 | *.log 412 | local_settings.py 413 | db.sqlite3 414 | db.sqlite3-journal 415 | 416 | # Flask stuff: 417 | instance/ 418 | .webassets-cache 419 | 420 | # Scrapy stuff: 421 | .scrapy 422 | 423 | # Sphinx documentation 424 | docs/_build/ 425 | 426 | # PyBuilder 427 | .pybuilder/ 428 | target/ 429 | 430 | # Jupyter Notebook 431 | .ipynb_checkpoints 432 | 433 | # IPython 434 | profile_default/ 435 | ipython_config.py 436 | 437 | # pyenv 438 | # For a library or package, you might want to ignore these files since the code is 439 | # intended to run in multiple environments; otherwise, check them in: 440 | # .python-version 441 | 442 | # pipenv 443 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 444 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 445 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 446 | # install all needed dependencies. 447 | #Pipfile.lock 448 | 449 | # poetry 450 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 451 | # This is especially recommended for binary packages to ensure reproducibility, and is more 452 | # commonly ignored for libraries. 453 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 454 | #poetry.lock 455 | 456 | # pdm 457 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 458 | #pdm.lock 459 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 460 | # in version control. 461 | # https://pdm.fming.dev/#use-with-ide 462 | .pdm.toml 463 | 464 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 465 | __pypackages__/ 466 | 467 | # Celery stuff 468 | celerybeat-schedule 469 | celerybeat.pid 470 | 471 | # SageMath parsed files 472 | *.sage.py 473 | 474 | # Environments 475 | .env 476 | .venv 477 | env/ 478 | venv/ 479 | ENV/ 480 | env.bak/ 481 | venv.bak/ 482 | 483 | # Spyder project settings 484 | .spyderproject 485 | .spyproject 486 | 487 | # Rope project settings 488 | .ropeproject 489 | 490 | # mkdocs documentation 491 | /site 492 | 493 | # mypy 494 | .mypy_cache/ 495 | .dmypy.json 496 | dmypy.json 497 | 498 | # Pyre type checker 499 | .pyre/ 500 | 501 | # pytype static type analyzer 502 | .pytype/ 503 | 504 | # Cython debug symbols 505 | cython_debug/ 506 | 507 | # PyCharm 508 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 509 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 510 | # and can be added to the global gitignore or merged into this file. For a more nuclear 511 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 512 | #.idea/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchScale - A Library of Foundation Architectures 2 | 3 |

4 | MIT License 5 | MIT License 6 |

7 | 8 | TorchScale is a PyTorch library that allows researchers and developers to scale up Transformers efficiently and effectively. 9 | 10 | Fundamental research to develop new architectures for foundation models and A(G)I, focusing on modeling generality and capability, as well as training stability and efficiency. 11 | - Stability - [**DeepNet**](https://arxiv.org/abs/2203.00555): scaling Transformers to 1,000 Layers and beyond 12 | - Generality - [**Foundation Transformers (Magneto)**](https://arxiv.org/abs/2210.06423): towards true general-purpose modeling across tasks and modalities (including language, vision, speech, and multimodal) 13 | - Capability - A [**Length-Extrapolatable**](https://arxiv.org/abs/2212.10554) Transformer 14 | - Efficiency - [**X-MoE**](https://arxiv.org/abs/2204.09179): scalable & finetunable sparse Mixture-of-Experts (MoE) 15 | 16 | ### The Revolution of Model Architecture 17 | - [**BitNet**](https://arxiv.org/abs/2310.11453): 1-bit Transformers for Large Language Models 18 | - [**RetNet**](https://arxiv.org/abs/2307.08621): Retentive Network: A Successor to Transformer for Large Language Models 19 | - [**LongNet**](https://arxiv.org/abs/2307.02486): Scaling Transformers to 1,000,000,000 Tokens 20 | 21 | ## News 22 | 23 | - December, 2023: [LongNet](torchscale/model/LongNet.py) and [LongViT](examples/longvit/README.md) released 24 | - October, 2023: Update RMSNorm and SwiGLU as the default module in RetNet 25 | - November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)] 26 | 27 | ## Installation 28 | 29 | To install: 30 | ``` 31 | pip install torchscale 32 | ``` 33 | 34 | Alternatively, you can develop it locally: 35 | ``` 36 | git clone https://github.com/microsoft/torchscale.git 37 | cd torchscale 38 | pip install -e . 39 | ``` 40 | 41 | For faster training install [Flash Attention](https://github.com/Dao-AILab/flash-attention) for Turing, Ampere, Ada, or Hopper GPUs: 42 | ``` 43 | pip install flash-attn 44 | ``` 45 | or [xFormers](https://github.com/facebookresearch/xformers) for Volta, Turing, Ampere, Ada, or Hopper GPUs: 46 | ``` 47 | # cuda 11.8 version 48 | pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118 49 | # cuda 12.1 version 50 | pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121 51 | ``` 52 | 53 | ## Getting Started 54 | 55 | It takes only several lines of code to create a model with the above fundamental research features enabled. Here is how to quickly obtain a BERT-like encoder: 56 | 57 | ```python 58 | >>> from torchscale.architecture.config import EncoderConfig 59 | >>> from torchscale.architecture.encoder import Encoder 60 | 61 | >>> config = EncoderConfig(vocab_size=64000) 62 | >>> model = Encoder(config) 63 | 64 | >>> print(model) 65 | ``` 66 | 67 | We also support the `Decoder` architecture and the `EncoderDecoder` architecture: 68 | 69 | ```python 70 | # Creating a decoder model 71 | >>> from torchscale.architecture.config import DecoderConfig 72 | >>> from torchscale.architecture.decoder import Decoder 73 | 74 | >>> config = DecoderConfig(vocab_size=64000) 75 | >>> decoder = Decoder(config) 76 | >>> print(decoder) 77 | 78 | # Creating a encoder-decoder model 79 | >>> from torchscale.architecture.config import EncoderDecoderConfig 80 | >>> from torchscale.architecture.encoder_decoder import EncoderDecoder 81 | 82 | >>> config = EncoderDecoderConfig(vocab_size=64000) 83 | >>> encdec = EncoderDecoder(config) 84 | >>> print(encdec) 85 | ``` 86 | 87 | It takes only several lines of code to create a RetNet model: 88 | 89 | ```python 90 | # Creating a RetNet model 91 | >>> import torch 92 | >>> from torchscale.architecture.config import RetNetConfig 93 | >>> from torchscale.architecture.retnet import RetNetDecoder 94 | 95 | >>> config = RetNetConfig(vocab_size=64000) 96 | >>> retnet = RetNetDecoder(config) 97 | 98 | >>> print(retnet) 99 | ``` 100 | 101 | For LongNet models ([Flash Attention](https://github.com/Dao-AILab/flash-attention) required): 102 | ```python 103 | >>> import torch 104 | >>> from torchscale.architecture.config import EncoderConfig, DecoderConfig 105 | >>> from torchscale.model.longnet import LongNetEncoder, LongNetDecoder 106 | 107 | # Creating a LongNet encoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2] 108 | >>> config = EncoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True) 109 | >>> longnet = LongNetEncoder(config) 110 | 111 | # Creating a LongNet decoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2] 112 | >>> config = DecoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True) 113 | >>> longnet = LongNetDecoder(config) 114 | ``` 115 | 116 | ## Key Features 117 | 118 | - [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555) 119 | * enabled by setting *deepnorm=True* in the `Config` class. 120 | * It adjusts both the residual connection and the initialization method according to the model architecture (i.e., encoder, decoder, or encoder-decoder). 121 | 122 | - [SubLN for the model generality and the training stability](https://arxiv.org/abs/2210.06423) 123 | * enabled by *subln=True*. This is enabled by default. 124 | * It introduces another LayerNorm to each sublayer and adjusts the initialization according to the model architecture. 125 | * Note that SubLN and DeepNorm cannot be used in one single model. 126 | 127 | - [X-MoE: efficient and finetunable sparse MoE modeling](https://arxiv.org/abs/2204.09179) 128 | * enabled by *use_xmoe=True*. 129 | * It replaces every *'moe_freq'* `FeedForwardNetwork` layers with the X-MoE layers. 130 | 131 | - [Multiway architecture for multimodality](https://arxiv.org/abs/2208.10442) 132 | * enabled by *multiway=True*. 133 | * It provides a pool of Transformer's parameters used for different modalities. 134 | 135 | - [Extrapolatable position embedding (Xpos)](https://arxiv.org/abs/2212.10554) 136 | * enabled by *xpos_rel_pos=True*. 137 | 138 | - [Relative position bias](https://arxiv.org/abs/1910.10683) 139 | * enabled by adjusting *rel_pos_buckets* and *max_rel_pos*. 140 | 141 | - [SparseClip: improving the gradient clipping for sparse MoE models](https://arxiv.org/abs/2211.13184) 142 | * we provide a [sample code](examples/fairseq/utils/sparse_clip.py) that can be easily adapted to the FairSeq (or other) repo. 143 | 144 | - [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/abs/2307.08621) 145 | * created by `config = RetNetConfig(vocab_size=64000)` and `retnet = RetNetDecoder(config)`. 146 | 147 | - [LongNet: Scaling Transformers to 1,000,000,000 Tokens](https://arxiv.org/abs/2307.02486) 148 | 149 | Most of the features above can be used by simply passing the corresponding parameters to the config. For example: 150 | 151 | ```python 152 | >>> from torchscale.architecture.config import EncoderConfig 153 | >>> from torchscale.architecture.encoder import Encoder 154 | 155 | >>> config = EncoderConfig(vocab_size=64000, deepnorm=True, multiway=True) 156 | >>> model = Encoder(config) 157 | 158 | >>> print(model) 159 | ``` 160 | 161 | ## Examples 162 | 163 | We have examples of how to use TorchScale in the following scenarios/tasks: 164 | 165 | - Language 166 | 167 | * [Decoder/GPT](examples/fairseq/README.md#example-gpt-pretraining) 168 | 169 | * [Encoder-Decoder/Neural Machine Translation](examples/fairseq/README.md#example-machine-translation) 170 | 171 | * [Encoder/BERT](examples/fairseq/README.md#example-bert-pretraining) 172 | 173 | - Vision 174 | 175 | * [LongViT](examples/longvit/README.md) 176 | 177 | * ViT/BEiT [In progress] 178 | 179 | - Speech 180 | 181 | - Multimodal 182 | 183 | * [Multiway Transformers/BEiT-3](https://github.com/microsoft/unilm/tree/master/beit3) 184 | 185 | We plan to provide more examples regarding different tasks (e.g. vision pretraining and speech recognition) and various deep learning toolkits (e.g. [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). Any comments or PRs are welcome! 186 | 187 | 188 | ## Acknowledgments 189 | 190 | Some implementations in TorchScale are either adapted from or inspired by the [FairSeq](https://github.com/facebookresearch/fairseq) repository and the [UniLM](https://github.com/microsoft/unilm) repository. 191 | 192 | ## Citations 193 | 194 | If you find this repository useful, please consider citing our work: 195 | 196 | ``` 197 | @article{torchscale, 198 | author = {Shuming Ma and Hongyu Wang and Shaohan Huang and Wenhui Wang and Zewen Chi and Li Dong and Alon Benhaim and Barun Patra and Vishrav Chaudhary and Xia Song and Furu Wei}, 199 | title = {{TorchScale}: {Transformers} at Scale}, 200 | journal = {CoRR}, 201 | volume = {abs/2211.13184}, 202 | year = {2022} 203 | } 204 | ``` 205 | 206 | ``` 207 | @article{deepnet, 208 | author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei}, 209 | title = {{DeepNet}: Scaling {Transformers} to 1,000 Layers}, 210 | journal = {CoRR}, 211 | volume = {abs/2203.00555}, 212 | year = {2022}, 213 | } 214 | ``` 215 | 216 | ``` 217 | @article{magneto, 218 | author = {Hongyu Wang and Shuming Ma and Shaohan Huang and Li Dong and Wenhui Wang and Zhiliang Peng and Yu Wu and Payal Bajaj and Saksham Singhal and Alon Benhaim and Barun Patra and Zhun Liu and Vishrav Chaudhary and Xia Song and Furu Wei}, 219 | title = {Foundation {Transformers}}, 220 | journal = {CoRR}, 221 | volume = {abs/2210.06423}, 222 | year = {2022} 223 | } 224 | ``` 225 | 226 | ``` 227 | @inproceedings{xmoe, 228 | title={On the Representation Collapse of Sparse Mixture of Experts}, 229 | author={Zewen Chi and Li Dong and Shaohan Huang and Damai Dai and Shuming Ma and Barun Patra and Saksham Singhal and Payal Bajaj and Xia Song and Xian-Ling Mao and Heyan Huang and Furu Wei}, 230 | booktitle={Advances in Neural Information Processing Systems}, 231 | year={2022}, 232 | url={https://openreview.net/forum?id=mWaYC6CZf5} 233 | } 234 | ``` 235 | 236 | ``` 237 | @article{retnet, 238 | author={Yutao Sun and Li Dong and Shaohan Huang and Shuming Ma and Yuqing Xia and Jilong Xue and Jianyong Wang and Furu Wei}, 239 | title = {Retentive Network: A Successor to {Transformer} for Large Language Models}, 240 | journal = {ArXiv}, 241 | volume = {abs/2307.08621}, 242 | year = {2023} 243 | } 244 | ``` 245 | 246 | ``` 247 | @article{longnet, 248 | author={Jiayu Ding and Shuming Ma and Li Dong and Xingxing Zhang and Shaohan Huang and Wenhui Wang and Nanning Zheng and Furu Wei}, 249 | title = {{LongNet}: Scaling Transformers to 1,000,000,000 Tokens}, 250 | journal = {ArXiv}, 251 | volume = {abs/2307.02486}, 252 | year = {2023} 253 | } 254 | ``` 255 | 256 | ``` 257 | @article{longvit, 258 | title = {When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology}, 259 | author = {Wenhui Wang and Shuming Ma and Hanwen Xu and Naoto Usuyama and Jiayu Ding and Hoifung Poon and Furu Wei}, 260 | journal = {ArXiv}, 261 | volume = {abs/2312.03558}, 262 | year = {2023} 263 | } 264 | ``` 265 | 266 | ## Contributing 267 | 268 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 269 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 270 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 271 | 272 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 273 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 274 | provided by the bot. You will only need to do this once across all repos using our CLA. 275 | 276 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 277 | For more information, see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 278 | contact [Furu Wei](mailto:fuwei@microsoft.com) and [Shuming Ma](mailto:shumma@microsoft.com) with any additional questions or comments. 279 | 280 | ## Trademarks 281 | 282 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 283 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 284 | Any use of third-party trademarks or logos is subject to those third-party's policies. 285 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /examples/fairseq/README.md: -------------------------------------------------------------------------------- 1 | # Example: Integration with FairSeq 2 | 3 | ## Setup 4 | 5 | ```bash 6 | # Install the repo as a package: 7 | git clone https://github.com/microsoft/torchscale.git 8 | cd torchscale 9 | pip install -e . 10 | pip install git+https://github.com/shumingma/fairseq.git@moe 11 | pip install git+https://github.com/shumingma/infinibatch.git 12 | pip install iopath 13 | pip install numpy==1.23.0 14 | ``` 15 | 16 | ## Example: BERT Pretraining 17 | 18 | ### Data Format 19 | 20 | We use a [streaming dataloader](https://github.com/microsoft/infinibatch) to read the data on-the-fly from the disk. It requires the data sharded into multiple small files (e.g. 10K lines per file), as well as a JSON file to contain some meta data and the paths to these files. 21 | 22 | The overall data directory should be organized as follows: 23 | ``` 24 | Data/ 25 | ├── json/ 26 | │ ├── train.json 27 | │ └── valid.json 28 | ├── shard/ 29 | │ ├── train/ 30 | │ │ ├── 00000.txt 31 | │ │ ├── 00001.txt 32 | │ │ └── ... 33 | │ └── valid/ 34 | │ ├── 00000.txt 35 | │ ├── 00001.txt 36 | │ └── ... 37 | ├── dict.txt 38 | └── sentencepiece.bpe.model 39 | ``` 40 | 41 | We recommend that each sharded data files contains no more than 10K lines with one sentence per line, and two documents should be separated with an empty line. 42 | ``` 43 | Document 1 Line 1 44 | Document 1 Line 2 45 | Document 1 Line 3 46 | 47 | Document 2 Line 1 48 | Document 2 Line 2 49 | 50 | ... 51 | ``` 52 | 53 | Also, the JSON file should be in the format like this: 54 | ``` 55 | [ 56 | { 57 | "source": [ 58 | "shard/train/00000.txt", 59 | "shard/train/00001.txt", 60 | ... 61 | ], 62 | "source_lang": "en", 63 | "weight": 1.0 64 | } 65 | ] 66 | ``` 67 | 68 | You can quickly get started with our processed vocabulary files: [sentencepiece.bpe.model] and [dict.txt]. Note that this vocabulary is English-only with 64K tokens. To train a new `sentencepiece.bpe.model` on your own data, please refer to the [SentencePiece](https://github.com/google/sentencepiece) repo. With the sentecepiece model and the installed `sentencepiece` library, you can extract the `dict.txt` file from it by 69 | ``` 70 | spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt 71 | ``` 72 | 73 | ### Dense Model 74 | ```bash 75 | cd examples/fairseq/ 76 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \ 77 | --task pretraining \ 78 | --tokens-per-sample 512 \ 79 | --mask-prob 0.15 \ 80 | --span-length 3.0 \ 81 | --leave-unmasked-prob 0.0 \ 82 | --random-token-prob 0.0 \ 83 | --criterion masked_lm \ 84 | --arch mlm_base \ 85 | --share-encoder-input-output-embed \ 86 | --required-batch-size-multiple 8 \ 87 | --spm-model ${PATH_TO_DATA}/sentencepiece.bpe.model \ 88 | --dict-file ${PATH_TO_DATA}/dict.txt \ 89 | --optimizer adam \ 90 | --adam-betas '(0.9,0.98)' \ 91 | --adam-eps 1e-6 \ 92 | --clip-norm 2.0 \ 93 | --lr-scheduler polynomial_decay \ 94 | --lr 0.0005 \ 95 | --warmup-updates 10000 \ 96 | --total-num-update 125000 \ 97 | --max-update 125000 \ 98 | --max-sentences 32 \ 99 | --update-freq 1 \ 100 | --log-format simple \ 101 | --log-interval 100 \ 102 | --disable-validation \ 103 | --save-interval-updates 5000 \ 104 | --no-epoch-checkpoints \ 105 | --fp16 \ 106 | --fp16-init-scale 4 \ 107 | --fp16-scale-window 256 \ 108 | --min-loss-scale 0.0001 \ 109 | --seed 1 \ 110 | --save-dir ${PATH_TO_CKPT} \ 111 | --ddp-backend=no_c10d \ 112 | --distributed-no-spawn \ 113 | --reset-dataloader \ 114 | --batch-read-ahead 10000 \ 115 | --rel-pos-buckets 32 \ 116 | --max-rel-pos 128 \ 117 | --deepnorm 118 | ``` 119 | 120 | ### Sparse (MoE) Model 121 | ```bash 122 | cd examples/fairseq/ 123 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \ 124 | --task pretraining \ 125 | --tokens-per-sample 512 \ 126 | --mask-prob 0.15 \ 127 | --span-length 3.0 \ 128 | --leave-unmasked-prob 0.0 \ 129 | --random-token-prob 0.0 \ 130 | --arch mlm_base \ 131 | --share-encoder-input-output-embed \ 132 | --required-batch-size-multiple 8 \ 133 | --spm-model ${PATH_TO_DATA}/sentencepiece.bpe.model \ 134 | --dict-file ${PATH_TO_DATA}/dict.txt \ 135 | --optimizer adam \ 136 | --adam-betas '(0.9,0.98)' \ 137 | --adam-eps 1e-6 \ 138 | --clip-norm 2.0 \ 139 | --lr-scheduler polynomial_decay \ 140 | --lr 0.0005 \ 141 | --warmup-updates 10000 \ 142 | --total-num-update 125000 \ 143 | --max-update 125000 \ 144 | --max-sentences 32 \ 145 | --update-freq 1 \ 146 | --log-format simple \ 147 | --log-interval 100 \ 148 | --disable-validation \ 149 | --save-interval-updates 5000 \ 150 | --no-epoch-checkpoints \ 151 | --fp16 \ 152 | --fp16-init-scale 4 \ 153 | --fp16-scale-window 256 \ 154 | --min-loss-scale 0.0001 \ 155 | --seed 1 \ 156 | --save-dir ${PATH_TO_CKPT} \ 157 | --ddp-backend=no_c10d \ 158 | --distributed-no-spawn \ 159 | --reset-dataloader \ 160 | --batch-read-ahead 10000 \ 161 | --rel-pos-buckets 32 \ 162 | --max-rel-pos 128 \ 163 | --deepnorm \ 164 | --moe-expert-count 64 --moe-freq 2 \ 165 | --moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \ 166 | --moe-eval-capacity-token-fraction -1.0 \ 167 | --criterion masked_lm_moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \ 168 | --use-xmoe --pad-to-max-length 169 | ``` 170 | 171 | ## Example: GPT Pretraining 172 | 173 | ### Data Format 174 | 175 | We use the format as in the FairSeq's [language modeling example](https://github.com/facebookresearch/fairseq/tree/main/examples/language_model#1-preprocess-the-data). 176 | 177 | ### Dense Model 178 | 179 | ```bash 180 | cd examples/fairseq/ 181 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \ 182 | ${PATH_TO_DATA} \ 183 | --num-workers 2 \ 184 | --activation-fn gelu \ 185 | --share-decoder-input-output-embed \ 186 | --validate-interval-updates 1000 \ 187 | --save-interval-updates 1000 \ 188 | --no-epoch-checkpoints \ 189 | --memory-efficient-fp16 \ 190 | --fp16-init-scale 4 \ 191 | --arch lm_base \ 192 | --task language_modeling \ 193 | --sample-break-mode none \ 194 | --tokens-per-sample 128 \ 195 | --optimizer adam --adam-betas "(0.9, 0.98)" \ 196 | --adam-eps 1e-08 \ 197 | --clip-norm 0.0 \ 198 | --lr 5e-4 \ 199 | --lr-scheduler polynomial_decay \ 200 | --warmup-updates 750 \ 201 | --dropout 0.1 \ 202 | --attention-dropout 0.1 \ 203 | --weight-decay 0.01 \ 204 | --batch-size 4 \ 205 | --update-freq 1 \ 206 | --required-batch-size-multiple 1 \ 207 | --total-num-update 50000 \ 208 | --max-update 50000 \ 209 | --seed 1 \ 210 | --ddp-backend=c10d 211 | ``` 212 | 213 | ### Sparse (MoE) Model 214 | 215 | ```bash 216 | cd examples/fairseq/ 217 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \ 218 | ${PATH_TO_DATA} \ 219 | --num-workers 2 \ 220 | --activation-fn gelu \ 221 | --share-decoder-input-output-embed \ 222 | --validate-interval-updates 1000 \ 223 | --save-interval-updates 1000 \ 224 | --no-epoch-checkpoints \ 225 | --memory-efficient-fp16 \ 226 | --fp16-init-scale 4 \ 227 | --arch lm_base \ 228 | --task language_modeling \ 229 | --sample-break-mode none \ 230 | --tokens-per-sample 128 \ 231 | --optimizer adam --adam-betas "(0.9, 0.98)" \ 232 | --adam-eps 1e-08 \ 233 | --clip-norm 0.0 \ 234 | --lr 5e-4 \ 235 | --lr-scheduler polynomial_decay \ 236 | --warmup-updates 750 \ 237 | --dropout 0.1 \ 238 | --attention-dropout 0.1 \ 239 | --weight-decay 0.01 \ 240 | --batch-size 4 \ 241 | --update-freq 1 \ 242 | --required-batch-size-multiple 1 \ 243 | --total-num-update 50000 \ 244 | --max-update 50000 \ 245 | --seed 1 \ 246 | --ddp-backend=no_c10d \ 247 | --moe-expert-count 2 --moe-freq 2 \ 248 | --moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \ 249 | --moe-eval-capacity-token-fraction -1.0 \ 250 | --criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \ 251 | --use-xmoe 252 | ``` 253 | 254 | ### LongNet Model 255 | 256 | ```bash 257 | cd examples/fairseq/ 258 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \ 259 | ${PATH_TO_DATA} \ 260 | --num-workers 2 \ 261 | --activation-fn gelu \ 262 | --share-decoder-input-output-embed \ 263 | --validate-interval-updates 1000 \ 264 | --save-interval-updates 1000 \ 265 | --no-epoch-checkpoints \ 266 | --memory-efficient-fp16 \ 267 | --fp16-init-scale 4 \ 268 | --arch lm_base \ 269 | --task language_modeling \ 270 | --sample-break-mode none \ 271 | --tokens-per-sample 4096 \ 272 | --optimizer adam --adam-betas "(0.9, 0.98)" \ 273 | --adam-eps 1e-08 \ 274 | --clip-norm 0.0 \ 275 | --lr 5e-4 \ 276 | --lr-scheduler polynomial_decay \ 277 | --warmup-updates 750 \ 278 | --dropout 0.1 \ 279 | --attention-dropout 0.1 \ 280 | --weight-decay 0.01 \ 281 | --batch-size 4 \ 282 | --update-freq 1 \ 283 | --required-batch-size-multiple 1 \ 284 | --total-num-update 50000 \ 285 | --max-update 50000 \ 286 | --seed 1 \ 287 | --ddp-backend=c10d \ 288 | --flash-attention \ 289 | --segment-length [2048,4096] \ 290 | --dilated-ratio [1,2] 291 | ``` 292 | 293 | ## Example: Machine Translation 294 | 295 | ### Data Format 296 | 297 | We follow the FairSeq's [neural machine translation example](https://github.com/facebookresearch/fairseq/tree/main/examples/translation#training-a-new-model) to preprocess the data. 298 | 299 | ### Dense Model 300 | 301 | ```bash 302 | cd examples/fairseq/ 303 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \ 304 | ${PATH_TO_DATA} \ 305 | --arch mt_base --share-decoder-input-output-embed \ 306 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 307 | --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ 308 | --dropout 0.3 --weight-decay 0.0001 \ 309 | --max-tokens 4096 --fp16 310 | ``` 311 | 312 | ### Sparse (MoE) Model 313 | 314 | ```bash 315 | cd examples/fairseq/ 316 | python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \ 317 | ${PATH_TO_DATA} \ 318 | --arch mt_base --share-decoder-input-output-embed \ 319 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 320 | --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ 321 | --dropout 0.3 --weight-decay 0.0001 \ 322 | --moe-expert-count 2 --moe-freq 2 \ 323 | --moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \ 324 | --moe-eval-capacity-token-fraction -1.0 \ 325 | --criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \ 326 | --use-xmoe \ 327 | --max-tokens 4096 --fp16 328 | ``` 329 | -------------------------------------------------------------------------------- /examples/fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /examples/fairseq/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in sorted(os.listdir(os.path.dirname(__file__))): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("criterions." + file_name) -------------------------------------------------------------------------------- /examples/fairseq/criterions/masked_lm_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | from fairseq import metrics, utils 10 | from fairseq.criterions import MoECriterion, register_criterion, MoECriterionConfig 11 | 12 | 13 | @register_criterion("masked_lm_moe_cross_entropy", dataclass=MoECriterionConfig) 14 | class MaskedLMMoECrossEntropyCriterion(MoECriterion): 15 | 16 | def compute_inner_loss(self, model, sample, reduce=True): 17 | masked_tokens = sample["target"].ne(self.padding_idx) 18 | sample_size = masked_tokens.int().sum() 19 | 20 | masked_tokens = torch.where( 21 | masked_tokens.any(), 22 | masked_tokens, 23 | masked_tokens.new([True]), 24 | ) 25 | 26 | net_output = model(**sample["net_input"], masked_tokens=masked_tokens) 27 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 28 | lprobs = lprobs.view(-1, lprobs.size(-1)) 29 | target = model.get_targets(sample, net_output) 30 | 31 | if masked_tokens is not None: 32 | target = target[masked_tokens] 33 | 34 | nll_loss = F.nll_loss( 35 | lprobs, 36 | target.view(-1), 37 | ignore_index=self.padding_idx, 38 | reduction="sum" if reduce else "none", 39 | ) 40 | logging_output = { 41 | "inner_loss": nll_loss.data, 42 | "ntokens": sample["ntokens"], 43 | "nsentences": sample["target"].size(0), 44 | "sample_size": sample_size, 45 | } 46 | return net_output, nll_loss, sample_size, logging_output 47 | 48 | @staticmethod 49 | def reduce_metrics(logging_outputs) -> None: 50 | """Aggregate logging outputs from data parallel training.""" 51 | MaskedLMMoECrossEntropyCriterion.reduce_moe_metrics(logging_outputs) 52 | 53 | loss_sum = sum(log.get("inner_loss", 0) for log in logging_outputs) 54 | ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) 55 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 56 | 57 | # we divide by log(2) to convert the loss from base e to base 2 58 | metrics.log_scalar( 59 | "inner_loss", loss_sum / sample_size / math.log(2), sample_size, round=3 60 | ) 61 | if sample_size != ntokens: 62 | metrics.log_scalar( 63 | "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 64 | ) 65 | metrics.log_derived( 66 | "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) 67 | ) 68 | else: 69 | metrics.log_derived( 70 | "ppl", lambda meters: utils.get_perplexity(meters["inner_loss"].avg) 71 | ) -------------------------------------------------------------------------------- /examples/fairseq/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | # flake8: noqa 5 | import models 6 | import tasks 7 | import criterions 8 | from fairseq_cli.generate import cli_main 9 | 10 | if __name__ == "__main__": 11 | cli_main() 12 | -------------------------------------------------------------------------------- /examples/fairseq/interactive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | # flake8: noqa 5 | import models 6 | import tasks 7 | import criterions 8 | from fairseq_cli.interactive import cli_main 9 | 10 | if __name__ == "__main__": 11 | cli_main() 12 | -------------------------------------------------------------------------------- /examples/fairseq/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import argparse 5 | import importlib 6 | import os 7 | 8 | MODEL_REGISTRY = {} 9 | MODEL_DATACLASS_REGISTRY = {} 10 | ARCH_MODEL_REGISTRY = {} 11 | ARCH_MODEL_NAME_REGISTRY = {} 12 | ARCH_MODEL_INV_REGISTRY = {} 13 | ARCH_CONFIG_REGISTRY = {} 14 | 15 | # automatically import any Python files in the models/ directory 16 | models_dir = os.path.dirname(__file__) 17 | for file in os.listdir(models_dir): 18 | path = os.path.join(models_dir, file) 19 | if ( 20 | not file.startswith("_") 21 | and not file.startswith(".") 22 | and (file.endswith(".py") or os.path.isdir(path)) 23 | ): 24 | model_name = file[: file.find(".py")] if file.endswith(".py") else file 25 | module = importlib.import_module("models." + model_name) 26 | 27 | # extra `model_parser` for sphinx 28 | if model_name in MODEL_REGISTRY: 29 | parser = argparse.ArgumentParser(add_help=False) 30 | group_archs = parser.add_argument_group("Named architectures") 31 | group_archs.add_argument( 32 | "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name] 33 | ) 34 | group_args = parser.add_argument_group("Additional command-line arguments") 35 | MODEL_REGISTRY[model_name].add_args(group_args) 36 | globals()[model_name + "_parser"] = parser 37 | -------------------------------------------------------------------------------- /examples/fairseq/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import argparse 5 | import importlib 6 | import os 7 | 8 | # register dataclass 9 | TASK_DATACLASS_REGISTRY = {} 10 | TASK_REGISTRY = {} 11 | TASK_CLASS_NAMES = set() 12 | 13 | # automatically import any Python files in the tasks/ directory 14 | tasks_dir = os.path.dirname(__file__) 15 | for file in os.listdir(tasks_dir): 16 | path = os.path.join(tasks_dir, file) 17 | if ( 18 | not file.startswith("_") 19 | and not file.startswith(".") 20 | and (file.endswith(".py") or os.path.isdir(path)) 21 | ): 22 | task_name = file[: file.find(".py")] if file.endswith(".py") else file 23 | module = importlib.import_module("tasks." + task_name) 24 | 25 | # expose `task_parser` for sphinx 26 | if task_name in TASK_REGISTRY: 27 | parser = argparse.ArgumentParser(add_help=False) 28 | group_task = parser.add_argument_group("Task name") 29 | # fmt: off 30 | group_task.add_argument('--task', metavar=task_name, 31 | help='Enable this task with: ``--task=' + task_name + '``') 32 | # fmt: on 33 | group_args = parser.add_argument_group("Additional command-line arguments") 34 | TASK_REGISTRY[task_name].add_args(group_args) 35 | globals()[task_name + "_parser"] = parser 36 | -------------------------------------------------------------------------------- /examples/fairseq/tasks/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /examples/fairseq/tasks/data/basic_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch 5 | from infinibatch.iterators import CheckpointableIterator 6 | 7 | from . import utils 8 | 9 | 10 | class BaseBatchGen(CheckpointableIterator): 11 | """ 12 | This is a base class for batch generators that use infinibatch 13 | """ 14 | 15 | def __init__(self): 16 | self._iter = None 17 | self.epoch = 1 18 | self.next_epoch_idx = 1 19 | self.sharded_checkpoint = True 20 | self.should_close_after_finished = True 21 | 22 | def _build_iter(self): 23 | """ 24 | Build infinibatch iterator and assign to self._iter 25 | """ 26 | raise NotImplementedError() 27 | 28 | def _move_to_tensor(self, batch): 29 | def to_tensor(x): 30 | return torch.tensor(x) 31 | 32 | return utils.apply_to_sample(to_tensor, batch) 33 | 34 | @property 35 | def iterator(self): 36 | if self._iter is None: 37 | raise NotImplementedError("_build_iter() must called first") 38 | return self._iter 39 | 40 | def __iter__(self): 41 | if self._iter is None: 42 | raise NotImplementedError("_build_iter() must called first") 43 | return self._iter 44 | 45 | def __next__(self): 46 | return next(self._iter) 47 | 48 | def setstate(self, value): 49 | self._iter.setstate(value) 50 | 51 | def getstate(self): 52 | return self._iter.getstate() 53 | 54 | def close(self): 55 | self._iter.close() 56 | 57 | def __len__(self) -> int: 58 | return 819200000 59 | 60 | def next_epoch_itr( 61 | self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True 62 | ): 63 | return self 64 | 65 | def end_of_epoch(self) -> bool: 66 | return False 67 | 68 | def state_dict(self): 69 | """Returns a dictionary containing a whole state of the iterator.""" 70 | return self.getstate() 71 | 72 | def load_state_dict(self, state_dict): 73 | """Copies the state of the iterator from the given *state_dict*.""" 74 | self.setstate(state_dict) 75 | 76 | @property 77 | def first_batch(self): 78 | return "DUMMY" 79 | -------------------------------------------------------------------------------- /examples/fairseq/tasks/data/mlm_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import copy 5 | import itertools 6 | import os 7 | 8 | import numpy as np 9 | from infinibatch import iterators 10 | 11 | from .basic_loader import BaseBatchGen 12 | from .utils import NativeCheckpointableIterator, WeightIterator 13 | 14 | 15 | class MLMLoader(BaseBatchGen): 16 | def __init__( 17 | self, 18 | args, 19 | dataset, 20 | dictionary, 21 | tokenizer, 22 | max_tokens=None, 23 | max_sentences=None, 24 | max_positions=None, 25 | ignore_invalid_inputs=False, 26 | required_batch_size_multiple=1, 27 | seed=1, 28 | num_shards=1, 29 | shard_id=0, 30 | ): 31 | super().__init__() 32 | self.args = args 33 | self.data = dataset.data 34 | self.data_dir = dataset.data_dir 35 | self.shuffle = dataset.shuffle 36 | self.dictionary = dictionary 37 | self.tokenizer = tokenizer 38 | 39 | self.max_tokens = max_tokens 40 | self.max_sentences = max_sentences 41 | self.max_positions = max_positions 42 | self.tokens_per_sample = args.tokens_per_sample 43 | self.sample_break_mode = args.sample_break_mode 44 | self.ignore_invalid_inputs = ignore_invalid_inputs 45 | self.required_batch_size_multiple = required_batch_size_multiple 46 | self.seed = str(seed) 47 | self.num_shards = num_shards 48 | self.shard_id = shard_id 49 | 50 | self.batch_read_ahead = args.batch_read_ahead 51 | 52 | self._build_iter() 53 | 54 | def _build_iter(self): 55 | tokenized_lines = self._multilingual_tokenize() 56 | self.padded_batches = self._batchify(tokenized_lines) 57 | 58 | prefetch_batches = iterators.PrefetchIterator( 59 | self.padded_batches, 60 | buffer_size=10000, 61 | buffer_in_main_process=True, 62 | log_empty_buffer_warning=True and self.shard_id == 0, 63 | ) 64 | 65 | prefetch_batches = iterators.MapIterator(prefetch_batches, self._move_to_tensor) 66 | 67 | self._iter = prefetch_batches 68 | 69 | def _multilingual_tokenize(self): 70 | multilingual_iters = [] 71 | weights = [] 72 | 73 | for data in self.data: 74 | multilingual_iters.append(self._tokenize(data)) 75 | if "weight" in data: 76 | weights.append(float(data["weight"])) 77 | else: 78 | weights.append(int(data["count"])) 79 | 80 | if len(multilingual_iters) == 1: 81 | return multilingual_iters[0] 82 | 83 | sampling_iterator = WeightIterator(weights) 84 | control_iterator = NativeCheckpointableIterator(sampling_iterator) 85 | tokenized_lines = iterators.MultiplexIterator( 86 | control_iterator, multilingual_iters 87 | ) 88 | 89 | return tokenized_lines 90 | 91 | def _tokenize(self, data): 92 | """ 93 | data: 94 | { 95 | 'source': list[Path], 96 | 'source_lang': str, 97 | 'count': int, 98 | 'weight': float, 99 | 'name': str, 100 | } 101 | """ 102 | dataset = list( 103 | zip( 104 | data["source"], 105 | itertools.repeat(data["source_lang"]), 106 | ) 107 | ) 108 | 109 | if self.shuffle: 110 | chunk_files = iterators.InfinitePermutationSourceIterator( 111 | dataset, 112 | seed=self.seed, 113 | shuffle=self.shuffle, 114 | num_instances=self.num_shards, 115 | instance_rank=self.shard_id, 116 | ) 117 | else: 118 | chunk_files = iterators.ChunkedSourceIterator( 119 | dataset, 120 | num_instances=self.num_shards, 121 | instance_rank=self.shard_id, 122 | ) 123 | 124 | tokenized_lines = iterators.SelectManyIterator( 125 | chunk_files, lambda files: self._read_from_files(*files) 126 | ) 127 | tokenized_lines = iterators.SamplingRandomMapIterator( 128 | tokenized_lines, self._prepare, self.seed 129 | ) 130 | 131 | return tokenized_lines 132 | 133 | def _batchify(self, lines): 134 | 135 | if self.max_sentences is not None: 136 | if self.batch_read_ahead > 0: 137 | lines = iterators.BlockwiseShuffleIterator( 138 | lines, self.batch_read_ahead, self.seed 139 | ) 140 | batches = iterators.FixedBatchIterator(lines, self.max_sentences) 141 | else: 142 | 143 | def dynamic_batch_size(sample): 144 | lengths = [len(x) for x in sample] 145 | batch_size = self.max_tokens // max(lengths) 146 | batch_size = ( 147 | batch_size 148 | // self.required_batch_size_multiple 149 | * self.required_batch_size_multiple 150 | ) 151 | return max(1, batch_size) 152 | 153 | batches = iterators.BucketedReadaheadBatchIterator( 154 | lines, 155 | read_ahead=self.batch_read_ahead, 156 | key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None, 157 | batch_size=dynamic_batch_size, 158 | shuffle=self.shuffle, 159 | seed=self.seed, 160 | ) 161 | 162 | def collate(batch): 163 | batch_size = len(batch) 164 | 165 | mlm_source_max_length = max([len(x[0]) for x in batch]) 166 | mlm_target_max_length = max([len(x[1]) for x in batch]) 167 | s2s_source_max_length = max([len(x[2]) for x in batch]) 168 | s2s_target_max_length = max([len(x[3]) for x in batch]) 169 | if self.args.pad_to_max_length: 170 | mlm_source_max_length = self.args.tokens_per_sample 171 | mlm_target_max_length = self.args.tokens_per_sample 172 | 173 | mlm_source_ids = np.full( 174 | shape=(batch_size, mlm_source_max_length), 175 | dtype=np.int32, 176 | fill_value=self.dictionary.pad(), 177 | ) 178 | mlm_target_ids = np.full( 179 | shape=(batch_size, mlm_target_max_length), 180 | dtype=np.int32, 181 | fill_value=self.dictionary.pad(), 182 | ) 183 | s2s_source_ids = np.full( 184 | shape=(batch_size, s2s_source_max_length), 185 | dtype=np.int32, 186 | fill_value=self.dictionary.pad(), 187 | ) 188 | s2s_target_ids = np.full( 189 | shape=(batch_size, s2s_target_max_length - 1), 190 | dtype=np.int32, 191 | fill_value=self.dictionary.pad(), 192 | ) 193 | s2s_prev_input_ids = np.full( 194 | shape=(batch_size, s2s_target_max_length - 1), 195 | dtype=np.int32, 196 | fill_value=self.dictionary.pad(), 197 | ) 198 | 199 | for i, ( 200 | mlm_input_ids, 201 | mlm_label_ids, 202 | s2s_input_ids, 203 | s2s_label_ids, 204 | ) in enumerate(batch): 205 | mlm_source_ids[i, : len(mlm_input_ids)] = mlm_input_ids 206 | mlm_target_ids[i, : len(mlm_label_ids)] = mlm_label_ids 207 | s2s_source_ids[i, : len(s2s_input_ids)] = s2s_input_ids 208 | s2s_target_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[1:] 209 | s2s_prev_input_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[:-1] 210 | 211 | ret_batch = { 212 | "net_input": { 213 | "src_tokens": mlm_source_ids.astype(np.int64), 214 | }, 215 | "target": mlm_target_ids.astype(np.int64), 216 | "nsentences": batch_size, 217 | "ntokens": sum([len(x[0]) for x in batch]), 218 | } 219 | 220 | return ret_batch 221 | 222 | padded_batches = iterators.MapIterator(batches, collate) 223 | 224 | return padded_batches 225 | 226 | def _prepare(self, _random, doc): 227 | nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc) 228 | nonnoise_spans, noise_spans = self._span_corruption(_random, doc) 229 | return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans 230 | 231 | def _mask_lm(self, _random, doc): 232 | def mask_tokens(): 233 | return "" 234 | 235 | length = len(doc) 236 | mask_tokens_num = int(length * self.args.mask_prob) 237 | mask_tokens_num = min(max(mask_tokens_num, 1), length - 1) 238 | possible_mask_positions = _random.sample(range(length), k=mask_tokens_num) 239 | possible_mask_positions = sorted(possible_mask_positions) 240 | 241 | nonmasked_tokens = copy.deepcopy(doc) 242 | masked_tokens = [self.dictionary.pad() for _ in range(len(doc))] 243 | 244 | for position in possible_mask_positions: 245 | # masked_tokens.append(nonmasked_tokens[position]) 246 | masked_tokens[position] = nonmasked_tokens[position] 247 | nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()] 248 | 249 | return nonmasked_tokens, masked_tokens 250 | 251 | def _span_corruption(self, _random, doc): 252 | def mask_tokens(i): 253 | return f"" 254 | 255 | length = len(doc) 256 | noise_tokens_num = int(length * self.args.mask_prob) 257 | noise_tokens_num = min(max(noise_tokens_num, 1), length - 1) 258 | noise_spans_num = int(noise_tokens_num / self.args.span_length) 259 | noise_spans_num = max(noise_spans_num, 1) 260 | nonnoise_tokens_num = length - noise_tokens_num 261 | 262 | if noise_spans_num == 1: 263 | noise_split_positions = [0, noise_tokens_num] 264 | else: 265 | possible_split_positions = list(range(1, noise_tokens_num)) 266 | _random.shuffle(possible_split_positions) 267 | noise_split_positions = sorted( 268 | possible_split_positions[: noise_spans_num - 1] 269 | ) 270 | noise_split_positions = [0] + noise_split_positions + [noise_tokens_num] 271 | 272 | possible_insert_positions = list(range(nonnoise_tokens_num)) 273 | _random.shuffle(possible_insert_positions) 274 | noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num]) 275 | 276 | nonnoise_spans, noise_spans = [], [] 277 | last_end = 0 278 | for i in range(noise_spans_num): 279 | start_pos = noise_insert_positions[i] + noise_split_positions[i] 280 | end_pos = noise_insert_positions[i] + noise_split_positions[i + 1] 281 | mask_id = self.dictionary.indices[mask_tokens(i)] 282 | 283 | if getattr(self.args, "remove_target_sentinel", False): 284 | noise_spans.append(doc[start_pos:end_pos]) 285 | else: 286 | noise_spans.append([mask_id] + doc[start_pos:end_pos]) 287 | 288 | if getattr(self.args, "remove_source_sentinel", False): 289 | nonnoise_spans.extend(doc[last_end:start_pos]) 290 | else: 291 | nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id]) 292 | 293 | last_end = end_pos 294 | 295 | nonnoise_spans.extend(doc[last_end:]) 296 | noise_spans = sum(noise_spans, []) 297 | 298 | return nonnoise_spans, noise_spans 299 | 300 | def _read_from_files(self, source_file, source_lang): 301 | # data = [] 302 | file_path = os.path.join(self.data_dir, source_file) 303 | 304 | if not os.path.exists(file_path): 305 | print("| file {} not exists".format(file_path), flush=True) 306 | return iter([]) # skip bad file 307 | 308 | with open(file_path, "r", encoding="utf8") as f: 309 | lines = f.read().strip().split("\n") 310 | 311 | doc = [self.dictionary.bos()] 312 | for line in lines: 313 | if line == "": 314 | if self.sample_break_mode == "complete_doc": 315 | # data.append(doc) 316 | yield doc 317 | doc = [self.dictionary.bos()] 318 | continue 319 | 320 | tokenized_line = self.tokenizer.EncodeAsPieces(line) 321 | tokenized_id = [ 322 | self.dictionary.index(token) for token in tokenized_line 323 | ] + [self.dictionary.eos_index] 324 | 325 | if len(tokenized_id) > self.tokens_per_sample: 326 | continue 327 | if len(doc) + len(tokenized_id) > self.tokens_per_sample: 328 | # data.append(doc) 329 | yield doc 330 | doc = [self.dictionary.bos()] 331 | doc.extend(tokenized_id) 332 | 333 | if len(doc) > 1 and len(doc) <= self.tokens_per_sample: 334 | # data.append(doc) 335 | yield doc 336 | 337 | # return data 338 | -------------------------------------------------------------------------------- /examples/fairseq/tasks/data/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import collections 5 | from random import Random 6 | from typing import Dict, Iterable, Optional 7 | 8 | import numpy as np 9 | from infinibatch import iterators 10 | 11 | 12 | def apply_to_sample(f, sample): 13 | if hasattr(sample, "__len__") and len(sample) == 0: 14 | return {} 15 | 16 | def _apply(x): 17 | if isinstance(x, np.ndarray): 18 | return f(x) 19 | elif isinstance(x, collections.OrderedDict): 20 | # OrderedDict has attributes that needs to be preserved 21 | od = collections.OrderedDict( 22 | (key, _apply(value)) for key, value in x.items() 23 | ) 24 | od.__dict__ = x.__dict__ 25 | return od 26 | elif isinstance(x, dict): 27 | return {key: _apply(value) for key, value in x.items()} 28 | elif isinstance(x, list): 29 | return [_apply(x) for x in x] 30 | elif isinstance(x, tuple): 31 | return tuple(_apply(x) for x in x) 32 | elif isinstance(x, set): 33 | return {_apply(x) for x in x} 34 | else: 35 | return x 36 | 37 | return _apply(sample) 38 | 39 | 40 | class NativeCheckpointableIterator(iterators.CheckpointableIterator): 41 | def __init__(self, iterable: Iterable): 42 | self._input_iterable = iterable 43 | self.setstate(None) 44 | 45 | def getstate(self) -> Dict: 46 | return {"num_items_yielded": self._num_items_yielded} 47 | 48 | def setstate(self, checkpoint: Optional[Dict]): 49 | self._iterator = iter(self._input_iterable) 50 | self._num_items_yielded = ( 51 | iterators._advance_iterator(self._iterator, checkpoint["num_items_yielded"]) 52 | if checkpoint is not None 53 | else 0 54 | ) 55 | 56 | def __next__(self): 57 | item = next(self._iterator) 58 | self._num_items_yielded += 1 59 | return item 60 | 61 | def close(self): 62 | pass 63 | 64 | 65 | class WeightIterator(object): 66 | def __init__(self, weights, seed): 67 | self.weights = weights 68 | self.seed = seed 69 | self.control_index = list(range(len(weights))) 70 | self.setstate(None) 71 | 72 | def __iter__(self): 73 | return self 74 | 75 | def getstate(self): 76 | return {"random_state": self._random_state} 77 | 78 | def setstate(self, checkpoint): 79 | self._random_state = checkpoint["random_state"] if checkpoint else None 80 | self._random = ( 81 | None # this will trigger the lazy initialization in self.__next__ 82 | ) 83 | 84 | def __next__(self): 85 | if self._random is None: 86 | self._random = Random(self.seed) 87 | if self._random_state is not None: 88 | self._random.setstate(self._random_state) 89 | idx = self._random.choices(self.control_index, self.weights)[0] 90 | self._random_state = self._random.getstate() 91 | return idx 92 | 93 | def close(self): 94 | pass 95 | -------------------------------------------------------------------------------- /examples/fairseq/tasks/pretraining.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import json 5 | import logging 6 | import os 7 | from argparse import Namespace 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # 11 | # This source code is licensed under the MIT license found in the 12 | # LICENSE file in the root directory of this source tree. 13 | from dataclasses import dataclass, field 14 | 15 | import sentencepiece as spm 16 | from fairseq import utils 17 | from fairseq.data import Dictionary 18 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass 19 | from fairseq.tasks import FairseqTask, register_task 20 | from omegaconf import II, MISSING 21 | 22 | from .data.mlm_loader import MLMLoader 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) 27 | SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) 28 | 29 | 30 | @dataclass 31 | class PretrainingConfig(FairseqDataclass): 32 | data: str = field( 33 | default=MISSING, 34 | metadata={ 35 | "help": "colon separated path to data directories list, \ 36 | will be iterated upon during epochs in round-robin manner" 37 | }, 38 | ) 39 | sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( 40 | default="complete", 41 | metadata={ 42 | "help": 'If omitted or "none", fills each sample with tokens-per-sample ' 43 | 'tokens. If set to "complete", splits samples only at the end ' 44 | "of sentence, but may include multiple sentences per sample. " 45 | '"complete_doc" is similar but respects doc boundaries. ' 46 | 'If set to "eos", includes only one sentence per sample.' 47 | }, 48 | ) 49 | tokens_per_sample: int = field( 50 | default=1024, 51 | metadata={"help": "max number of tokens per sample for LM dataset"}, 52 | ) 53 | mask_prob: float = field( 54 | default=0.15, 55 | metadata={"help": "probability of replacing a token with mask"}, 56 | ) 57 | leave_unmasked_prob: float = field( 58 | default=0.1, 59 | metadata={"help": "probability that a masked token is unmasked"}, 60 | ) 61 | random_token_prob: float = field( 62 | default=0.1, 63 | metadata={"help": "probability of replacing a token with a random token"}, 64 | ) 65 | freq_weighted_replacement: bool = field( 66 | default=False, 67 | metadata={"help": "sample random replacement words based on word frequencies"}, 68 | ) 69 | mask_whole_words: bool = field( 70 | default=False, 71 | metadata={"help": "mask whole words; you may also want to set --bpe"}, 72 | ) 73 | mask_multiple_length: int = field( 74 | default=1, 75 | metadata={"help": "repeat the mask indices multiple times"}, 76 | ) 77 | mask_stdev: float = field( 78 | default=0.0, 79 | metadata={"help": "stdev of the mask length"}, 80 | ) 81 | shorten_method: SHORTEN_METHOD_CHOICES = field( 82 | default="none", 83 | metadata={ 84 | "help": "if not none, shorten sequences that exceed --tokens-per-sample" 85 | }, 86 | ) 87 | shorten_data_split_list: str = field( 88 | default="", 89 | metadata={ 90 | "help": "comma-separated list of dataset splits to apply shortening to, " 91 | 'e.g., "train,valid" (default: all dataset splits)' 92 | }, 93 | ) 94 | seed: int = II("common.seed") 95 | span_length: float = field( 96 | default=3.0, 97 | metadata={"help": "average span length for masking"}, 98 | ) 99 | remove_source_sentinel: bool = field( 100 | default=False, 101 | metadata={"help": "remove the source sentinel for the span corruption task"}, 102 | ) 103 | remove_target_sentinel: bool = field( 104 | default=False, 105 | metadata={"help": "remove the target sentinel for the span corruption task"}, 106 | ) 107 | batch_read_ahead: int = field( 108 | default=100000, 109 | metadata={"help": "batch read ahead size for infinibatch"}, 110 | ) 111 | required_batch_size_multiple: int = II("dataset.required_batch_size_multiple") 112 | spm_model: str = field( 113 | default="", 114 | metadata={"help": "sentencepice model to tokenize the data"}, 115 | ) 116 | dict_file: str = field( 117 | default="", 118 | metadata={"help": ""}, 119 | ) 120 | pad_to_max_length: bool = field( 121 | default=False, 122 | ) 123 | 124 | 125 | @register_task("pretraining", dataclass=PretrainingConfig) 126 | class PLMTask(FairseqTask): 127 | def __init__(self, cfg, dictionary, tokenizer): 128 | super().__init__(cfg) 129 | self.cfg = cfg 130 | self.dictionary = dictionary 131 | self.tokenizer = tokenizer 132 | self.seed = cfg.seed 133 | self.mask_idx = dictionary.index("") 134 | 135 | @classmethod 136 | def setup_task(cls, cfg, **kwargs): 137 | paths = utils.split_paths(cfg.data) 138 | assert len(paths) > 0 139 | if cfg.dict_file != "": 140 | dictionary = Dictionary.load(cfg.dict_file) 141 | else: 142 | dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) 143 | 144 | # add mask token 145 | dictionary.add_symbol("") 146 | for i in range(100): 147 | dictionary.add_symbol(f"") 148 | 149 | dictionary.pad_to_multiple_(cfg.required_batch_size_multiple) 150 | logger.info("dictionary: {} types".format(len(dictionary))) 151 | 152 | # tokenizer = SentencepieceBPE(Namespace(sentencepiece_model=cfg.spm_model)) 153 | tokenizer = spm.SentencePieceProcessor() 154 | tokenizer.Load(cfg.spm_model) 155 | return cls(cfg, dictionary, tokenizer) 156 | 157 | def load_dataset(self, split, epoch=1, combine=False, **kwargs): 158 | self.datasets[split] = { 159 | "data": json.load(open(f"{self.cfg.data}/json/{split}.json")), 160 | "data_dir": self.cfg.data, 161 | "shuffle": True if split == "train" else False, 162 | } 163 | self.datasets[split] = Namespace(**self.datasets[split]) 164 | 165 | def dataset(self, split): 166 | if split not in self.datasets: 167 | raise KeyError("Dataset not loaded: " + split) 168 | 169 | return self.datasets[split] 170 | 171 | def get_batch_iterator( 172 | self, 173 | dataset, 174 | max_tokens=None, 175 | max_sentences=None, 176 | max_positions=None, 177 | ignore_invalid_inputs=False, 178 | required_batch_size_multiple=1, 179 | seed=1, 180 | num_shards=1, 181 | shard_id=0, 182 | num_workers=0, 183 | epoch=1, 184 | data_buffer_size=0, 185 | disable_iterator_cache=False, 186 | **kwargs, 187 | ): 188 | return MLMLoader( 189 | self.cfg, 190 | dataset, 191 | self.dictionary, 192 | self.tokenizer, 193 | max_tokens=max_tokens, 194 | max_sentences=max_sentences, 195 | max_positions=max_positions, 196 | ignore_invalid_inputs=ignore_invalid_inputs, 197 | required_batch_size_multiple=required_batch_size_multiple, 198 | seed=seed, 199 | num_shards=num_shards, 200 | shard_id=shard_id, 201 | ) 202 | 203 | @property 204 | def source_dictionary(self): 205 | return self.dictionary 206 | 207 | @property 208 | def target_dictionary(self): 209 | return self.dictionary 210 | -------------------------------------------------------------------------------- /examples/fairseq/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | # flake8: noqa 5 | import models 6 | import tasks 7 | import criterions 8 | from fairseq_cli.train import cli_main 9 | 10 | if __name__ == "__main__": 11 | cli_main() 12 | -------------------------------------------------------------------------------- /examples/fairseq/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /examples/fairseq/utils/sparse_clip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import math 5 | import warnings 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm 10 | 11 | 12 | @torch.no_grad() 13 | def clip_grad_norm_( 14 | params, max_norm, moe_expert_count, aggregate_norm_fn=None 15 | ) -> torch.Tensor: 16 | def grad_exists(p): 17 | return p is not None and getattr(p, "grad", None) is not None 18 | 19 | if isinstance(params, torch.Tensor): 20 | params = [params] 21 | params = list(params) 22 | params = list(filter(grad_exists, params)) 23 | grads, expert_grads, base_expert_grads, sharded_grads = [], [], [], [] 24 | denom = math.sqrt(max(dist.get_global_world_size(), moe_expert_count)) 25 | for p in params: 26 | if hasattr(p, "expert"): 27 | expert_grads.append(p.grad.detach() / denom) 28 | elif hasattr(p, "base_expert"): 29 | base_expert_grads.append(p.grad.detach()) 30 | elif hasattr(p, "_is_sharded"): 31 | sharded_grads.append(p.grad.detach()) 32 | else: 33 | grads.append(p.grad.detach()) 34 | if len(grads) == 0: 35 | if len(params) > 0: 36 | total_norm = params[0].new_tensor(0.0) 37 | else: 38 | total_norm = torch.tensor(0.0) 39 | elif len(grads) == 1: 40 | total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) 41 | else: 42 | if multi_tensor_l2norm_available: 43 | total_norm = multi_tensor_total_norm(grads) 44 | else: 45 | if torch.cuda.is_available(): 46 | warnings.warn( 47 | "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " 48 | "you may get better performance by installing NVIDIA's apex library" 49 | ) 50 | device = torch.cuda.current_device() 51 | elif grads[0].device.type == "xla": 52 | device = grads[0].device 53 | else: 54 | device = torch.device("cpu") 55 | total_norm = torch.norm( 56 | torch.stack( 57 | [torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads] 58 | ) 59 | ) 60 | 61 | # calculate split_norm and all_reduce with other workers 62 | norms = [total_norm] 63 | for split_grads in [expert_grads, sharded_grads]: 64 | if len(split_grads) == 0: 65 | continue 66 | split_norm = torch.norm( 67 | torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads]) 68 | ) 69 | if dist.is_initialized(): 70 | split_norm.pow_(2) 71 | dist.all_reduce(split_norm) 72 | split_norm.sqrt_() 73 | norms.append(split_norm) 74 | if len(norms) > 1: 75 | total_norm = torch.norm(torch.stack(norms)) 76 | 77 | if aggregate_norm_fn is not None: 78 | total_norm = aggregate_norm_fn(total_norm) 79 | 80 | if max_norm > 0: 81 | max_norm = float(max_norm) 82 | clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) 83 | for g in grads + expert_grads + sharded_grads + base_expert_grads: 84 | g.mul_(clip_coef) 85 | return total_norm 86 | -------------------------------------------------------------------------------- /examples/longvit/README.md: -------------------------------------------------------------------------------- 1 | # [(LongViT) When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology](https://arxiv.org/abs/2312.03558) 2 | 3 | **LongViT** is a vision Transformer that can process gigapixel images (e.g., 32,768x32,768 images) in an end-to-end manner. We split the image into millions of patches and employ [LongNet](https://arxiv.org/abs/2307.02486) to directly model the extremely long sequence. We apply LongViT in the field of computational pathology and achieve remarkable performance on cancer subtyping and survival prediction tasks. 4 | 5 | 6 | ## Setup 7 | ``` 8 | pip install -r requirements.txt 9 | pip install git+https://github.com/shumingma/fairseq.git@moe 10 | pip install -v -U git+https://github.com/facebookresearch/xformers.git@v0.0.20#egg=xformers 11 | ``` 12 | 13 | 14 | ## Pretraining 15 | 16 | We perform self-supervised pretraining on TCGA diagnostic slides using [DINO](https://arxiv.org/abs/2104.14294) objective. The detailed instructions can be found at [`get_started_for_tcga_pretraining.md`](get_started/get_started_for_tcga_pretraining.md). 17 | 18 | The link to the pretrained LongViT model on TCGA diagnostic slides: 19 | - [`LongViT`](https://github.com/wenhui0924/longvit_ckpts/releases/download/longvit/longvit_small_patch32_1024.pth): #layer=12; hidden=384; FFN factor=4x; #head=16; patch=32x32 20 | 21 | 22 | ## Fine-tuning on Subtyping Classification 23 | 24 | We perform finetuning on cancer subtyping on images with sizes up to 32,768x32,768 (1M patches). The detailed instructions can be found at [`get_started_for_tcga_subtyping.md`](get_started/get_started_for_tcga_subtyping.md). 25 | 26 | 27 | ## Fine-tuning on Survival Prediction 28 | 29 | We perform finetuning on survival prediction on images with sizes up to 32,768x32,768 (1M patches). The detailed instructions can be found at [`get_started_for_tcga_survival_prediction.md`](get_started/get_started_for_tcga_survival_prediction.md). 30 | 31 | 32 | ## Citation 33 | 34 | If you find this repository useful, please consider citing our work: 35 | ``` 36 | @article{longvit, 37 | title={When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology}, 38 | author={Wang, Wenhui and Ma, Shuming and Xu, Hanwen and Usuyama, Naoto and Ding, Jiayu and Poon, Hoifung and Wei, Furu}, 39 | journal={arXiv preprint arXiv:2312.03558}, 40 | year={2023} 41 | } 42 | 43 | @article{longnet, 44 | title={LongNet: Scaling transformers to 1,000,000,000 tokens}, 45 | author={Ding, Jiayu and Ma, Shuming and Dong, Li and Zhang, Xingxing and Huang, Shaohan and Wang, Wenhui and Zheng, Nanning and Wei, Furu}, 46 | journal={arXiv preprint arXiv:2307.02486}, 47 | year={2023} 48 | } 49 | 50 | @article{torchscale, 51 | title={TorchScale: Transformers at scale}, 52 | author={Ma, Shuming and Wang, Hongyu and Huang, Shaohan and Wang, Wenhui and Chi, Zewen and Dong, Li and Benhaim, Alon and Patra, Barun and Chaudhary, Vishrav and Song, Xia and others}, 53 | journal={arXiv preprint arXiv:2211.13184}, 54 | year={2022} 55 | } 56 | ``` 57 | 58 | 59 | ## Acknowledgement 60 | 61 | This repository is built using the [BEiT-3](https://github.com/microsoft/unilm/tree/master/beit3), the [MCAT](https://github.com/mahmoodlab/MCAT), the [DINO](https://github.com/facebookresearch/dino), the [HIPT](https://github.com/mahmoodlab/HIPT) repository and the [timm](https://github.com/rwightman/pytorch-image-models) library. 62 | 63 | 64 | ## License 65 | This project is licensed under the license found in the LICENSE file in the root directory of this source tree. 66 | 67 | [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct) 68 | 69 | ### Contact Information 70 | 71 | For help or issues using LongViT models, please submit a GitHub issue. 72 | -------------------------------------------------------------------------------- /examples/longvit/data_preprocessing/cache_transformed_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import random 5 | import argparse 6 | from PIL import Image, ImageFilter, ImageOps 7 | from multiprocessing import Pool, cpu_count 8 | from timm.data.transforms import RandomResizedCropAndInterpolation 9 | import torchvision.transforms as transforms 10 | 11 | Image.MAX_IMAGE_PIXELS = 6400000000 12 | 13 | 14 | def build_transform(input_size): 15 | train_interpolation = "bicubic" 16 | t = [ 17 | RandomResizedCropAndInterpolation(input_size, scale=(0.5, 1.0), interpolation=train_interpolation), 18 | transforms.RandomHorizontalFlip(), 19 | ] 20 | t = transforms.Compose(t) 21 | 22 | return t 23 | 24 | 25 | def pil_loader(path): 26 | with open(path, "rb") as f: 27 | img = Image.open(f) 28 | return img.convert("RGB") 29 | 30 | 31 | def save_image(transformed_img, output_image_path): 32 | if isinstance(transformed_img, torch.Tensor): 33 | transformed_img = transforms.ToPILImage()(transformed_img) 34 | transformed_img.save(output_image_path) 35 | 36 | 37 | def get_image_files(input_dir): 38 | for root, _, files in os.walk(input_dir): 39 | for file in files: 40 | if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): 41 | yield os.path.join(root, file) 42 | 43 | 44 | def transform_and_save_crops(args): 45 | input_path, input_dir, output_dir, transform = args 46 | print(input_path) 47 | file_basename = os.path.basename(input_path) 48 | 49 | img = pil_loader(input_path) 50 | transformed_img = transform(img) 51 | output_image_path = os.path.join(output_dir, file_basename) 52 | save_image(transformed_img, output_image_path) 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='Save transformed images in a directory.') 57 | parser.add_argument('input_dir', help='Path to the input directory.') 58 | parser.add_argument('output_dir', help='Path to the output directory.') 59 | parser.add_argument('-p', '--processes', type=int, default=cpu_count(), help='Number of processes to use. Default: number of CPU cores') 60 | parser.add_argument('--input_size', type=int, default=16384, help='input image size') 61 | args = parser.parse_args() 62 | 63 | input_dir = args.input_dir 64 | output_dir = args.output_dir 65 | num_processes = args.processes 66 | input_size = args.input_size 67 | print("num_processes: {}".format(num_processes)) 68 | print("input_size: {}".format(input_size)) 69 | 70 | transform = build_transform(input_size=input_size) 71 | 72 | image_files = list(get_image_files(input_dir)) 73 | task_args = [(file, input_dir, output_dir, transform) for file in image_files] 74 | 75 | os.makedirs(output_dir, exist_ok=True) 76 | 77 | with Pool(processes=num_processes) as pool: 78 | pool.map(transform_and_save_crops, task_args) 79 | -------------------------------------------------------------------------------- /examples/longvit/data_preprocessing/convert_wsi_to_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | import openslide 5 | 6 | from PIL import Image 7 | from concurrent.futures import ProcessPoolExecutor 8 | 9 | 10 | def convert_wsi_to_images(slide_path, image_path, target_size, level=0): 11 | slide = openslide.open_slide(slide_path) 12 | level_dims = slide.level_dimensions 13 | region = slide.read_region((0,0), level, level_dims[level]) 14 | region = region.convert("RGB") 15 | print("convert: {}({}) -> {}".format(slide_path, region.size, image_path)) 16 | resized_img = region.resize((target_size, target_size), Image.BICUBIC) 17 | resized_img.save(image_path) 18 | 19 | 20 | def process_slides(input_folder, output_folder, target_size, level=0): 21 | if not os.path.exists(output_folder): 22 | os.makedirs(output_folder) 23 | 24 | slide_paths = glob.glob(os.path.join(input_folder, "*.svs")) 25 | 26 | with ProcessPoolExecutor(max_workers=1) as executor: 27 | for slide_path in slide_paths: 28 | image_path = os.path.join(output_folder, os.path.basename(slide_path).split(".svs")[0] + ".jpg") 29 | executor.submit(convert_wsi_to_images, slide_path, image_path, target_size, level=level) 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser(description="Convert slides into images") 34 | parser.add_argument("input_folder", type=str, help="") 35 | parser.add_argument("output_folder", type=str, help="") 36 | parser.add_argument("target_size", type=int, help="") 37 | parser.add_argument("level", type=int, help="") 38 | 39 | args = parser.parse_args() 40 | input_folder = args.input_folder 41 | output_folder = args.output_folder 42 | target_size = args.target_size 43 | level = args.level 44 | 45 | process_slides(input_folder, output_folder, target_size, level=level) 46 | -------------------------------------------------------------------------------- /examples/longvit/data_preprocessing/create_tcga_subtyping_index.py: -------------------------------------------------------------------------------- 1 | from datasets import TCGASubtypingDataset 2 | 3 | tcga_task = "tcga_brca" 4 | for k_fold in range(10): 5 | TCGASubtypingDataset.make_dataset_index( 6 | task=tcga_task, 7 | csv_path="./subtyping_dataset_csv/{}_subset.csv.zip".format(tcga_task), 8 | csv_split_path="./subtyping_splits/10foldcv_subtype/{}/splits_{}.csv".format(tcga_task, k_fold), 9 | k_fold=k_fold, 10 | index_path="./subtyping_split_index/{}".format(tcga_task), 11 | ignore=['MDLC', 'PD', 'ACBC', 'IMMC', 'BRCNOS', 'BRCA', 'SPC', 'MBC', 'MPT'], 12 | label_dict = {'IDC':0, 'ILC':1}, 13 | ) 14 | 15 | tcga_task = "tcga_lung" 16 | for k_fold in range(10): 17 | TCGASubtypingDataset.make_dataset_index( 18 | task=tcga_task, 19 | csv_path="./subtyping_dataset_csv/{}_subset.csv.zip".format(tcga_task), 20 | csv_split_path="./subtyping_splits/10foldcv_subtype/{}/splits_{}.csv".format(tcga_task, k_fold), 21 | k_fold=k_fold, 22 | index_path="./subtyping_split_index/{}".format(tcga_task), 23 | ignore=[], 24 | label_dict = {'LUAD':0, 'LUSC':1}, 25 | ) 26 | 27 | tcga_task = "tcga_kidney" 28 | for k_fold in range(10): 29 | TCGASubtypingDataset.make_dataset_index( 30 | task=tcga_task, 31 | csv_path="./subtyping_dataset_csv/{}_subset.csv.zip".format(tcga_task), 32 | csv_split_path="./subtyping_splits/10foldcv_subtype/{}/splits_{}.csv".format(tcga_task, k_fold), 33 | k_fold=k_fold, 34 | index_path="./subtyping_split_index/{}".format(tcga_task), 35 | ignore=[], 36 | label_dict = {'CCRCC':0, 'PRCC':1, 'CHRCC':2}, 37 | ) 38 | -------------------------------------------------------------------------------- /examples/longvit/data_preprocessing/create_tcga_survival_index.py: -------------------------------------------------------------------------------- 1 | from datasets import TCGASurvivalDataset 2 | 3 | for tcga_task in ["tcga_ucec", "tcga_luad", "tcga_brca"]: 4 | for k_fold in range(5): 5 | TCGASurvivalDataset.make_dataset_index( 6 | task=tcga_task, 7 | csv_path="./survival_dataset_csv/{}_all_clean.csv.zip".format(tcga_task), 8 | csv_split_path="./survival_splits/5foldcv/{}/splits_{}.csv".format(tcga_task, k_fold), 9 | k_fold=k_fold, 10 | index_path="./survival_split_index/{}".format(tcga_task), 11 | ) -------------------------------------------------------------------------------- /examples/longvit/data_preprocessing/generate_1024_crops.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import json 5 | import numpy as np 6 | import openslide 7 | import time 8 | import torch 9 | import openslide 10 | import argparse 11 | import random 12 | import shutil 13 | 14 | import glob 15 | from concurrent.futures import ProcessPoolExecutor 16 | 17 | from PIL import Image 18 | from torchvision import transforms 19 | from torchvision.transforms import InterpolationMode 20 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 21 | 22 | 23 | def is_similar_pixel(pixel1, pixel2, threshold=30): 24 | return np.linalg.norm(pixel1 - pixel2) < threshold 25 | 26 | 27 | def should_discard_image(image_path, target_pixel=np.array([243, 243, 243]), threshold=30, similarity_ratio=0.99): 28 | image = cv2.imread(image_path) 29 | height, width, _ = image.shape 30 | 31 | similar_pixels = 0 32 | total_pixels = height * width 33 | 34 | for y in range(height): 35 | for x in range(width): 36 | pixel = image[y, x] 37 | 38 | if is_similar_pixel(pixel, target_pixel, threshold): 39 | similar_pixels += 1 40 | 41 | ratio = similar_pixels / total_pixels 42 | return ratio > similarity_ratio 43 | 44 | 45 | def random_crop(slide_path, output_path, min_crop_size, max_crop_size, level=0): 46 | slide = openslide.open_slide(slide_path) 47 | level_dim = slide.level_dimensions 48 | slide_width, slide_height = slide.dimensions 49 | 50 | crop_width = random.randint(min_crop_size, max_crop_size) 51 | crop_height = random.randint(min_crop_size, max_crop_size) 52 | 53 | x = random.randint(0, slide_width - crop_width) 54 | y = random.randint(0, slide_height - crop_height) 55 | 56 | region = slide.read_region((x, y), level, (crop_width, crop_height)) 57 | region = region.convert("RGB") 58 | region.save(output_path) 59 | 60 | 61 | def get_crops(slide_path, output_folder, crop_number, min_crop_size, max_crop_size): 62 | print(slide_path) 63 | 64 | index = 0 65 | while index < crop_number: 66 | output_path = os.path.join(output_folder, os.path.basename(slide_path).split(".svs")[0], f"{str(index).zfill(8)}.JPEG") 67 | 68 | dir_path = os.path.dirname(output_path) 69 | if not os.path.exists(dir_path): 70 | os.makedirs(dir_path) 71 | 72 | random_crop(slide_path, output_path, min_crop_size, max_crop_size) 73 | if not should_discard_image(output_path): 74 | index += 1 75 | 76 | 77 | def process_slides(input_folder, output_folder, crop_number=100, min_crop_size=1024, max_crop_size=1536): 78 | if not os.path.exists(output_folder): 79 | os.makedirs(output_folder) 80 | 81 | slide_paths = glob.glob(f"{input_folder}/**/*.svs", recursive=True) 82 | 83 | with ProcessPoolExecutor(max_workers=4) as executor: 84 | for slide_path in slide_paths: 85 | executor.submit(get_crops, slide_path, output_folder, crop_number, min_crop_size, max_crop_size) 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser(description="Generate crops from slides") 90 | parser.add_argument("input_folder", type=str, help="") 91 | parser.add_argument("output_folder", type=str, help="") 92 | parser.add_argument("crop_number", type=int, help="") 93 | 94 | args = parser.parse_args() 95 | input_folder = args.input_folder 96 | output_folder = args.output_folder 97 | crop_number = args.crop_number 98 | 99 | process_slides(input_folder, output_folder, crop_number=crop_number) 100 | -------------------------------------------------------------------------------- /examples/longvit/data_preprocessing/split_to_small_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | import argparse 5 | from PIL import Image 6 | from concurrent.futures import ProcessPoolExecutor 7 | 8 | Image.MAX_IMAGE_PIXELS = 6400000000 9 | 10 | 11 | def split_image(image_path, input_folder, output_folder, num_splits): 12 | print(image_path) 13 | file_name, file_ext = os.path.splitext(os.path.basename(image_path)) 14 | 15 | img = Image.open(image_path) 16 | width, height = img.size 17 | 18 | block_width = width 19 | block_height = height // num_splits 20 | 21 | for i in range(num_splits): 22 | left = 0 23 | upper = i * block_height 24 | right = block_width 25 | lower = (i + 1) * block_height 26 | cropped_img = img.crop((left, upper, right, lower)) 27 | cropped_img.save(f"{output_folder}/{file_name}_{i}{file_ext}") 28 | 29 | 30 | def find_images(input_folder): 31 | image_files = [] 32 | for root, _, files in os.walk(input_folder): 33 | for f in files: 34 | if f.lower().endswith(('.png', '.jpg', '.jpeg')): 35 | image_files.append(os.path.join(root, f)) 36 | return image_files 37 | 38 | 39 | def process_images(image_files, input_folder, output_folder, num_splits, num_processes): 40 | with ProcessPoolExecutor(max_workers=num_processes) as executor: 41 | for image_file in image_files: 42 | executor.submit(split_image, image_file, input_folder, output_folder, num_splits) 43 | 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser(description='Split images into smaller tiles') 47 | parser.add_argument('--input', type=str, required=True, help='Path to the input folder containing images') 48 | parser.add_argument('--output', type=str, required=True, help='Path to the output folder for saving the tiles') 49 | parser.add_argument('--num_splits', type=int, default=16, help='Size of the tiles (default: 4096)') 50 | parser.add_argument('--processes', type=int, default=1, help='Number of processes (default: number of CPU cores)') 51 | args = parser.parse_args() 52 | 53 | input_folder = args.input 54 | output_folder = args.output 55 | num_splits = args.num_splits 56 | num_processes = args.processes 57 | 58 | if not os.path.exists(output_folder): 59 | os.makedirs(output_folder) 60 | 61 | image_files = find_images(input_folder) 62 | process_images(image_files, input_folder, output_folder, num_splits, num_processes) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | 68 | -------------------------------------------------------------------------------- /examples/longvit/get_started/get_started_for_tcga_pretraining.md: -------------------------------------------------------------------------------- 1 | # Pretraining LongViT on TCGA using DINO 2 | 3 | ## Setup 4 | 5 | 1. Download TCGA diagnostic whole slides from [NIH Genomic Data Commons Data Portal](https://portal.gdc.cancer.gov/). 6 | 7 | 2. Generate 1,024x1,024 regions from WSIs: 8 | ``` 9 | # we randomly generate 100 small regions for each whole slide image 10 | python data_preprocessing/generate_1024_crops.py /path/to/your_WSIs /path/to/your_crops 100 11 | ``` 12 | 13 | ## Pretraining LongViT 14 | 15 | Replace the `vision_transformer.py` in [DINO](https://github.com/facebookresearch/dino) with [LongViT vision_transformer.py](../pretraining/vision_transformer.py), and modify the `global crop size` to 1024 and `local crop size` to 512 to preform LongViT pretraining using DINO framework. -------------------------------------------------------------------------------- /examples/longvit/get_started/get_started_for_tcga_subtyping.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning LongViT on TCGA Subtyping 2 | 3 | ## Setup 4 | 5 | 1. Download TCGA diagnostic whole slides from [NIH Genomic Data Commons Data Portal](https://portal.gdc.cancer.gov/), and organize the dataset (e.g., BRCA WSIs) as following structure: 6 | 7 | ``` 8 | /path/to/your_WSIs/ 9 | TCGA-3C-AALI-01Z-00-DX1.F6E9A5DF-D8FB-45CF-B4BD-C6B76294C291.svs 10 | ... 11 | TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9.svs 12 | ... 13 | ``` 14 | 15 | 2. Download [dataset annotation csv](https://github.com/mahmoodlab/HIPT/tree/master/2-Weakly-Supervised-Subtyping/dataset_csv) and [splits for cross validation](https://github.com/mahmoodlab/HIPT/tree/master/2-Weakly-Supervised-Subtyping/splits/10foldcv_subtype) from the HIPT repository. 16 | 17 | 3. Generate the index json files of each split using the following command. 18 | ``` 19 | # Modify the `csv_path` and `csv_split_path` to your path. 20 | python data_preprocessing/create_tcga_subtyping_index.py 21 | ``` 22 | 23 | 4. Resize whole slide images to the desired size for finetuning. 24 | ``` 25 | python data_preprocessing/convert_wsi_to_images.py /path/to/your_WSIs /path/to/your_resized_WSIs ${target_size} ${wsi_level} 26 | ``` 27 | 28 | 5. (Optional) For very large images (e.g., 32,768x32,768), we suggest parallelizing the training across multiple GPU devices due to the constraints of computation and memory. We split the sequence of millions of patches along the sequence dimension. 29 | ``` 30 | # num_splits is equal to the number of GPUs you used (e.g., 8 in our experiment) 31 | python data_preprocessing/split_to_small_images.py /path/to/your_resized_WSIs /path/to/your_splited_WSIs --num_splits ${num_splits} 32 | ``` 33 | 34 | 6. (Optional) We find performing image augmentation slightly improves the performance. For very large images (e.g., 32,768x32,768), we perform the augmentation and cache the resulted images of each epoch. 35 | ``` 36 | # Run the command 10 times (number of epochs in finetuning) using i from 0-9 37 | python data_preprocessing/cache_transformed_images.py /path/to/your_resized_WSIs /path/to/your_augmentated_WSIs/epoch_$i --input_size 32768 38 | ``` 39 | 40 | Split these cached images as in step 5 and organize the data as following structure: 41 | ``` 42 | /path/to/your_splited_WSIs/ 43 | epoch_0/ 44 | TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9_0.jpg 45 | ... 46 | TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9_7.jpg 47 | ... 48 | epoch_1/ 49 | TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9_0.jpg 50 | ... 51 | TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9_7.jpg 52 | ... 53 | ... 54 | epoch_5/ 55 | ... 56 | epoch_9/ 57 | wo_augmentation/ 58 | ``` 59 | 60 | 61 | ## Example: Fine-tuning LongViT on TCGA Subtyping 62 | 63 | The LongViT model can be fine-tuned using 8 V100-32GB. For images with a size less than or equal to 16,384x16,384, we can directly perform finetuning without using sequence parallel. 64 | 65 | ```bash 66 | # IMAGE_SIZE - {1024, 4096, 8192, 16384} 67 | # TASK - {"brca", "kidney", "lung"} 68 | # K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} 69 | python -m torch.distributed.launch --nproc_per_node=8 run_longvit_finetuning.py \ 70 | --input_size ${IMAGE_SIZE} \ 71 | --model longvit_small_patch32_${IMAGE_SIZE} \ 72 | --task tcga_${TASK}_subtyping \ 73 | --batch_size 1 \ 74 | --layer_decay 1.0 \ 75 | --lr 5e-5 \ 76 | --update_freq 1 \ 77 | --epochs 10 \ 78 | --warmup_epochs 1 \ 79 | --drop_path 0.1 \ 80 | --finetune /your_longvit_model_path/longvit_small_patch32_1024.pth 81 | --data_path ./subtyping_split_index/tcga_${TASK} \ 82 | --image_dir /path/to/your_resized_WSIs \ 83 | --output_dir /path/to/save/your_model \ 84 | --log_dir /path/to/save/your_model/log \ 85 | --weight_decay 0.05 \ 86 | --seed 42 \ 87 | --save_ckpt_freq 5 \ 88 | --k_fold ${K_FOLD} \ 89 | --num_workers 1 \ 90 | --enable_deepspeed \ 91 | --model_key teacher \ 92 | --randaug 93 | ``` 94 | - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretraining). 95 | - `--randaug`: perform image augmentation. 96 | 97 | 98 | Sequence parallel of training on 32,768x32,768 images: 99 | 100 | ```bash 101 | # TASK - {"brca", "kidney", "lung"} 102 | # K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} 103 | python -m torch.distributed.launch --nproc_per_node=8 run_longvit_finetuning.py \ 104 | --input_size 32768 \ 105 | --model longvit_small_patch32_32768 \ 106 | --task tcga_${TASK}_subtyping \ 107 | --batch_size 2 \ 108 | --layer_decay 1.0 \ 109 | --lr 5e-5 \ 110 | --update_freq 4 \ 111 | --epochs 10 \ 112 | --warmup_epochs 1 \ 113 | --drop_path 0.1 \ 114 | --finetune /your_longvit_model_path/longvit_small_patch32_1024.pth 115 | --data_path ./subtyping_split_index/tcga_${TASK} \ 116 | --image_dir /path/to/your_splited_WSIs \ 117 | --output_dir /path/to/save/your_model \ 118 | --log_dir /path/to/save/your_model/log \ 119 | --weight_decay 0.05 \ 120 | --seed 42 \ 121 | --save_ckpt_freq 5 \ 122 | --k_fold ${K_FOLD} \ 123 | --num_workers 1 \ 124 | --enable_deepspeed \ 125 | --model_key teacher \ 126 | --seq_parallel \ 127 | --cached_randaug 128 | ``` 129 | - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretraining). 130 | - `--seq_parallel`: parallelize the training for very large images. 131 | - `--cached_randaug`: perform training on the cached augmented images. 132 | 133 | 134 | ## Example: Evaluate LongViT on TCGA Subtyping 135 | 136 | ```bash 137 | # IMAGE_SIZE - {1024, 4096, 8192, 16384, 32768} 138 | # TASK - {"brca", "kidney", "lung"} 139 | # K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} 140 | python -m torch.distributed.launch --nproc_per_node=1 run_longvit_finetuning.py \ 141 | --input_size ${IMAGE_SIZE} \ 142 | --model longvit_small_patch32_${IMAGE_SIZE} \ 143 | --task tcga_${TASK}_subtyping \ 144 | --batch_size 1 \ 145 | --layer_decay 1.0 \ 146 | --lr 5e-5 \ 147 | --update_freq 1 \ 148 | --epochs 10 \ 149 | --warmup_epochs 1 \ 150 | --drop_path 0.1 \ 151 | --finetune /path/to/save/your_model/checkpoint-best/mp_rank_00_model_states.pt \ 152 | --data_path ./subtyping_split_index/tcga_${TASK} \ 153 | --image_dir /path/to/your_resized_WSIs \ 154 | --output_dir /path/to/save/your_model \ 155 | --log_dir /path/to/save/your_model/log \ 156 | --weight_decay 0.05 \ 157 | --seed 42 \ 158 | --save_ckpt_freq 5 \ 159 | --k_fold ${K_FOLD} \ 160 | --num_workers 1 \ 161 | --enable_deepspeed \ 162 | --model_key module \ 163 | --eval \ 164 | --no_auto_resume 165 | ``` 166 | - `--eval`: performing evaluation on test set. 167 | - `--finetune`: best val model used for test. 168 | 169 | For the model trained with sequence parallel, add `--seq_parallel` and use the same number of GPUs as training to perform evaluation. -------------------------------------------------------------------------------- /examples/longvit/get_started/get_started_for_tcga_survival_prediction.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning LongViT on TCGA Survival Prediction 2 | 3 | ## Setup 4 | 5 | 1. Download TCGA diagnostic whole slides from [NIH Genomic Data Commons Data Portal](https://portal.gdc.cancer.gov/), and organize the dataset (e.g., BRCA WSIs) as following structure: 6 | 7 | ``` 8 | /path/to/your_WSIs/ 9 | TCGA-3C-AALI-01Z-00-DX1.F6E9A5DF-D8FB-45CF-B4BD-C6B76294C291.svs 10 | ... 11 | TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9.svs 12 | ... 13 | ``` 14 | 15 | 2. Download [dataset annotation csv](https://github.com/mahmoodlab/MCAT/tree/master/datasets_csv_sig) and [splits for cross validation](https://github.com/mahmoodlab/MCAT/tree/master/splits/5foldcv) from the MCAT repository. 16 | 17 | 3. Generate the index json files of each split using the following command. 18 | ``` 19 | # Modify the `csv_path` and `csv_split_path` to your path. 20 | python data_preprocessing/create_tcga_survival_index.py 21 | ``` 22 | 23 | 4. Resize whole slide images to the desired size for finetuning. 24 | ``` 25 | python data_preprocessing/convert_wsi_to_images.py /path/to/your_WSIs /path/to/your_resized_WSIs ${target_size} ${wsi_level} 26 | ``` 27 | 28 | 5. (Optional) For very large images (e.g., 32,768x32,768), we suggest parallelizing the training across multiple GPU devices due to the constraints of computation and memory. We split the sequence of millions of patches along the sequence dimension. 29 | ``` 30 | # num_splits is equal to the number of GPUs you used (e.g., 8 in our experiment) 31 | python data_preprocessing/split_to_small_images.py /path/to/your_resized_WSIs /path/to/your_splited_WSIs --num_splits ${num_splits} 32 | ``` 33 | 34 | 35 | ## Example: Fine-tuning LongViT on TCGA Survival Prediction 36 | 37 | The LongViT model can be fine-tuned using 8 V100-32GB. For images with a size less than or equal to 16,384x16,384, we can directly perform finetuning without using sequence parallel. 38 | 39 | ```bash 40 | # IMAGE_SIZE - {1024, 4096, 8192, 16384} 41 | # TASK - {"brca", "kidney", "lung"} 42 | # K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} 43 | python -m torch.distributed.launch --nproc_per_node=8 run_longvit_finetuning.py \ 44 | --input_size ${IMAGE_SIZE} \ 45 | --model longvit_small_patch32_${IMAGE_SIZE} \ 46 | --task tcga_${TASK}_survival \ 47 | --batch_size 1 \ 48 | --layer_decay 1.0 \ 49 | --lr 5e-5 \ 50 | --update_freq 1 \ 51 | --epochs 10 \ 52 | --warmup_epochs 1 \ 53 | --drop_path 0.1 \ 54 | --finetune /your_longvit_model_path/longvit_small_patch32_1024.pth 55 | --data_path ./survival_split_index/tcga_${TASK} \ 56 | --image_dir /path/to/your_resized_WSIs \ 57 | --output_dir /path/to/save/your_model \ 58 | --log_dir /path/to/save/your_model/log \ 59 | --weight_decay 0.05 \ 60 | --seed 42 \ 61 | --save_ckpt_freq 5 \ 62 | --k_fold ${K_FOLD} \ 63 | --num_workers 1 \ 64 | --enable_deepspeed \ 65 | --model_key teacher \ 66 | --randaug 67 | ``` 68 | - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretraining). 69 | - `--randaug`: perform image augmentation. 70 | 71 | 72 | Parallelize the training of 32,768x32,768 images: 73 | 74 | ```bash 75 | # TASK - {"brca", "kidney", "lung"} 76 | # K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} 77 | python -m torch.distributed.launch --nproc_per_node=8 run_longvit_finetuning.py \ 78 | --input_size 32768 \ 79 | --model longvit_small_patch32_32768 \ 80 | --task tcga_${TASK}_survival \ 81 | --batch_size 2 \ 82 | --layer_decay 1.0 \ 83 | --lr 5e-5 \ 84 | --update_freq 4 \ 85 | --epochs 10 \ 86 | --warmup_epochs 1 \ 87 | --drop_path 0.1 \ 88 | --finetune /your_longvit_model_path/longvit_small_patch32_1024.pth 89 | --data_path ./subtyping_split_index/tcga_${TASK} \ 90 | --image_dir /path/to/your_splited_WSIs \ 91 | --output_dir /path/to/save/your_model \ 92 | --log_dir /path/to/save/your_model/log \ 93 | --weight_decay 0.05 \ 94 | --seed 42 \ 95 | --save_ckpt_freq 5 \ 96 | --k_fold ${K_FOLD} \ 97 | --num_workers 1 \ 98 | --enable_deepspeed \ 99 | --model_key teacher \ 100 | --seq_parallel 101 | ``` 102 | - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretraining). 103 | - `--seq_parallel`: parallelize the training for very large images. 104 | 105 | 106 | ## Example: Evaluate LongViT on TCGA Subtyping 107 | 108 | ```bash 109 | # IMAGE_SIZE - {1024, 4096, 8192, 16384, 32768} 110 | # TASK - {"brca", "kidney", "lung"} 111 | # K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} 112 | python -m torch.distributed.launch --nproc_per_node=1 run_longvit_finetuning.py \ 113 | --input_size ${IMAGE_SIZE} \ 114 | --model longvit_small_patch32_${IMAGE_SIZE} \ 115 | --task tcga_${TASK}_survival \ 116 | --batch_size 1 \ 117 | --layer_decay 1.0 \ 118 | --lr 5e-5 \ 119 | --update_freq 1 \ 120 | --epochs 10 \ 121 | --warmup_epochs 1 \ 122 | --drop_path 0.1 \ 123 | --finetune /path/to/save/your_model/checkpoint-best/mp_rank_00_model_states.pt \ 124 | --data_path ./survival_split_index/tcga_${TASK} \ 125 | --image_dir /path/to/your_resized_WSIs \ 126 | --output_dir /path/to/save/your_model \ 127 | --log_dir /path/to/save/your_model/log \ 128 | --weight_decay 0.05 \ 129 | --seed 42 \ 130 | --save_ckpt_freq 5 \ 131 | --k_fold ${K_FOLD} \ 132 | --num_workers 1 \ 133 | --enable_deepspeed \ 134 | --model_key module \ 135 | --eval \ 136 | --no_auto_resume 137 | ``` 138 | - `--eval`: performing evaluation. 139 | - `--finetune`: best val model. 140 | 141 | For the model trained with sequence parallel, add `--seq_parallel` and use the same number of GPUs as training to perform evaluation. -------------------------------------------------------------------------------- /examples/longvit/longvit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Mostly copy-paste from timm library. 16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 17 | """ 18 | import math 19 | from functools import partial 20 | 21 | import utils 22 | import torch 23 | import torch.nn as nn 24 | 25 | from torchscale.architecture.encoder import Encoder 26 | from torchscale.model.LongNet import LongNetEncoder 27 | from torchscale.architecture.config import EncoderConfig 28 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 29 | 30 | 31 | def trunc_normal_(tensor, mean=0., std=1.): 32 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 33 | 34 | 35 | def drop_path(x, drop_prob: float = 0., training: bool = False): 36 | if drop_prob == 0. or not training: 37 | return x 38 | keep_prob = 1 - drop_prob 39 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 40 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 41 | random_tensor.floor_() # binarize 42 | output = x.div(keep_prob) * random_tensor 43 | return output 44 | 45 | 46 | class DropPath(nn.Module): 47 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 48 | """ 49 | def __init__(self, drop_prob=None): 50 | super(DropPath, self).__init__() 51 | self.drop_prob = drop_prob 52 | 53 | def forward(self, x): 54 | return drop_path(x, self.drop_prob, self.training) 55 | 56 | 57 | class Mlp(nn.Module): 58 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 59 | super().__init__() 60 | out_features = out_features or in_features 61 | hidden_features = hidden_features or in_features 62 | self.fc1 = nn.Linear(in_features, hidden_features) 63 | self.act = act_layer() 64 | self.fc2 = nn.Linear(hidden_features, out_features) 65 | self.drop = nn.Dropout(drop) 66 | 67 | def forward(self, x): 68 | x = self.fc1(x) 69 | x = self.act(x) 70 | x = self.drop(x) 71 | x = self.fc2(x) 72 | x = self.drop(x) 73 | return x 74 | 75 | 76 | class Attention(nn.Module): 77 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 78 | super().__init__() 79 | self.num_heads = num_heads 80 | head_dim = dim // num_heads 81 | self.scale = qk_scale or head_dim ** -0.5 82 | 83 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 84 | self.attn_drop = nn.Dropout(attn_drop) 85 | self.proj = nn.Linear(dim, dim) 86 | self.proj_drop = nn.Dropout(proj_drop) 87 | 88 | def forward(self, x): 89 | B, N, C = x.shape 90 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 91 | q, k, v = qkv[0], qkv[1], qkv[2] 92 | 93 | attn = (q @ k.transpose(-2, -1)) * self.scale 94 | attn = attn.softmax(dim=-1) 95 | attn = self.attn_drop(attn) 96 | 97 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 98 | x = self.proj(x) 99 | x = self.proj_drop(x) 100 | return x, attn 101 | 102 | 103 | class Block(nn.Module): 104 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 105 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 106 | super().__init__() 107 | self.norm1 = norm_layer(dim) 108 | self.attn = Attention( 109 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 110 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 111 | self.norm2 = norm_layer(dim) 112 | mlp_hidden_dim = int(dim * mlp_ratio) 113 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 114 | 115 | def forward(self, x, return_attention=False): 116 | y, attn = self.attn(self.norm1(x)) 117 | if return_attention: 118 | return attn 119 | x = x + self.drop_path(y) 120 | x = x + self.drop_path(self.mlp(self.norm2(x))) 121 | return x 122 | 123 | 124 | class PatchEmbed(nn.Module): 125 | """ Image to Patch Embedding 126 | """ 127 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 128 | super().__init__() 129 | num_patches = (img_size // patch_size) * (img_size // patch_size) 130 | self.img_size = img_size 131 | self.patch_size = patch_size 132 | self.num_patches = num_patches 133 | 134 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 135 | 136 | def forward(self, x): 137 | B, C, H, W = x.shape 138 | x = self.proj(x).flatten(2).transpose(1, 2) 139 | return x 140 | 141 | 142 | class LongViT(nn.Module): 143 | """ Vision Transformer """ 144 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 145 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 146 | drop_path_rate=0., norm_layer=nn.LayerNorm, flash_attention=True, dilated_ratio="[1,2,4,8,16]", segment_length="[64,128,256,512,1024]", checkpoint_activations=False, seq_parallel=False, **kwargs): 147 | super().__init__() 148 | self.num_features = self.embed_dim = embed_dim 149 | 150 | self.patch_embed = PatchEmbed( 151 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 152 | num_patches = self.patch_embed.num_patches 153 | 154 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 155 | self.pos_drop = nn.Dropout(p=drop_rate) 156 | 157 | if img_size == 4096: 158 | segment_length = "[1024,2048,4096,8192,16384]" 159 | elif img_size == 8192: 160 | segment_length = "[1024,4096,8192,16384,65536]" 161 | elif img_size == 16384: 162 | segment_length = "[1024,4096,16384,65536,262144]" 163 | elif img_size == 32768: 164 | segment_length = "[1024,4096,32768,262144,1048576]" 165 | 166 | self.seq_parallel = seq_parallel 167 | encoder_config = EncoderConfig( 168 | img_size=img_size, patch_size=patch_size, vocab_size=64010, multiway=False, 169 | layernorm_embedding=False, normalize_output=False, no_output_layer=True, 170 | drop_path_rate=drop_path_rate, encoder_embed_dim=embed_dim, encoder_attention_heads=num_heads, 171 | encoder_ffn_embed_dim=int(embed_dim * mlp_ratio), encoder_layers=depth, 172 | checkpoint_activations=checkpoint_activations, flash_attention=flash_attention, 173 | dilated_ratio=dilated_ratio, segment_length=segment_length, seq_parallel=seq_parallel, 174 | ) 175 | if flash_attention: 176 | print("Using Torchscale LoneNetEncoder") 177 | print("segment_length: {}".format(encoder_config.segment_length)) 178 | print("dilated_ratio: {}".format(encoder_config.dilated_ratio)) 179 | print("checkpoint_activations: {}".format(encoder_config.checkpoint_activations)) 180 | print("drop_path_rate: {}".format(encoder_config.drop_path_rate)) 181 | self.encoder = LongNetEncoder(encoder_config, embed_tokens=None, embed_positions=None, 182 | output_projection=None, is_encoder_decoder=False,) 183 | else: 184 | print("Using Torchscale Encoder") 185 | self.encoder = Encoder(encoder_config, embed_tokens=None, embed_positions=None, 186 | output_projection=None, is_encoder_decoder=False,) 187 | 188 | trunc_normal_(self.pos_embed, std=.02) 189 | self.apply(self._init_weights) 190 | 191 | def _init_weights(self, m): 192 | if isinstance(m, nn.Linear): 193 | trunc_normal_(m.weight, std=.02) 194 | if isinstance(m, nn.Linear) and m.bias is not None: 195 | nn.init.constant_(m.bias, 0) 196 | elif isinstance(m, nn.LayerNorm): 197 | nn.init.constant_(m.bias, 0) 198 | nn.init.constant_(m.weight, 1.0) 199 | 200 | def interpolate_pos_encoding(self, x, w, h): 201 | npatch = x.shape[1] 202 | N = self.pos_embed.shape[1] 203 | if npatch == N and w == h: 204 | return self.pos_embed 205 | patch_pos_embed = self.pos_embed 206 | dim = x.shape[-1] 207 | w0 = w // self.patch_embed.patch_size 208 | h0 = h // self.patch_embed.patch_size 209 | # we add a small number to avoid floating point error in the interpolation 210 | # see discussion at https://github.com/facebookresearch/dino/issues/8 211 | w0, h0 = w0 + 0.1, h0 + 0.1 212 | patch_pos_embed = nn.functional.interpolate( 213 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 214 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 215 | mode='bicubic', 216 | ) 217 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 218 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 219 | return patch_pos_embed 220 | 221 | def prepare_tokens(self, x): 222 | B, nc, w, h = x.shape 223 | x = self.patch_embed(x) # patch linear embedding 224 | 225 | # add positional encoding to each token 226 | if self.seq_parallel: 227 | rank_seq_len = x.size(1) 228 | cur_rank = utils.get_rank() 229 | start_idx = cur_rank * rank_seq_len 230 | end_idx = (cur_rank + 1) * rank_seq_len 231 | x = x + self.pos_embed[:, start_idx:end_idx, :] 232 | else: 233 | x = x + self.interpolate_pos_encoding(x, w, h) 234 | 235 | return self.pos_drop(x) 236 | 237 | def forward(self, x): 238 | x = self.prepare_tokens(x) 239 | x = self.encoder(src_tokens=None, token_embeddings=x)["encoder_out"] 240 | return x 241 | -------------------------------------------------------------------------------- /examples/longvit/modeling_finetune.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit3 4 | # Copyright (c) 2023 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # --------------------------------------------------------' 7 | 8 | import utils 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | 14 | from timm.models.registry import register_model 15 | from functools import partial 16 | from longvit import LongViT 17 | from torchscale.architecture.config import EncoderConfig 18 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 19 | 20 | 21 | def _get_small_config( 22 | img_size=1024, patch_size=32, drop_path_rate=0, 23 | checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs 24 | ): 25 | return EncoderConfig( 26 | img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=False, 27 | layernorm_embedding=False, normalize_output=False, no_output_layer=True, 28 | drop_path_rate=drop_path_rate, encoder_embed_dim=384, encoder_attention_heads=16, 29 | encoder_ffn_embed_dim=int(384 * mlp_ratio), encoder_layers=12, 30 | checkpoint_activations=checkpoint_activations, 31 | ) 32 | 33 | 34 | def trunc_normal_(tensor, mean=0., std=1.): 35 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 36 | 37 | 38 | class LongViTForTCGAClassification(nn.Module): 39 | def __init__( 40 | self, 41 | args, 42 | num_classes, 43 | norm_layer=nn.LayerNorm, 44 | seq_parallel=False, 45 | **kwargs 46 | ): 47 | super().__init__() 48 | self.model = LongViT( 49 | img_size=args.img_size, patch_size=args.patch_size, embed_dim=args.encoder_embed_dim, 50 | depth=args.encoder_layers, num_heads=args.encoder_attention_heads, 51 | mlp_ratio=4, drop_path_rate=args.drop_path_rate, 52 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 53 | checkpoint_activations=args.checkpoint_activations, seq_parallel=seq_parallel 54 | ) 55 | embed_dim = args.encoder_embed_dim 56 | self.depth = args.encoder_layers 57 | self.fc_norm = norm_layer(embed_dim) 58 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 59 | 60 | self.fc_norm.apply(self._init_weights) 61 | self.head.apply(self._init_weights) 62 | 63 | def _init_weights(self, m): 64 | if isinstance(m, nn.Linear): 65 | trunc_normal_(m.weight, std=.02) 66 | if isinstance(m, nn.Linear) and m.bias is not None: 67 | nn.init.constant_(m.bias, 0) 68 | elif isinstance(m, nn.LayerNorm): 69 | nn.init.constant_(m.bias, 0) 70 | nn.init.constant_(m.weight, 1.0) 71 | 72 | def get_num_layers(self): 73 | return self.depth 74 | 75 | @torch.jit.ignore 76 | def no_weight_decay(self): 77 | return {'model.pos_embed'} 78 | 79 | def forward(self, image, **kwargs): 80 | x = self.model(image) 81 | t = x[:, :, :] 82 | cls_x = self.fc_norm(t.mean(1)) 83 | return self.head(cls_x) 84 | 85 | 86 | class LongViTForTCGAClassificationSeqParallel(nn.Module): 87 | def __init__( 88 | self, 89 | args, 90 | num_classes, 91 | norm_layer=nn.LayerNorm, 92 | seq_parallel=False, 93 | **kwargs 94 | ): 95 | super().__init__() 96 | self.model = LongViT( 97 | img_size=args.img_size, patch_size=args.patch_size, embed_dim=args.encoder_embed_dim, 98 | depth=args.encoder_layers, num_heads=args.encoder_attention_heads, 99 | mlp_ratio=4, drop_path_rate=args.drop_path_rate, 100 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 101 | checkpoint_activations=args.checkpoint_activations, seq_parallel=seq_parallel, 102 | ) 103 | embed_dim = args.encoder_embed_dim 104 | self.depth = args.encoder_layers 105 | self.fc_norm = norm_layer(embed_dim) 106 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 107 | 108 | self.fc_norm.apply(self._init_weights) 109 | self.head.apply(self._init_weights) 110 | 111 | def _init_weights(self, m): 112 | if isinstance(m, nn.Linear): 113 | trunc_normal_(m.weight, std=.02) 114 | if isinstance(m, nn.Linear) and m.bias is not None: 115 | nn.init.constant_(m.bias, 0) 116 | elif isinstance(m, nn.LayerNorm): 117 | nn.init.constant_(m.bias, 0) 118 | nn.init.constant_(m.weight, 1.0) 119 | 120 | def get_num_layers(self): 121 | return self.depth 122 | 123 | @torch.jit.ignore 124 | def no_weight_decay(self): 125 | return {'model.pos_embed'} 126 | 127 | def forward(self, image, **kwargs): 128 | x = self.model(image) 129 | t = x[:, :, :].contiguous() 130 | gatherd_t = utils.gather_tcga_features(t) 131 | cls_x = self.fc_norm(gatherd_t.mean(1)) 132 | return self.head(cls_x) 133 | 134 | 135 | @register_model 136 | def longvit_small_patch32_1024_tcga_subtyping(pretrained=False, task=None, **kwargs): 137 | args = _get_small_config(img_size=1024, patch_size=32, **kwargs) 138 | if task == "tcga_kidney_subtyping": 139 | model = LongViTForTCGAClassification(args, num_classes=3, **kwargs) 140 | else: 141 | model = LongViTForTCGAClassification(args, num_classes=2, **kwargs) 142 | return model 143 | 144 | 145 | @register_model 146 | def longvit_small_patch32_4096_tcga_subtyping(pretrained=False, task=None, **kwargs): 147 | args = _get_small_config(img_size=4096, patch_size=32, **kwargs) 148 | if task == "tcga_kidney_subtyping": 149 | model = LongViTForTCGAClassification(args, num_classes=3, **kwargs) 150 | else: 151 | model = LongViTForTCGAClassification(args, num_classes=2, **kwargs) 152 | return model 153 | 154 | 155 | @register_model 156 | def longvit_small_patch32_8192_tcga_subtyping(pretrained=False, task=None, **kwargs): 157 | args = _get_small_config(img_size=8192, patch_size=32, **kwargs) 158 | args.checkpoint_activations = True 159 | if task == "tcga_kidney_subtyping": 160 | model = LongViTForTCGAClassification(args, num_classes=3, **kwargs) 161 | else: 162 | model = LongViTForTCGAClassification(args, num_classes=2, **kwargs) 163 | return model 164 | 165 | 166 | @register_model 167 | def longvit_small_patch32_16384_tcga_subtyping(pretrained=False, task=None, **kwargs): 168 | args = _get_small_config(img_size=16384, patch_size=32, **kwargs) 169 | args.checkpoint_activations = True 170 | if task == "tcga_kidney_subtyping": 171 | model = LongViTForTCGAClassification(args, num_classes=3, **kwargs) 172 | else: 173 | model = LongViTForTCGAClassification(args, num_classes=2, **kwargs) 174 | return model 175 | 176 | 177 | @register_model 178 | def longvit_small_patch32_32768_tcga_subtyping(pretrained=False, task=None, seq_parallel=False, **kwargs): 179 | args = _get_small_config(img_size=32768, patch_size=32, **kwargs) 180 | args.checkpoint_activations = True 181 | if task == "tcga_kidney_subtyping": 182 | model = LongViTForTCGAClassificationSeqParallel(args, num_classes=3, seq_parallel=seq_parallel, **kwargs) 183 | else: 184 | model = LongViTForTCGAClassificationSeqParallel(args, num_classes=2, seq_parallel=seq_parallel, **kwargs) 185 | return model 186 | 187 | 188 | @register_model 189 | def longvit_small_patch32_1024_tcga_survival(pretrained=False, task=None, **kwargs): 190 | args = _get_small_config(img_size=1024, patch_size=32, **kwargs) 191 | model = LongViTForTCGAClassification(args, num_classes=4, **kwargs) 192 | return model 193 | 194 | 195 | @register_model 196 | def longvit_small_patch32_4096_tcga_survival(pretrained=False, task=None, **kwargs): 197 | args = _get_small_config(img_size=4096, patch_size=32, **kwargs) 198 | model = LongViTForTCGAClassification(args, num_classes=4, **kwargs) 199 | return model 200 | 201 | 202 | @register_model 203 | def longvit_small_patch32_8192_tcga_survival(pretrained=False, task=None, **kwargs): 204 | args = _get_small_config(img_size=8192, patch_size=32, **kwargs) 205 | args.checkpoint_activations = True 206 | model = LongViTForTCGAClassification(args, num_classes=4, **kwargs) 207 | return model 208 | 209 | 210 | @register_model 211 | def longvit_small_patch32_16384_tcga_survival(pretrained=False, task=None, **kwargs): 212 | args = _get_small_config(img_size=16384, patch_size=32, **kwargs) 213 | args.checkpoint_activations = True 214 | model = LongViTForTCGAClassification(args, num_classes=4, **kwargs) 215 | return model 216 | 217 | 218 | @register_model 219 | def longvit_small_patch32_32768_tcga_survival(pretrained=False, task=None, seq_parallel=False, **kwargs): 220 | args = _get_small_config(img_size=32768, patch_size=32, **kwargs) 221 | args.checkpoint_activations = True 222 | model = LongViTForTCGAClassificationSeqParallel(args, num_classes=4, seq_parallel=seq_parallel, **kwargs) 223 | return model 224 | -------------------------------------------------------------------------------- /examples/longvit/optim_factory.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit3 4 | # Copyright (c) 2023 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # --------------------------------------------------------' 7 | 8 | from torch import optim as optim 9 | from timm.optim.lookahead import Lookahead 10 | 11 | import json 12 | 13 | 14 | def get_num_layer_for_vit(var_name, num_max_layer): 15 | if "embed" in var_name: 16 | return 0 17 | elif var_name in ( 18 | "cls_token", "mask_token", "pos_embed", "model.pos_embed", "language_pos_embed", 19 | "word_embeddings.weight", "vision_cls_token", "vision_pos_embed" 20 | ): 21 | return 0 22 | elif var_name.startswith("patch_embed"): 23 | return 0 24 | elif var_name.startswith("rel_pos_bias"): 25 | return num_max_layer - 1 26 | elif "layers." in var_name: 27 | layer_id = int(var_name.split('layers.')[1].split('.')[0]) 28 | return layer_id + 1 29 | else: 30 | return num_max_layer - 1 31 | 32 | 33 | def get_is_head_flag_for_vit(var_name, num_max_layer): 34 | if var_name.startswith("head"): 35 | return 1 36 | # elif var_name.startswith("pooler"): 37 | # return 1 38 | else: 39 | return 0 40 | 41 | 42 | class LayerDecayValueAssigner(object): 43 | def __init__(self, values, scale_handler=None): 44 | self.scale_handler = scale_handler or get_num_layer_for_vit 45 | self.values = values 46 | 47 | def get_scale(self, layer_id): 48 | return self.values[layer_id] 49 | 50 | def get_layer_id(self, var_name): 51 | return self.scale_handler(var_name, len(self.values)) 52 | 53 | 54 | # The implementation code is modified from Timm (https://github.com/huggingface/pytorch-image-models/tree/main/timm 55 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 56 | parameter_group_names = {} 57 | parameter_group_vars = {} 58 | 59 | for name, param in model.named_parameters(): 60 | if not param.requires_grad: 61 | continue # frozen weights 62 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 63 | group_name = "no_decay" 64 | this_weight_decay = 0. 65 | else: 66 | group_name = "decay" 67 | this_weight_decay = weight_decay 68 | if get_num_layer is not None: 69 | layer_id = get_num_layer(name) 70 | group_name = "layer_%d_%s" % (layer_id, group_name) 71 | else: 72 | layer_id = None 73 | 74 | if group_name not in parameter_group_names: 75 | if get_layer_scale is not None: 76 | scale = get_layer_scale(layer_id) 77 | else: 78 | scale = 1. 79 | 80 | parameter_group_names[group_name] = { 81 | "weight_decay": this_weight_decay, 82 | "params": [], 83 | "lr_scale": scale 84 | } 85 | parameter_group_vars[group_name] = { 86 | "weight_decay": this_weight_decay, 87 | "params": [], 88 | "lr_scale": scale 89 | } 90 | 91 | parameter_group_vars[group_name]["params"].append(param) 92 | parameter_group_names[group_name]["params"].append(name) 93 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 94 | return list(parameter_group_vars.values()) 95 | 96 | 97 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 98 | opt_lower = args.opt.lower() 99 | weight_decay = args.weight_decay 100 | if weight_decay and filter_bias_and_bn: 101 | skip = {} 102 | if skip_list is not None: 103 | skip = skip_list 104 | elif hasattr(model, 'no_weight_decay'): 105 | skip = model.no_weight_decay() 106 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 107 | weight_decay = 0. 108 | else: 109 | parameters = model.parameters() 110 | 111 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 112 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 113 | opt_args['eps'] = args.opt_eps 114 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 115 | opt_args['betas'] = args.opt_betas 116 | 117 | opt_split = opt_lower.split('_') 118 | opt_lower = opt_split[-1] 119 | if opt_lower == 'adamw': 120 | optimizer = optim.AdamW(parameters, **opt_args) 121 | else: 122 | raise ValueError("Invalid optimizer") 123 | 124 | if len(opt_split) > 1: 125 | if opt_split[0] == 'lookahead': 126 | optimizer = Lookahead(optimizer) 127 | 128 | return optimizer 129 | -------------------------------------------------------------------------------- /examples/longvit/pretraining/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Mostly copy-paste from timm library. 16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 17 | """ 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | from utils import trunc_normal_ 25 | from torchscale.architecture.encoder import Encoder 26 | from torchscale.model.LongNet import LongNetEncoder 27 | from torchscale.architecture.config import EncoderConfig 28 | 29 | 30 | def drop_path(x, drop_prob: float = 0., training: bool = False): 31 | if drop_prob == 0. or not training: 32 | return x 33 | keep_prob = 1 - drop_prob 34 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 35 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 36 | random_tensor.floor_() # binarize 37 | output = x.div(keep_prob) * random_tensor 38 | return output 39 | 40 | 41 | class DropPath(nn.Module): 42 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 43 | """ 44 | def __init__(self, drop_prob=None): 45 | super(DropPath, self).__init__() 46 | self.drop_prob = drop_prob 47 | 48 | def forward(self, x): 49 | return drop_path(x, self.drop_prob, self.training) 50 | 51 | 52 | class Mlp(nn.Module): 53 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 54 | super().__init__() 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | self.fc1 = nn.Linear(in_features, hidden_features) 58 | self.act = act_layer() 59 | self.fc2 = nn.Linear(hidden_features, out_features) 60 | self.drop = nn.Dropout(drop) 61 | 62 | def forward(self, x): 63 | x = self.fc1(x) 64 | x = self.act(x) 65 | x = self.drop(x) 66 | x = self.fc2(x) 67 | x = self.drop(x) 68 | return x 69 | 70 | 71 | class Attention(nn.Module): 72 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 73 | super().__init__() 74 | self.num_heads = num_heads 75 | head_dim = dim // num_heads 76 | self.scale = qk_scale or head_dim ** -0.5 77 | 78 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 79 | self.attn_drop = nn.Dropout(attn_drop) 80 | self.proj = nn.Linear(dim, dim) 81 | self.proj_drop = nn.Dropout(proj_drop) 82 | 83 | def forward(self, x): 84 | B, N, C = x.shape 85 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 86 | q, k, v = qkv[0], qkv[1], qkv[2] 87 | 88 | attn = (q @ k.transpose(-2, -1)) * self.scale 89 | attn = attn.softmax(dim=-1) 90 | attn = self.attn_drop(attn) 91 | 92 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 93 | x = self.proj(x) 94 | x = self.proj_drop(x) 95 | return x, attn 96 | 97 | 98 | class Block(nn.Module): 99 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 100 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 101 | super().__init__() 102 | self.norm1 = norm_layer(dim) 103 | self.attn = Attention( 104 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 105 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 106 | self.norm2 = norm_layer(dim) 107 | mlp_hidden_dim = int(dim * mlp_ratio) 108 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 109 | 110 | def forward(self, x, return_attention=False): 111 | y, attn = self.attn(self.norm1(x)) 112 | if return_attention: 113 | return attn 114 | x = x + self.drop_path(y) 115 | x = x + self.drop_path(self.mlp(self.norm2(x))) 116 | return x 117 | 118 | 119 | class PatchEmbed(nn.Module): 120 | """ Image to Patch Embedding 121 | """ 122 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 123 | super().__init__() 124 | num_patches = (img_size // patch_size) * (img_size // patch_size) 125 | self.img_size = img_size 126 | self.patch_size = patch_size 127 | self.num_patches = num_patches 128 | 129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 130 | 131 | def forward(self, x): 132 | B, C, H, W = x.shape 133 | x = self.proj(x).flatten(2).transpose(1, 2) 134 | return x 135 | 136 | 137 | class VisionTransformer(nn.Module): 138 | """ Vision Transformer """ 139 | def __init__(self, img_size=1024, patch_size=32, in_chans=3, num_classes=0, embed_dim=768, depth=12, 140 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 141 | drop_path_rate=0., norm_layer=nn.LayerNorm, flash_attention=True, dilated_ratio="[1,2,4,8,16]", segment_length="[64,128,256,512,1024]", checkpoint_activations=False, **kwargs): 142 | super().__init__() 143 | self.num_features = self.embed_dim = embed_dim 144 | 145 | self.patch_embed = PatchEmbed( 146 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 147 | num_patches = self.patch_embed.num_patches 148 | 149 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 150 | self.pos_drop = nn.Dropout(p=drop_rate) 151 | 152 | encoder_config = EncoderConfig( 153 | img_size=img_size, patch_size=patch_size, vocab_size=64010, multiway=False, 154 | layernorm_embedding=False, normalize_output=False, no_output_layer=True, 155 | drop_path_rate=drop_path_rate, encoder_embed_dim=embed_dim, encoder_attention_heads=num_heads, 156 | encoder_ffn_embed_dim=int(embed_dim * mlp_ratio), encoder_layers=depth, 157 | checkpoint_activations=checkpoint_activations, flash_attention=flash_attention, 158 | dilated_ratio=dilated_ratio, segment_length=segment_length, seq_parallel=False, 159 | ) 160 | if flash_attention: 161 | print("Using Torchscale LoneNetEncoder") 162 | self.encoder = LongNetEncoder(encoder_config, embed_tokens=None, embed_positions=None, 163 | output_projection=None, is_encoder_decoder=False,) 164 | else: 165 | print("Using Torchscale Encoder") 166 | self.encoder = Encoder(encoder_config, embed_tokens=None, embed_positions=None, 167 | output_projection=None, is_encoder_decoder=False,) 168 | 169 | self.norm = norm_layer(embed_dim) 170 | 171 | # Classifier head 172 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 173 | 174 | trunc_normal_(self.pos_embed, std=.02) 175 | self.apply(self._init_weights) 176 | 177 | def _init_weights(self, m): 178 | if isinstance(m, nn.Linear): 179 | trunc_normal_(m.weight, std=.02) 180 | if isinstance(m, nn.Linear) and m.bias is not None: 181 | nn.init.constant_(m.bias, 0) 182 | elif isinstance(m, nn.LayerNorm): 183 | nn.init.constant_(m.bias, 0) 184 | nn.init.constant_(m.weight, 1.0) 185 | 186 | def interpolate_pos_encoding(self, x, w, h): 187 | npatch = x.shape[1] 188 | N = self.pos_embed.shape[1] 189 | if npatch == N and w == h: 190 | return self.pos_embed 191 | patch_pos_embed = self.pos_embed 192 | dim = x.shape[-1] 193 | w0 = w // self.patch_embed.patch_size 194 | h0 = h // self.patch_embed.patch_size 195 | # we add a small number to avoid floating point error in the interpolation 196 | # see discussion at https://github.com/facebookresearch/dino/issues/8 197 | w0, h0 = w0 + 0.1, h0 + 0.1 198 | patch_pos_embed = nn.functional.interpolate( 199 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 200 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 201 | mode='bicubic', 202 | ) 203 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 204 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 205 | return patch_pos_embed 206 | 207 | def prepare_tokens(self, x): 208 | B, nc, w, h = x.shape 209 | x = self.patch_embed(x) # patch linear embedding 210 | 211 | # add positional encoding to each token 212 | x = x + self.interpolate_pos_encoding(x, w, h) 213 | 214 | return self.pos_drop(x) 215 | 216 | def forward(self, x): 217 | x = self.prepare_tokens(x) 218 | x = self.encoder(src_tokens=None, token_embeddings=x)["encoder_out"] 219 | x = self.norm(x) 220 | t = x[:, :, :] 221 | cls_x = t.mean(1) 222 | return cls_x 223 | 224 | 225 | def vit_small(patch_size=32, **kwargs): 226 | model = VisionTransformer( 227 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=16, mlp_ratio=4, 228 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 229 | return model 230 | 231 | 232 | class DINOHead(nn.Module): 233 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 234 | super().__init__() 235 | nlayers = max(nlayers, 1) 236 | if nlayers == 1: 237 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 238 | else: 239 | layers = [nn.Linear(in_dim, hidden_dim)] 240 | if use_bn: 241 | layers.append(nn.BatchNorm1d(hidden_dim)) 242 | layers.append(nn.GELU()) 243 | for _ in range(nlayers - 2): 244 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 245 | if use_bn: 246 | layers.append(nn.BatchNorm1d(hidden_dim)) 247 | layers.append(nn.GELU()) 248 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 249 | self.mlp = nn.Sequential(*layers) 250 | self.apply(self._init_weights) 251 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 252 | self.last_layer.weight_g.data.fill_(1) 253 | if norm_last_layer: 254 | self.last_layer.weight_g.requires_grad = False 255 | 256 | def _init_weights(self, m): 257 | if isinstance(m, nn.Linear): 258 | trunc_normal_(m.weight, std=.02) 259 | if isinstance(m, nn.Linear) and m.bias is not None: 260 | nn.init.constant_(m.bias, 0) 261 | 262 | def forward(self, x): 263 | x = self.mlp(x) 264 | x = nn.functional.normalize(x, dim=-1, p=2) 265 | x = self.last_layer(x) 266 | return x 267 | -------------------------------------------------------------------------------- /examples/longvit/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | timm==0.6.13 3 | Pillow==10.0.0 4 | blobfile==2.0.2 5 | mypy==1.4.1 6 | numpy==1.22.4 7 | pytest==7.2.2 8 | requests==2.31.0 9 | einops==0.6.1 10 | tensorboardX==1.8 11 | scipy==1.6.3 12 | ftfy==6.1.1 13 | opencv-python==4.8.0.74 14 | pyarrow==9.0.0 15 | transformers==4.8.1 16 | deepspeed==0.4.0 17 | scikit-survival==0.22.1 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | from io import open 5 | 6 | from setuptools import find_packages, setup 7 | 8 | setup( 9 | name="torchscale", 10 | version="0.2.0", 11 | author="TorchScale Team", 12 | author_email="Shuming.Ma@microsoft.com", 13 | description="Transformers at any scale", 14 | long_description=open("README.md", "r", encoding="utf-8").read(), 15 | long_description_content_type="text/markdown", 16 | keywords="Transformers at any scale", 17 | license="MIT", 18 | url="https://github.com/microsoft/torchscale", 19 | packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), 20 | install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.6.13", "einops"], 21 | python_requires=">=3.8.0", 22 | classifiers=[ 23 | "Programming Language :: Python :: 3", 24 | ], 25 | ) 26 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /tests/test_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import pytest 5 | import torch 6 | 7 | from torchscale.architecture.config import DecoderConfig 8 | from torchscale.architecture.decoder import Decoder 9 | 10 | testcases = [ 11 | {}, 12 | {"vocab_size": 64000}, 13 | {"activation_fn": "relu"}, 14 | {"drop_path_rate": 0.1}, 15 | {"decoder_normalize_before": False}, 16 | {"no_scale_embedding": False}, 17 | {"layernorm_embedding": True}, 18 | {"rel_pos_buckets": 32, "max_rel_pos": 256}, 19 | {"deepnorm": True, "subln": False, "decoder_normalize_before": False}, 20 | {"bert_init": True}, 21 | {"multiway": True}, 22 | {"share_decoder_input_output_embed": True}, 23 | {"checkpoint_activations": True}, 24 | {"fsdp": True}, 25 | ] 26 | 27 | 28 | @pytest.mark.parametrize("args", testcases) 29 | def test_decoder(args): 30 | config = DecoderConfig(**args) 31 | model = Decoder(config) 32 | prev_output_tokens = torch.ones(2, 10) 33 | token_embeddings = torch.rand(2, 10, config.decoder_embed_dim) 34 | model( 35 | prev_output_tokens=prev_output_tokens, 36 | token_embeddings=token_embeddings, 37 | features_only=True, 38 | ) 39 | -------------------------------------------------------------------------------- /tests/test_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import pytest 5 | import torch 6 | 7 | from torchscale.architecture.config import EncoderConfig 8 | from torchscale.architecture.encoder import Encoder 9 | 10 | testcases = [ 11 | {}, 12 | {"vocab_size": 64000}, 13 | {"activation_fn": "relu"}, 14 | {"drop_path_rate": 0.1}, 15 | {"encoder_normalize_before": False}, 16 | {"no_scale_embedding": False}, 17 | {"layernorm_embedding": True}, 18 | {"rel_pos_buckets": 32, "max_rel_pos": 256}, 19 | {"deepnorm": True, "subln": False, "encoder_normalize_before": False}, 20 | {"bert_init": True}, 21 | {"multiway": True}, 22 | {"share_encoder_input_output_embed": True}, 23 | {"checkpoint_activations": True}, 24 | {"fsdp": True}, 25 | ] 26 | 27 | 28 | @pytest.mark.parametrize("args", testcases) 29 | def test_encoder(args): 30 | config = EncoderConfig(**args) 31 | model = Encoder(config) 32 | token_embeddings = torch.rand(2, 10, config.encoder_embed_dim) 33 | model(src_tokens=None, token_embeddings=token_embeddings) 34 | -------------------------------------------------------------------------------- /tests/test_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import pytest 5 | import torch 6 | 7 | from torchscale.architecture.config import EncoderDecoderConfig 8 | from torchscale.architecture.encoder_decoder import EncoderDecoder 9 | from torchscale.component.embedding import PositionalEmbedding, TextEmbedding 10 | 11 | testcases = [ 12 | {}, 13 | {"vocab_size": 64000}, 14 | {"activation_fn": "relu"}, 15 | {"drop_path_rate": 0.1}, 16 | {"encoder_normalize_before": False, "decoder_normalize_before": False}, 17 | {"no_scale_embedding": False}, 18 | {"layernorm_embedding": True}, 19 | {"rel_pos_buckets": 32, "max_rel_pos": 256}, 20 | { 21 | "deepnorm": True, 22 | "subln": False, 23 | "encoder_normalize_before": False, 24 | "decoder_normalize_before": False, 25 | }, 26 | {"bert_init": True}, 27 | {"multiway": True}, 28 | {"share_decoder_input_output_embed": True}, 29 | {"share_all_embeddings": True}, 30 | {"checkpoint_activations": True}, 31 | {"fsdp": True}, 32 | ] 33 | 34 | 35 | @pytest.mark.parametrize("args", testcases) 36 | def test_decoder(args): 37 | config = EncoderDecoderConfig(**args) 38 | model = EncoderDecoder( 39 | config, 40 | encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim), 41 | decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim), 42 | encoder_embed_positions=PositionalEmbedding( 43 | config.max_source_positions, config.encoder_embed_dim 44 | ), 45 | decoder_embed_positions=PositionalEmbedding( 46 | config.max_target_positions, config.decoder_embed_dim 47 | ), 48 | ) 49 | 50 | src_tokens = torch.ones(2, 20).long() 51 | prev_output_tokens = torch.ones(2, 10).long() 52 | 53 | model( 54 | src_tokens=src_tokens, 55 | prev_output_tokens=prev_output_tokens, 56 | features_only=True, 57 | ) 58 | -------------------------------------------------------------------------------- /torchscale/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /torchscale/architecture/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /torchscale/architecture/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch.nn as nn 5 | 6 | from torchscale.architecture.decoder import Decoder 7 | from torchscale.architecture.encoder import Encoder 8 | 9 | 10 | class EncoderDecoder(nn.Module): 11 | def __init__( 12 | self, 13 | args, 14 | encoder_embed_tokens=None, 15 | encoder_embed_positions=None, 16 | decoder_embed_tokens=None, 17 | decoder_embed_positions=None, 18 | output_projection=None, 19 | **kwargs 20 | ): 21 | super().__init__() 22 | self.args = args 23 | if args.share_all_embeddings: 24 | args.share_decoder_input_output_embed = True 25 | 26 | self.encoder = Encoder( 27 | args, 28 | encoder_embed_tokens, 29 | encoder_embed_positions, 30 | is_encoder_decoder=True, 31 | **kwargs 32 | ) 33 | 34 | if args.share_all_embeddings and decoder_embed_tokens is None: 35 | decoder_embed_tokens = self.encoder.embed_tokens 36 | 37 | self.decoder = Decoder( 38 | args, 39 | decoder_embed_tokens, 40 | decoder_embed_positions, 41 | output_projection, 42 | is_encoder_decoder=True, 43 | **kwargs 44 | ) 45 | 46 | def forward( 47 | self, 48 | src_tokens, 49 | prev_output_tokens, 50 | return_all_hiddens=False, 51 | features_only=False, 52 | **kwargs 53 | ): 54 | encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens) 55 | decoder_out = self.decoder( 56 | prev_output_tokens, 57 | encoder_out=encoder_out, 58 | features_only=features_only, 59 | return_all_hiddens=return_all_hiddens, 60 | ) 61 | return decoder_out 62 | -------------------------------------------------------------------------------- /torchscale/architecture/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch.nn as nn 5 | 6 | from torchscale.component.multihead_attention import MultiheadAttention 7 | from torchscale.component.multiway_network import MultiwayNetwork 8 | 9 | 10 | def init_bert_params(module): 11 | def normal_(data): 12 | data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) 13 | 14 | if isinstance(module, nn.Linear): 15 | normal_(module.weight.data) 16 | if module.bias is not None: 17 | module.bias.data.zero_() 18 | if isinstance(module, nn.Embedding): 19 | normal_(module.weight.data) 20 | if module.padding_idx is not None: 21 | module.weight.data[module.padding_idx].zero_() 22 | if isinstance(module, MultiheadAttention): 23 | if isinstance(module.q_proj, MultiwayNetwork): 24 | normal_(module.q_proj.A.weight.data) 25 | normal_(module.q_proj.B.weight.data) 26 | normal_(module.k_proj.A.weight.data) 27 | normal_(module.k_proj.B.weight.data) 28 | normal_(module.v_proj.A.weight.data) 29 | normal_(module.v_proj.B.weight.data) 30 | else: 31 | normal_(module.q_proj.weight.data) 32 | normal_(module.k_proj.weight.data) 33 | normal_(module.v_proj.weight.data) 34 | -------------------------------------------------------------------------------- /torchscale/component/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /torchscale/component/dilated_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from einops import rearrange 9 | 10 | from .multihead_attention import MultiheadAttention 11 | from .utils import padding_to_multiple_of, all_gather_func, get_data_parallel_rank, get_data_parallel_world_size 12 | 13 | 14 | class DilatedAttention(MultiheadAttention): 15 | 16 | def dense_to_sparse(self, x, ratio): 17 | length = x.size(1) 18 | padding = padding_to_multiple_of(length, ratio) 19 | head_padding = padding_to_multiple_of(self.num_heads, ratio) 20 | 21 | if padding > 0 or head_padding > 0: 22 | x = F.pad(x, (0, 0, 0, head_padding, 0, padding), value = 0.) 23 | 24 | x = rearrange(x, 'b (l r1) (r2 h) d -> b l h d r1 r2', r1=ratio, r2=ratio) 25 | x = torch.diagonal(x, offset=0, dim1=4, dim2=5) 26 | x = rearrange(x, 'b l h d r -> b l (r h) d') 27 | 28 | if head_padding > 0: 29 | x = x[:, :, :self.num_heads] 30 | 31 | return x 32 | 33 | def sparse_to_dense(self, out, lse, ratio): 34 | head_padding = padding_to_multiple_of(self.num_heads, ratio) 35 | 36 | if head_padding > 0: 37 | out = F.pad(out, (0, 0, 0, head_padding), value = 0.) 38 | lse = F.pad(lse, (0, 0, 0, head_padding), value = -1e8) 39 | 40 | out = rearrange(out, 'b l (r h) d -> b l h d r', r=ratio) 41 | out = torch.diag_embed(out, offset=0, dim1=4, dim2=5) 42 | out = rearrange(out, 'b l h d r1 r2 -> b (r2 h) (l r1) d', r1=ratio, r2=ratio) 43 | 44 | lse = rearrange(lse, 'b (r h) l -> b l h r', r=ratio) 45 | lse = torch.diag_embed(lse, offset=0, dim1=3, dim2=4) 46 | lse = lse.masked_fill_(lse==0, -1e8) 47 | lse = rearrange(lse, 'b l h r1 r2 -> b (r2 h) (l r1) 1', r1=ratio, r2=ratio) 48 | 49 | if head_padding > 0: 50 | out = out[:, :self.num_heads] 51 | lse = lse[:, :self.num_heads] 52 | 53 | return out, lse 54 | 55 | def gather_kv(self, x, sl, seq_len, is_causal=True): 56 | bsz = x.size(0) 57 | assert sl % seq_len == 0 58 | num_rank_per_segment = sl // seq_len 59 | 60 | x = all_gather_func(x) 61 | current_rank = get_data_parallel_rank() 62 | x = rearrange(x, '(w b) l h d -> w b l h d', b=bsz) 63 | 64 | if is_causal: 65 | if current_rank > 0: 66 | x = x[:current_rank] 67 | else: 68 | x = x[:1] * 0 69 | 70 | current_segment = current_rank // num_rank_per_segment * num_rank_per_segment 71 | x = x[current_segment:current_segment+num_rank_per_segment] 72 | 73 | x = rearrange(x, 'w b l h d -> b (w l) h d') 74 | return x 75 | 76 | def gathering(self, x, dr, sl, is_causal=True, offset=0, is_kv=False, seq_parall=True): 77 | 78 | curr_x = x 79 | if offset > 0: 80 | curr_x = F.pad(curr_x, (0, 0, 0, 0, offset % sl, 0), value=0.) 81 | seq_len = curr_x.size(1) 82 | should_gather_kv = is_kv and (get_data_parallel_world_size() > 1) and (sl > seq_len) and seq_parall 83 | _sl = sl 84 | sl = min(sl, seq_len) 85 | padding = padding_to_multiple_of(seq_len, sl) 86 | 87 | if padding > 0: 88 | curr_x = F.pad(curr_x, (0, 0, 0, 0, 0, padding), value = 0.) 89 | 90 | curr_x = rearrange(curr_x, 'b (n g) h d -> (b n) g h d', g=sl) 91 | curr_x = self.dense_to_sparse(curr_x, dr) 92 | 93 | if should_gather_kv: 94 | curr_x = self.gather_kv(curr_x, _sl, seq_len, is_causal) 95 | 96 | curr_x = rearrange(curr_x, 'b l h d -> (b h) l d') 97 | 98 | return curr_x 99 | 100 | def scattering(self, outs, lses, seq_len, bsz, offset=0): 101 | assert len(outs) == len(lses) 102 | assert len(outs) % len(self.args.dilated_ratio) == 0 103 | all_outs, all_lses = [], [] 104 | drs = self.args.dilated_ratio 105 | if len(outs) > len(drs): 106 | drs = drs * (len(outs) // len(drs)) 107 | 108 | for dr, o, lse in zip(drs, outs, lses): 109 | o = rearrange(o, 'b l (h d) -> b l h d', h=self.num_heads) 110 | o, lse = self.sparse_to_dense(o, lse, dr) 111 | o = rearrange(o, '(b n) h g d -> (b h) (n g) d', b=bsz) 112 | lse = rearrange(lse, '(b n) h g 1 -> (b h) (n g) 1', b=bsz) 113 | o = o[:, offset:offset+seq_len] 114 | lse = lse[:, offset:offset+seq_len] 115 | 116 | all_outs.append(o) 117 | all_lses.append(lse) 118 | 119 | with torch.no_grad(): 120 | max_lse = torch.stack(all_lses, dim=0) 121 | max_lse = max_lse.max(0)[0] 122 | all_lses = [torch.exp(lse-max_lse) for lse in all_lses] 123 | lse_sum = torch.stack(all_lses, dim=0).sum(0) 124 | all_lses = [lse / lse_sum for lse in all_lses] 125 | 126 | out = 0 127 | for o, lse in zip(all_outs, all_lses): 128 | out += o * lse.type_as(o) 129 | out = rearrange(out, '(b h) l d -> b l (h d)', h=self.num_heads) 130 | 131 | return out 132 | 133 | def forward( 134 | self, 135 | query, 136 | key, 137 | value, 138 | incremental_state=None, 139 | key_padding_mask=None, 140 | attn_mask=None, 141 | rel_pos=None, 142 | is_first_step=False, 143 | is_causal=False, 144 | ): 145 | assert self.args.flash_attention 146 | assert rel_pos is None 147 | bsz, tgt_len, embed_dim = query.size() 148 | src_len = tgt_len 149 | assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" 150 | 151 | key_bsz, src_len, _ = key.size() 152 | assert key_bsz == bsz, f"{query.size(), key.size()}" 153 | assert value is not None 154 | assert bsz, src_len == value.shape[:2] 155 | 156 | q = self.q_proj(query) 157 | k = self.k_proj(key) 158 | v = self.v_proj(value) 159 | 160 | q = rearrange(q, 'b l (h d) -> (b h) l d', h=self.num_heads) 161 | k = rearrange(k, 'b l (h d) -> (b h) l d', h=self.num_heads) 162 | v = rearrange(v, 'b l (h d) -> (b h) l d', h=self.num_heads) 163 | 164 | if incremental_state is not None and not is_first_step: 165 | offset = src_len - 1 166 | else: 167 | offset = 0 168 | 169 | if incremental_state is not None: 170 | if "prev_key" in incremental_state: 171 | prev_key = incremental_state["prev_key"].view( 172 | bsz * self.num_heads, -1, self.head_dim 173 | ) 174 | prev_value = incremental_state["prev_value"].view( 175 | bsz * self.num_heads, -1, self.head_dim 176 | ) 177 | k = torch.cat([prev_key, k], dim=1) 178 | v = torch.cat([prev_value, v], dim=1) 179 | incremental_state["prev_key"] = k.view( 180 | bsz, self.num_heads, -1, self.head_dim 181 | ) 182 | incremental_state["prev_value"] = v.view( 183 | bsz, self.num_heads, -1, self.head_dim 184 | ) 185 | src_len = k.size(1) 186 | 187 | if self.xpos is not None: 188 | if incremental_state is not None and not is_first_step: 189 | offset = src_len - 1 190 | else: 191 | offset = 0 192 | k = self.xpos(k, offset=0, downscale=True) 193 | q = self.xpos(q, offset=offset, downscale=False) 194 | 195 | q = rearrange(q, '(b h) l d -> b l h d', h=self.num_heads) 196 | k = rearrange(k, '(b h) l d -> b l h d', h=self.num_heads) 197 | v = rearrange(v, '(b h) l d -> b l h d', h=self.num_heads) 198 | 199 | outs, lses = [], [] 200 | for sl, dr in zip(self.args.segment_length, self.args.dilated_ratio): 201 | ki = self.gathering(k, dr, sl, is_causal=is_causal, offset=0, is_kv=True, seq_parall=self.args.seq_parallel) 202 | vi = self.gathering(v, dr, sl, is_causal=is_causal, offset=0, is_kv=True, seq_parall=self.args.seq_parallel) 203 | qi = self.gathering(q, dr, sl, is_causal=is_causal, offset=offset, is_kv=False, seq_parall=self.args.seq_parallel) 204 | 205 | out, lse = self.attention_ops(qi, ki, vi, key_padding_mask=key_padding_mask, attn_mask=attn_mask, rel_pos=rel_pos, is_causal=is_causal) 206 | 207 | outs.append(out) 208 | lses.append(lse) 209 | 210 | attn = self.scattering(outs, lses, tgt_len, bsz, offset=offset) 211 | 212 | if self.inner_attn_ln is not None: 213 | attn = self.inner_attn_ln(attn) 214 | 215 | attn = self.out_proj(attn) 216 | 217 | return attn, None 218 | -------------------------------------------------------------------------------- /torchscale/component/droppath.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch.nn as nn 5 | from timm.models.layers import drop_path 6 | 7 | 8 | class DropPath(nn.Module): 9 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 10 | 11 | def __init__(self, drop_prob=None): 12 | super(DropPath, self).__init__() 13 | self.drop_prob = drop_prob 14 | 15 | def forward(self, x): 16 | return drop_path(x, self.drop_prob, self.training) 17 | 18 | def extra_repr(self): 19 | return "p={}".format(self.drop_prob) 20 | -------------------------------------------------------------------------------- /torchscale/component/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class VisionLanguageEmbedding(nn.Module): 10 | def __init__(self, text_embed, vision_embed): 11 | super().__init__() 12 | self.text_embed = text_embed 13 | self.vision_embed = vision_embed 14 | 15 | def forward(self, textual_tokens, visual_tokens, **kwargs): 16 | if textual_tokens is None: 17 | return self.vision_embed(visual_tokens) 18 | 19 | if visual_tokens is None: 20 | return self.text_embed(textual_tokens) 21 | 22 | x1 = self.vision_embed(visual_tokens) 23 | x2 = self.text_embed(textual_tokens) 24 | 25 | return torch.cat([x1, x2], dim=1) 26 | 27 | 28 | class VisionEmbedding(nn.Module): 29 | """Image to Patch Embedding""" 30 | 31 | def __init__( 32 | self, 33 | img_size=224, 34 | patch_size=16, 35 | in_chans=3, 36 | embed_dim=768, 37 | contain_mask_token=False, 38 | prepend_cls_token=False, 39 | ): 40 | super().__init__() 41 | img_size = (img_size, img_size) 42 | patch_size = (patch_size, patch_size) 43 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 44 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 45 | self.img_size = img_size 46 | self.patch_size = patch_size 47 | self.num_patches = num_patches 48 | 49 | self.proj = nn.Conv2d( 50 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size 51 | ) 52 | 53 | if contain_mask_token: 54 | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 55 | else: 56 | self.mask_token = None 57 | 58 | if prepend_cls_token: 59 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 60 | else: 61 | self.cls_token = None 62 | 63 | def num_position_embeddings(self): 64 | if self.cls_token is None: 65 | return self.num_patches 66 | else: 67 | return self.num_patches + 1 68 | 69 | def forward(self, x, masked_position=None, **kwargs): 70 | B, C, H, W = x.shape 71 | assert ( 72 | H == self.img_size[0] and W == self.img_size[1] 73 | ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 74 | x = self.proj(x).flatten(2).transpose(1, 2) 75 | 76 | batch_size, seq_len, _ = x.size() 77 | 78 | if masked_position is not None: 79 | assert self.mask_token is not None 80 | mask_token = self.mask_token.expand(batch_size, seq_len, -1) 81 | w = masked_position.unsqueeze(-1).type_as(mask_token) 82 | x = x * (1 - w) + mask_token * w 83 | 84 | if self.cls_token is not None: 85 | cls_tokens = self.cls_token.expand( 86 | batch_size, -1, -1 87 | ) # stole cls_tokens impl from Phil Wang, thanks 88 | x = torch.cat((cls_tokens, x), dim=1) 89 | 90 | return x 91 | 92 | 93 | class TextEmbedding(nn.Embedding): 94 | def reset_parameters(self): 95 | nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5) 96 | self._fill_padding_idx_with_zero() 97 | 98 | 99 | class PositionalEmbedding(nn.Embedding): 100 | def forward( 101 | self, 102 | x, 103 | positions=None, 104 | **kwargs, 105 | ): 106 | if positions is None: 107 | # being consistent with Fairseq, which starts from 2. 108 | positions = ( 109 | torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0) 110 | ) 111 | return F.embedding( 112 | positions, 113 | self.weight, 114 | self.padding_idx, 115 | self.max_norm, 116 | self.norm_type, 117 | self.scale_grad_by_freq, 118 | self.sparse, 119 | ) 120 | -------------------------------------------------------------------------------- /torchscale/component/feedforward_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | try: 8 | from apex.normalization import FusedLayerNorm as LayerNorm 9 | except ModuleNotFoundError: 10 | from torch.nn import LayerNorm 11 | 12 | 13 | from .xmoe.global_groups import get_moe_group 14 | 15 | 16 | class set_torch_seed(object): 17 | def __init__(self, seed): 18 | assert isinstance(seed, int) 19 | self.rng_state = self.get_rng_state() 20 | 21 | torch.manual_seed(seed) 22 | if torch.cuda.is_available(): 23 | torch.cuda.manual_seed(seed) 24 | 25 | def get_rng_state(self): 26 | state = {"torch_rng_state": torch.get_rng_state()} 27 | if torch.cuda.is_available(): 28 | state["cuda_rng_state"] = torch.cuda.get_rng_state() 29 | return state 30 | 31 | def set_rng_state(self, state): 32 | torch.set_rng_state(state["torch_rng_state"]) 33 | if torch.cuda.is_available(): 34 | torch.cuda.set_rng_state(state["cuda_rng_state"]) 35 | 36 | def __enter__(self): 37 | return self 38 | 39 | def __exit__(self, *exc): 40 | self.set_rng_state(self.rng_state) 41 | 42 | 43 | def make_experts(args, embed_dim, expert_ffn_dim): 44 | world_size = ( 45 | 1 46 | if not torch.distributed.is_initialized() 47 | else torch.distributed.get_world_size() 48 | ) 49 | expert_list = [] 50 | ddp_rank = args.ddp_rank 51 | start_seed = torch.randint(1000000, (1,)).item() 52 | # at least as many experts than gpus 53 | if args.moe_expert_count >= world_size: 54 | assert ( 55 | args.moe_expert_count % world_size == 0 56 | ), f"{args.moe_expert_count}, {world_size}" 57 | local_moe_expert_count = args.moe_expert_count // world_size 58 | for i in range(local_moe_expert_count): 59 | with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i): 60 | expert_list.append( 61 | FeedForwardNetwork( 62 | embed_dim, 63 | expert_ffn_dim, 64 | args.activation_fn, 65 | args.dropout, 66 | args.activation_dropout, 67 | args.layernorm_eps, 68 | args.subln, 69 | ) 70 | ) 71 | else: 72 | assert ( 73 | world_size % args.moe_expert_count == 0 74 | ), f"{world_size}, {args.moe_expert_count}" 75 | 76 | moe_idx, _ = get_moe_group(args.moe_expert_count) 77 | 78 | with set_torch_seed(start_seed + moe_idx): 79 | expert_list.append( 80 | FeedForwardNetwork( 81 | embed_dim, 82 | expert_ffn_dim, 83 | args.activation_fn, 84 | args.dropout, 85 | args.activation_dropout, 86 | args.layernorm_eps, 87 | args.subln, 88 | ) 89 | ) 90 | experts = nn.ModuleList(expert_list) 91 | return experts 92 | 93 | 94 | def get_activation_fn(activation): 95 | if activation == "relu": 96 | return F.relu 97 | elif activation == "gelu": 98 | return F.gelu 99 | elif activation == "swish": 100 | return F.silu 101 | else: 102 | raise NotImplementedError 103 | 104 | 105 | class FeedForwardNetwork(nn.Module): 106 | def __init__( 107 | self, 108 | embed_dim, 109 | ffn_dim, 110 | activation_fn, 111 | dropout, 112 | activation_dropout, 113 | layernorm_eps, 114 | subln=False, 115 | ): 116 | super().__init__() 117 | self.embed_dim = embed_dim 118 | self.activation_fn = get_activation_fn(activation=str(activation_fn)) 119 | self.activation_dropout_module = torch.nn.Dropout(activation_dropout) 120 | self.dropout_module = torch.nn.Dropout(dropout) 121 | self.fc1 = nn.Linear(self.embed_dim, ffn_dim) 122 | self.fc2 = nn.Linear(ffn_dim, self.embed_dim) 123 | self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None 124 | 125 | def reset_parameters(self): 126 | self.fc1.reset_parameters() 127 | self.fc2.reset_parameters() 128 | if self.ffn_layernorm is not None: 129 | self.ffn_layernorm.reset_parameters() 130 | 131 | def forward(self, x): 132 | x_shape = x.shape 133 | x = x.reshape(-1, x.size(-1)) 134 | x = self.fc1(x) 135 | x = self.activation_fn(x.float()).type_as(x) 136 | x = self.activation_dropout_module(x) 137 | if self.ffn_layernorm is not None: 138 | x = self.ffn_layernorm(x) 139 | x = self.fc2(x) 140 | x = x.view(x_shape) 141 | x = self.dropout_module(x) 142 | return x 143 | -------------------------------------------------------------------------------- /torchscale/component/flash_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | 5 | from typing import Any, Optional 6 | import torch 7 | 8 | if torch.cuda.is_available(): 9 | try: 10 | if torch.cuda.get_device_capability()[0] > 7: 11 | from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func 12 | 13 | def flash_attn_func(q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False): 14 | assert bias is None 15 | attn, lse, _ = _flash_attn_func(q, k, v, dropout_p=dropout, softmax_scale=softmax_scale, causal=is_causal, return_attn_probs=True) 16 | return attn, lse 17 | 18 | else: 19 | from xformers.ops.fmha import ( 20 | cutlass, 21 | Inputs, 22 | Context, 23 | _memory_efficient_attention_forward_requires_grad, 24 | _memory_efficient_attention_backward, 25 | LowerTriangularMask, 26 | ) 27 | 28 | class FlashAttnFunc(torch.autograd.Function): 29 | @staticmethod 30 | # type: ignore 31 | def forward(ctx, q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False): 32 | if is_causal: 33 | assert bias is None 34 | attn_bias = LowerTriangularMask() 35 | else: 36 | attn_bias = bias 37 | 38 | inp = Inputs( 39 | query=q, 40 | key=k, 41 | value=v, 42 | attn_bias=attn_bias, 43 | p=dropout, 44 | scale=softmax_scale, 45 | ) 46 | op_fw = cutlass.FwOp 47 | op_bw = cutlass.BwOp 48 | 49 | out, op_ctx = _memory_efficient_attention_forward_requires_grad( 50 | inp=inp, op=op_fw 51 | ) 52 | 53 | # Saving attn_bias is a bit complicated, as the 54 | # torch part should go in `save_for_backward` 55 | if isinstance(inp.attn_bias, torch.Tensor): 56 | attn_bias_tensor = inp.attn_bias 57 | attn_bias_ctx = None 58 | else: 59 | attn_bias_tensor = None 60 | attn_bias_ctx = inp.attn_bias 61 | 62 | ctx.save_for_backward( 63 | inp.query, 64 | inp.key, 65 | inp.value, 66 | op_ctx.out, 67 | op_ctx.lse, 68 | ) 69 | ctx.rng_state = op_ctx.rng_state 70 | ctx.attn_bias_tensor = attn_bias_tensor 71 | if op_ctx.op_bw is not None: 72 | if op_bw is not None and op_bw is not op_ctx.op_bw: 73 | raise ValueError( 74 | f"Specified op_bw={op_bw.NAME}, but forward op " 75 | f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None." 76 | ) 77 | op_bw = op_ctx.op_bw 78 | ctx.op_fw = op_fw 79 | ctx.op_bw = op_bw 80 | ctx.p = inp.p 81 | 82 | ctx.scale = inp.scale 83 | ctx.attn_bias_ctx = attn_bias_ctx 84 | return out, op_ctx.lse 85 | 86 | @staticmethod 87 | def deserialize_bias( 88 | attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor] 89 | ) -> Any: 90 | if attn_bias_tensor is None: 91 | return attn_bias_ctx 92 | return attn_bias_tensor 93 | 94 | @classmethod 95 | @torch.autograd.function.once_differentiable 96 | def backward(cls, ctx, grad, dlse): 97 | # Re-create context 98 | query, key, value, out, lse = ctx.saved_tensors 99 | attn_bias_tensor = ctx.attn_bias_tensor 100 | rng_state = ctx.rng_state 101 | inp = Inputs( 102 | query=query, 103 | key=key, 104 | value=value, 105 | attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor), 106 | p=ctx.p, 107 | scale=ctx.scale, 108 | ) 109 | op_ctx = Context( 110 | lse=lse, 111 | out=out, 112 | rng_state=rng_state, 113 | ) 114 | grads = _memory_efficient_attention_backward( 115 | ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw 116 | ) 117 | return grads.dq, grads.dk, grads.dv, None, grads.db, None, None 118 | 119 | flash_attn_func = FlashAttnFunc.apply 120 | except ModuleNotFoundError: 121 | flash_attn_func = None 122 | else: 123 | flash_attn_func = None 124 | -------------------------------------------------------------------------------- /torchscale/component/gate_linear_unit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .feedforward_network import get_activation_fn 9 | 10 | 11 | class GLU(nn.Module): 12 | def __init__( 13 | self, 14 | embed_dim, 15 | ffn_dim, 16 | activation_fn, 17 | dropout, 18 | activation_dropout, 19 | ): 20 | super().__init__() 21 | self.embed_dim = embed_dim 22 | self.activation_fn = get_activation_fn(activation=str(activation_fn)) 23 | self.activation_dropout_module = torch.nn.Dropout(activation_dropout) 24 | self.dropout_module = torch.nn.Dropout(dropout) 25 | self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False) 26 | self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False) 27 | self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False) 28 | 29 | def reset_parameters(self): 30 | self.fc1.reset_parameters() 31 | self.fc2.reset_parameters() 32 | self.gate.reset_parameters() 33 | 34 | def forward(self, x): 35 | x_shape = x.shape 36 | x = x.reshape(-1, x.size(-1)) 37 | g = self.gate(x) 38 | x = self.fc1(x) 39 | x = self.activation_fn(x.float()).type_as(x) * g 40 | x = self.activation_dropout_module(x) 41 | x = self.fc2(x) 42 | x = x.view(x_shape) 43 | x = self.dropout_module(x) 44 | return x 45 | -------------------------------------------------------------------------------- /torchscale/component/multihead_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from einops import rearrange 10 | try: 11 | from apex.normalization import FusedLayerNorm as LayerNorm 12 | except ModuleNotFoundError: 13 | from torch.nn import LayerNorm 14 | 15 | from .multiway_network import MultiwayWrapper 16 | from .xpos_relative_position import XPOS 17 | from .flash_attention import flash_attn_func 18 | 19 | 20 | class MultiheadAttention(nn.Module): 21 | def __init__( 22 | self, 23 | args, 24 | embed_dim, 25 | num_heads, 26 | dropout=0.0, 27 | self_attention=False, 28 | encoder_decoder_attention=False, 29 | subln=False, 30 | ): 31 | super().__init__() 32 | self.args = args 33 | self.embed_dim = embed_dim 34 | self.num_heads = num_heads 35 | self.head_dim = embed_dim // num_heads 36 | self.scaling = self.head_dim**-0.5 37 | self.dropout = dropout 38 | 39 | self.self_attention = self_attention 40 | self.encoder_decoder_attention = encoder_decoder_attention 41 | assert self.self_attention ^ self.encoder_decoder_attention 42 | 43 | self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) 44 | self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) 45 | self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) 46 | self.out_proj = MultiwayWrapper( 47 | args, nn.Linear(embed_dim, embed_dim, bias=True) 48 | ) 49 | self.inner_attn_ln = ( 50 | MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) 51 | if subln and self.self_attention 52 | else None 53 | ) 54 | self.dropout_module = torch.nn.Dropout(dropout) 55 | self.xpos = ( 56 | XPOS(self.head_dim, args.xpos_scale_base) 57 | if args.xpos_rel_pos and self.self_attention 58 | else None 59 | ) 60 | 61 | def reset_parameters(self): 62 | nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) 63 | nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) 64 | nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) 65 | nn.init.xavier_uniform_(self.out_proj.weight) 66 | nn.init.constant_(self.out_proj.bias, 0.0) 67 | 68 | def attention_ops(self, q, k, v, key_padding_mask=None, attn_mask=None, rel_pos=None, is_causal=False): 69 | if not self.args.flash_attention: 70 | q *= self.scaling 71 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 72 | 73 | if attn_mask is not None: 74 | attn_weights = torch.nan_to_num(attn_weights) 75 | attn_mask = attn_mask.unsqueeze(0) 76 | attn_weights += attn_mask 77 | 78 | if key_padding_mask is not None: 79 | attn_weights = rearrange(attn_weights, '(b h) t s -> b h t s', h=self.num_heads) 80 | attn_weights = attn_weights.masked_fill( 81 | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 82 | float("-inf"), 83 | ) 84 | attn_weights = rearrange(attn_weights, 'b h t s -> (b h) t s') 85 | 86 | if rel_pos is not None: 87 | rel_pos = rel_pos.view(attn_weights.size()) 88 | attn_weights = attn_weights + rel_pos 89 | 90 | attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( 91 | attn_weights 92 | ) 93 | attn_probs = self.dropout_module(attn_weights) 94 | 95 | attn = torch.bmm(attn_probs, v) 96 | attn = rearrange(attn, '(b h) l d -> b l (h d)', h=self.num_heads) 97 | else: 98 | assert flash_attn_func is not None 99 | assert rel_pos is None 100 | q = rearrange(q, '(b h) l d -> b l h d', h=self.num_heads) 101 | k = rearrange(k, '(b h) l d -> b l h d', h=self.num_heads) 102 | v = rearrange(v, '(b h) l d -> b l h d', h=self.num_heads) 103 | attn, lse = flash_attn_func(q, k, v, self.dropout, attn_mask, None, is_causal) 104 | attn = rearrange(attn, 'b l h d -> b l (h d)') 105 | attn_weights = lse[:, :, :attn.size(1)] 106 | 107 | return attn, attn_weights 108 | 109 | def forward( 110 | self, 111 | query, 112 | key, 113 | value, 114 | incremental_state=None, 115 | key_padding_mask=None, 116 | attn_mask=None, 117 | rel_pos=None, 118 | is_first_step=False, 119 | is_causal=False, 120 | ): 121 | bsz, tgt_len, embed_dim = query.size() 122 | src_len = tgt_len 123 | assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" 124 | 125 | key_bsz, src_len, _ = key.size() 126 | assert key_bsz == bsz, f"{query.size(), key.size()}" 127 | assert value is not None 128 | assert bsz, src_len == value.shape[:2] 129 | 130 | q = self.q_proj(query) 131 | k = self.k_proj(key) 132 | v = self.v_proj(value) 133 | 134 | q = rearrange(q, 'b l (h d) -> (b h) l d', h=self.num_heads) 135 | k = rearrange(k, 'b l (h d) -> (b h) l d', h=self.num_heads) 136 | v = rearrange(v, 'b l (h d) -> (b h) l d', h=self.num_heads) 137 | 138 | if incremental_state is not None: 139 | if "prev_key" in incremental_state: 140 | prev_key = incremental_state["prev_key"].view( 141 | bsz * self.num_heads, -1, self.head_dim 142 | ) 143 | prev_value = incremental_state["prev_value"].view( 144 | bsz * self.num_heads, -1, self.head_dim 145 | ) 146 | k = torch.cat([prev_key, k], dim=1) 147 | v = torch.cat([prev_value, v], dim=1) 148 | incremental_state["prev_key"] = k.view( 149 | bsz, self.num_heads, -1, self.head_dim 150 | ) 151 | incremental_state["prev_value"] = v.view( 152 | bsz, self.num_heads, -1, self.head_dim 153 | ) 154 | src_len = k.size(1) 155 | 156 | if self.xpos is not None: 157 | if incremental_state is not None and not is_first_step: 158 | offset = src_len - 1 159 | else: 160 | offset = 0 161 | k = self.xpos(k, offset=0, downscale=True) 162 | q = self.xpos(q, offset=offset, downscale=False) 163 | 164 | attn, attn_weights = self.attention_ops(q, k, v, key_padding_mask=key_padding_mask, attn_mask=attn_mask, rel_pos=rel_pos, is_causal=is_causal) 165 | 166 | if self.inner_attn_ln is not None: 167 | attn = self.inner_attn_ln(attn) 168 | 169 | attn = self.out_proj(attn) 170 | 171 | return attn, attn_weights 172 | -------------------------------------------------------------------------------- /torchscale/component/multiscale_retention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from .rms_norm import RMSNorm 9 | 10 | from .multiway_network import MultiwayWrapper 11 | 12 | def rotate_every_two(x): 13 | x1 = x[:, :, :, ::2] 14 | x2 = x[:, :, :, 1::2] 15 | x = torch.stack((-x2, x1), dim=-1) 16 | return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ 17 | 18 | def duplicate_interleave(m): 19 | """ 20 | A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. 21 | """ 22 | dim0 = m.shape[0] 23 | m = m.view(-1, 1) # flatten the matrix 24 | m = m.repeat(1, 2) # repeat all elements into the 2nd dimension 25 | m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy 26 | return m 27 | 28 | def theta_shift(x, sin, cos): 29 | return (x * cos) + (rotate_every_two(x) * sin) 30 | 31 | def get_activation_fn(activation): 32 | if activation == "swish": 33 | return F.silu 34 | elif activation == "gelu": 35 | return F.gelu 36 | else: 37 | raise NotImplementedError 38 | 39 | class MultiScaleRetention(nn.Module): 40 | def __init__( 41 | self, 42 | args, 43 | embed_dim, 44 | value_dim, 45 | num_heads, 46 | gate_fn="swish", 47 | ): 48 | super().__init__() 49 | self.args = args 50 | self.embed_dim = embed_dim 51 | self.value_dim = value_dim 52 | self.num_heads = num_heads 53 | self.head_dim = self.value_dim // num_heads 54 | self.key_dim = self.embed_dim // num_heads 55 | self.scaling = self.key_dim ** -0.5 56 | 57 | self.gate_fn = get_activation_fn(activation=str(gate_fn)) 58 | 59 | self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False)) 60 | self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False)) 61 | self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False)) 62 | self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False)) 63 | 64 | self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False)) 65 | 66 | self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False)) 67 | self.reset_parameters() 68 | 69 | def reset_parameters(self): 70 | nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5) 71 | nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5) 72 | nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5) 73 | nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5) 74 | nn.init.xavier_uniform_(self.out_proj.weight, gain=2 ** -1) 75 | 76 | def parallel_forward(self, qr, kr, v, mask): 77 | bsz, tgt_len, embed_dim = v.size() 78 | 79 | vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) 80 | 81 | qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len 82 | qk_mat = qk_mat * mask 83 | # invariant after normalization 84 | qk_mat = qk_mat / qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1, max=5e4) 85 | output = torch.matmul(qk_mat, vr) 86 | output = output.transpose(1, 2) 87 | return output 88 | 89 | def recurrent_forward( 90 | self, 91 | qr, kr, v, 92 | decay, 93 | incremental_state 94 | ): 95 | bsz = v.size(0) 96 | 97 | v = v.view(bsz, self.num_heads, self.head_dim, 1) 98 | kv = kr * v 99 | if "prev_key_value" in incremental_state: 100 | prev_kv = incremental_state["prev_key_value"] 101 | prev_scale = incremental_state["scale"] 102 | scale = prev_scale * decay + 1 103 | kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1) 104 | # kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv 105 | else: 106 | scale = torch.ones_like(decay) 107 | 108 | incremental_state["prev_key_value"] = kv 109 | incremental_state["scale"] = scale 110 | 111 | output = torch.sum(qr * kv, dim=3) 112 | return output 113 | 114 | def chunk_recurrent_forward( 115 | self, 116 | qr, kr, v, 117 | inner_mask 118 | ): 119 | mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask 120 | bsz, tgt_len, embed_dim = v.size() 121 | chunk_len = mask.size(1) 122 | num_chunks = tgt_len // chunk_len 123 | 124 | assert tgt_len % chunk_len == 0 125 | 126 | qr = qr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2) 127 | kr = kr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2) 128 | v = v.view(bsz, num_chunks, chunk_len, self.num_heads, self.head_dim).transpose(2, 3) 129 | 130 | kr_t = kr.transpose(-1, -2) 131 | 132 | qk_mat = qr @ kr_t # bsz * num_heads * chunk_len * chunk_len 133 | qk_mat = qk_mat * mask 134 | inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1) 135 | qk_mat = qk_mat / inner_scale 136 | inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim 137 | 138 | # reduce kv in one chunk 139 | kv = kr_t @ (v * value_inner_decay) 140 | 141 | kv_recurrent = [] 142 | cross_scale = [] 143 | kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v) 144 | kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v) 145 | 146 | # accumulate kv by loop 147 | for i in range(num_chunks): 148 | kv_recurrent.append(kv_state / kv_scale) 149 | cross_scale.append(kv_scale) 150 | kv_state = kv_state * cross_decay + kv[:, i] 151 | kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).max(dim=-1, keepdim=True).values.clamp(min=1) 152 | 153 | kv_recurrent = torch.stack(kv_recurrent, dim=1) 154 | cross_scale = torch.stack(cross_scale, dim=1) 155 | 156 | all_scale = torch.maximum(inner_scale, cross_scale) 157 | align_inner_scale = all_scale / inner_scale 158 | align_cross_scale = all_scale / cross_scale 159 | 160 | cross_output = (qr * query_inner_decay) @ kv_recurrent 161 | output = inner_output / align_inner_scale + cross_output / align_cross_scale 162 | # output = inner_output / cross_scale + cross_output / inner_scale 163 | 164 | output = output.transpose(2, 3) 165 | return output 166 | 167 | def forward( 168 | self, 169 | x, 170 | rel_pos, 171 | chunkwise_recurrent=False, 172 | incremental_state=None 173 | ): 174 | bsz, tgt_len, _ = x.size() 175 | (sin, cos), inner_mask = rel_pos 176 | 177 | q = self.q_proj(x) 178 | k = self.k_proj(x) 179 | v = self.v_proj(x) 180 | g = self.g_proj(x) 181 | 182 | k *= self.scaling 183 | q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2) 184 | k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2) 185 | 186 | qr = theta_shift(q, sin, cos) 187 | kr = theta_shift(k, sin, cos) 188 | 189 | if incremental_state is not None: 190 | output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state) 191 | elif chunkwise_recurrent: 192 | output = self.chunk_recurrent_forward(qr, kr, v, inner_mask) 193 | else: 194 | output = self.parallel_forward(qr, kr, v, inner_mask) 195 | 196 | output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads) 197 | 198 | output = self.gate_fn(g) * output 199 | 200 | output = self.out_proj(output) 201 | 202 | return output 203 | 204 | 205 | -------------------------------------------------------------------------------- /torchscale/component/multiway_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import copy 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def MultiwayWrapper(args, module, dim=1): 11 | if args.multiway: 12 | return MultiwayNetwork(module, dim=dim) 13 | return module 14 | 15 | 16 | def set_split_position(position): 17 | def apply_fn(module): 18 | if hasattr(module, "split_position"): 19 | module.split_position = position 20 | 21 | return apply_fn 22 | 23 | 24 | class MultiwayNetwork(nn.Module): 25 | def __init__(self, module, dim=1): 26 | super().__init__() 27 | self.dim = dim 28 | self.A = module 29 | self.B = copy.deepcopy(module) 30 | self.B.reset_parameters() 31 | self.split_position = -1 32 | 33 | def forward(self, x, **kwargs): 34 | if self.split_position == -1: 35 | return self.A(x, **kwargs) 36 | if self.split_position == 0: 37 | return self.B(x, **kwargs) 38 | x1, x2 = torch.split( 39 | x, 40 | [self.split_position, x.size(self.dim) - self.split_position], 41 | dim=self.dim, 42 | ) 43 | # x1, x2 = x[:self.split_position], x[self.split_position:] 44 | y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs) 45 | return torch.cat([y1, y2], dim=self.dim) 46 | 47 | 48 | class MutliwayEmbedding(MultiwayNetwork): 49 | def __init__(self, modules, dim=1): 50 | super(MultiwayNetwork, self).__init__() 51 | self.dim = dim 52 | assert len(modules) == 2 53 | self.A = modules[0] 54 | self.B = modules[1] 55 | self.split_position = -1 -------------------------------------------------------------------------------- /torchscale/component/relative_position_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class RelativePositionBias(nn.Module): 11 | def __init__( 12 | self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12 13 | ): 14 | super().__init__() 15 | self.bidirectional = bidirectional 16 | self.num_buckets = num_buckets 17 | self.max_distance = max_distance 18 | self.n_heads = n_heads 19 | self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads) 20 | 21 | @staticmethod 22 | def _relative_position_bucket( 23 | relative_position, bidirectional=True, num_buckets=32, max_distance=128 24 | ): 25 | ret = 0 26 | n = -relative_position 27 | if bidirectional: 28 | num_buckets //= 2 29 | ret += (n < 0).to(torch.long) * num_buckets 30 | n = torch.abs(n) 31 | else: 32 | n = torch.max(n, torch.zeros_like(n)) 33 | 34 | max_exact = num_buckets // 2 35 | is_small = n < max_exact 36 | 37 | val_if_large = max_exact + ( 38 | torch.log(n.float() / max_exact) 39 | / math.log(max_distance / max_exact) 40 | * (num_buckets - max_exact) 41 | ).to(torch.long) 42 | val_if_large = torch.min( 43 | val_if_large, torch.full_like(val_if_large, num_buckets - 1) 44 | ) 45 | 46 | ret += torch.where(is_small, n, val_if_large) 47 | return ret 48 | 49 | def compute_bias(self, qlen, klen, step=None): 50 | step = 0 if step is None else step 51 | context_position = torch.arange( 52 | step, 53 | step + qlen, 54 | dtype=torch.long, 55 | device=self.relative_attention_bias.weight.device, 56 | )[:, None] 57 | memory_position = torch.arange( 58 | klen, dtype=torch.long, device=self.relative_attention_bias.weight.device 59 | )[None, :] 60 | relative_position = memory_position - context_position # shape (qlen, klen) 61 | 62 | rp_bucket = self._relative_position_bucket( 63 | relative_position, # shape (qlen, klen) 64 | bidirectional=self.bidirectional, 65 | num_buckets=self.num_buckets, 66 | max_distance=self.max_distance, 67 | ) 68 | rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) 69 | values = self.relative_attention_bias( 70 | rp_bucket 71 | ) # shape (qlen, klen, num_heads) 72 | values = values.permute([2, 0, 1]).unsqueeze( 73 | 0 74 | ) # shape (1, num_heads, qlen, klen) 75 | return values 76 | 77 | def forward(self, batch_size, qlen, klen, step=None): 78 | # shape (batch * num_heads, qlen, klen) 79 | return ( 80 | self.compute_bias(qlen, klen, step) 81 | .repeat(batch_size, 1, 1, 1) 82 | .view(-1, qlen, klen) 83 | ) 84 | -------------------------------------------------------------------------------- /torchscale/component/rms_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class RMSNorm(nn.Module): 8 | def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True): 9 | super().__init__() 10 | self.eps = eps 11 | self.elementwise_affine = elementwise_affine 12 | if self.elementwise_affine: 13 | self.weight = nn.Parameter(torch.ones(dim)) 14 | else: 15 | self.register_parameter('weight', None) 16 | 17 | def _norm(self, x): 18 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 19 | 20 | def forward(self, x): 21 | output = self._norm(x.float()).type_as(x) 22 | if self.weight is not None: 23 | output = output * self.weight 24 | return output 25 | -------------------------------------------------------------------------------- /torchscale/component/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | def padding_to_multiple_of(n, mult): 8 | remainder = n % mult 9 | if remainder == 0: 10 | return 0 11 | return mult - remainder 12 | 13 | def get_data_parallel_group(): 14 | if torch.distributed.is_initialized(): 15 | if not hasattr(get_data_parallel_group, "_global_group"): 16 | get_data_parallel_group._global_group = dist.new_group() 17 | return get_data_parallel_group._global_group 18 | else: 19 | return None 20 | 21 | def get_rank(group): 22 | return dist.get_rank(group=group) 23 | 24 | def get_world_size(group): 25 | if torch.distributed.is_initialized(): 26 | return dist.get_world_size(group=group) 27 | else: 28 | return 1 29 | 30 | def get_data_parallel_rank(): 31 | return get_rank(get_data_parallel_group()) 32 | 33 | def get_data_parallel_world_size(): 34 | return get_world_size(get_data_parallel_group()) 35 | 36 | 37 | class Allgather(torch.autograd.Function): 38 | 39 | @staticmethod 40 | def forward(ctx, input_): 41 | world_size = get_data_parallel_world_size() 42 | dim_size = list(input_.size()) 43 | dim_size[0] = dim_size[0] * world_size 44 | 45 | output = torch.empty(dim_size, dtype=input_.dtype, 46 | device=torch.cuda.current_device()) 47 | torch.distributed._all_gather_base(output, input_.contiguous(), 48 | group=get_data_parallel_group()) 49 | 50 | return output 51 | 52 | @staticmethod 53 | def backward(ctx, grad_output): 54 | world_size = get_data_parallel_world_size() 55 | 56 | dim_size = list(grad_output.size()) 57 | assert dim_size[0] % world_size == 0, \ 58 | "First dimension of the tensor should be divisible by tensor parallel size" 59 | 60 | dim_size[0] = dim_size[0] // world_size 61 | 62 | output = torch.empty(dim_size, dtype=grad_output.dtype, 63 | device=torch.cuda.current_device()) 64 | 65 | torch.distributed._reduce_scatter_base(output, grad_output.contiguous(), 66 | group=get_data_parallel_group()) 67 | 68 | return output 69 | 70 | all_gather_func = Allgather.apply 71 | -------------------------------------------------------------------------------- /torchscale/component/xmoe/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | -------------------------------------------------------------------------------- /torchscale/component/xmoe/global_groups.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | 4 | def _find_my_group_index(grouped_ranks): 5 | my_rank = dist.get_rank() 6 | for i, group in enumerate(grouped_ranks): 7 | if my_rank in group: 8 | return i 9 | raise RuntimeError 10 | 11 | def get_moe_group(moe_expert_count=None): 12 | if dist.is_initialized(): 13 | if not hasattr(get_moe_group, "_moe_groups"): 14 | world_size = dist.get_world_size() 15 | 16 | if world_size <= moe_expert_count: 17 | assert moe_expert_count % world_size == 0 18 | moe_groups = [[i] for i in range(world_size)] 19 | 20 | else: 21 | assert world_size % moe_expert_count == 0 22 | ranks_per_group = world_size // moe_expert_count 23 | moe_groups = [ 24 | [i + j * moe_expert_count for j in range(ranks_per_group)] 25 | for i in range(moe_expert_count) 26 | ] 27 | 28 | get_moe_group._moe_expert_count = moe_expert_count 29 | get_moe_group._moe_group_idx = moe_groups 30 | get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups] 31 | 32 | my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx) 33 | return my_group_idx, get_moe_group._moe_groups[my_group_idx] 34 | 35 | 36 | def get_all2all_group(moe_expert_count): 37 | if dist.is_initialized(): 38 | if not hasattr(get_all2all_group, "_all2all_groups"): 39 | world_size = dist.get_world_size() 40 | 41 | # more experts than world size 42 | if world_size <= moe_expert_count: 43 | assert moe_expert_count % world_size == 0 44 | all2all_groups = [[i for i in range(world_size)]] 45 | 46 | # larger world than num experts 47 | else: 48 | assert world_size % moe_expert_count == 0 49 | ranks_per_group = world_size // moe_expert_count 50 | all2all_groups = [ 51 | [i * moe_expert_count + j for j in range(moe_expert_count)] 52 | for i in range(ranks_per_group) 53 | ] 54 | 55 | get_all2all_group._all2all_group_idx = all2all_groups 56 | get_all2all_group._all2all_groups = [ 57 | dist.new_group(g) for g in all2all_groups 58 | ] 59 | 60 | my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx) 61 | return get_all2all_group._all2all_groups[my_group_idx] 62 | -------------------------------------------------------------------------------- /torchscale/component/xpos_relative_position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | def fixed_pos_embedding(x): 9 | seq_len, dim = x.shape 10 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim)) 11 | sinusoid_inp = ( 12 | torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x) 13 | ) 14 | return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) 15 | 16 | def rotate_every_two(x): 17 | x1 = x[:, :, ::2] 18 | x2 = x[:, :, 1::2] 19 | x = torch.stack((-x2, x1), dim=-1) 20 | return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ 21 | 22 | def duplicate_interleave(m): 23 | """ 24 | A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. 25 | """ 26 | dim0 = m.shape[0] 27 | m = m.view(-1, 1) # flatten the matrix 28 | m = m.repeat(1, 2) # repeat all elements into the 2nd dimension 29 | m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy 30 | return m 31 | 32 | def apply_rotary_pos_emb(x, sin, cos, scale=1): 33 | sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos)) 34 | # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) 35 | return (x * cos) + (rotate_every_two(x) * sin) 36 | 37 | 38 | class XPOS(nn.Module): 39 | def __init__( 40 | self, head_dim, scale_base=512 41 | ): 42 | super().__init__() 43 | self.head_dim = head_dim 44 | self.scale_base = scale_base 45 | self.register_buffer( 46 | "scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim) 47 | ) 48 | 49 | def forward(self, x, offset=0, downscale=False): 50 | length = x.shape[1] 51 | min_pos = -(length + offset) // 2 52 | max_pos = length + offset + min_pos 53 | scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None] 54 | sin, cos = fixed_pos_embedding(scale) 55 | 56 | if scale.shape[0] > length: 57 | scale = scale[-length:] 58 | sin = sin[-length:] 59 | cos = cos[-length:] 60 | 61 | if downscale: 62 | scale = 1 / scale 63 | 64 | x = apply_rotary_pos_emb(x, sin, cos, scale) 65 | return x 66 | -------------------------------------------------------------------------------- /torchscale/model/BEiT3.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torchscale.architecture.encoder import Encoder 8 | from torchscale.component.embedding import ( 9 | PositionalEmbedding, 10 | TextEmbedding, 11 | VisionEmbedding, 12 | ) 13 | from torchscale.component.multiway_network import MutliwayEmbedding 14 | 15 | 16 | class BEiT3(nn.Module): 17 | def __init__(self, args, **kwargs): 18 | super().__init__() 19 | self.args = args 20 | assert args.multiway 21 | assert args.vocab_size > 0 22 | assert not args.share_encoder_input_output_embed 23 | self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim) 24 | self.vision_embed = VisionEmbedding( 25 | args.img_size, 26 | args.patch_size, 27 | args.in_chans, 28 | args.encoder_embed_dim, 29 | contain_mask_token=True, 30 | prepend_cls_token=True, 31 | ) 32 | # being consistent with Fairseq, which starts from 2 for position embedding 33 | embed_positions = MutliwayEmbedding( 34 | modules=[ 35 | PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim), 36 | PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim), 37 | ], 38 | dim=1, 39 | ) 40 | self.encoder = Encoder( 41 | args, 42 | embed_tokens=None, 43 | embed_positions=embed_positions, 44 | output_projection=None, 45 | is_encoder_decoder=False, 46 | ) 47 | 48 | def forward( 49 | self, 50 | textual_tokens=None, 51 | visual_tokens=None, 52 | text_padding_position=None, 53 | attn_mask=None, 54 | vision_masked_position=None, 55 | incremental_state=None, 56 | positions=None, 57 | ): 58 | assert textual_tokens is not None or visual_tokens is not None 59 | 60 | if textual_tokens is None: 61 | x = self.vision_embed(visual_tokens, vision_masked_position) 62 | encoder_padding_mask = None 63 | multiway_split_position = -1 64 | elif visual_tokens is None: 65 | x = self.text_embed(textual_tokens) 66 | encoder_padding_mask = text_padding_position 67 | multiway_split_position = 0 68 | else: 69 | x1 = self.vision_embed(visual_tokens, vision_masked_position) 70 | multiway_split_position = x1.size(1) 71 | x2 = self.text_embed(textual_tokens) 72 | x = torch.cat([x1, x2], dim=1) 73 | 74 | if text_padding_position is not None: 75 | encoder_padding_mask = torch.cat( 76 | [ 77 | torch.zeros(x1.shape[:-1]).to(x1.device).bool(), 78 | text_padding_position, 79 | ], 80 | dim=1, 81 | ) 82 | else: 83 | encoder_padding_mask = None 84 | 85 | encoder_out = self.encoder( 86 | src_tokens=None, 87 | encoder_padding_mask=encoder_padding_mask, 88 | attn_mask=attn_mask, 89 | token_embeddings=x, 90 | multiway_split_position=multiway_split_position, 91 | incremental_state=incremental_state, 92 | positions=positions, 93 | ) 94 | encoder_out["multiway_split_position"] = multiway_split_position 95 | 96 | return encoder_out 97 | -------------------------------------------------------------------------------- /torchscale/model/LongNet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | from torchscale.architecture.decoder import Decoder, DecoderLayer 5 | from torchscale.architecture.encoder import Encoder, EncoderLayer 6 | from torchscale.component.dilated_attention import DilatedAttention 7 | from fairscale.nn import checkpoint_wrapper, wrap 8 | 9 | 10 | class LongNetDecoderLayer(DecoderLayer): 11 | 12 | def build_self_attention(self, embed_dim, args): 13 | return DilatedAttention( 14 | args, 15 | embed_dim, 16 | args.decoder_attention_heads, 17 | dropout=args.attention_dropout, 18 | self_attention=True, 19 | encoder_decoder_attention=False, 20 | subln=args.subln, 21 | ) 22 | 23 | class LongNetDecoder(Decoder): 24 | 25 | def build_decoder_layer( 26 | self, args, depth, is_moe_layer=False, is_encoder_decoder=False 27 | ): 28 | layer = LongNetDecoderLayer( 29 | args, 30 | depth, 31 | is_moe_layer=is_moe_layer, 32 | is_encoder_decoder=is_encoder_decoder, 33 | ) 34 | if args.checkpoint_activations: 35 | layer = checkpoint_wrapper(layer) 36 | if args.fsdp: 37 | layer = wrap(layer) 38 | return layer 39 | 40 | class LongNetEncoderLayer(EncoderLayer): 41 | 42 | def build_self_attention(self, embed_dim, args): 43 | return DilatedAttention( 44 | args, 45 | embed_dim, 46 | args.encoder_attention_heads, 47 | dropout=args.attention_dropout, 48 | self_attention=True, 49 | encoder_decoder_attention=False, 50 | subln=args.subln, 51 | ) 52 | 53 | class LongNetEncoder(Encoder): 54 | 55 | def build_encoder_layer( 56 | self, args, depth, is_moe_layer=False, is_encoder_decoder=False 57 | ): 58 | layer = LongNetEncoderLayer( 59 | args, 60 | depth, 61 | is_moe_layer=is_moe_layer, 62 | is_encoder_decoder=is_encoder_decoder, 63 | ) 64 | if args.checkpoint_activations: 65 | layer = checkpoint_wrapper(layer) 66 | if args.fsdp: 67 | layer = wrap(layer) 68 | return layer 69 | -------------------------------------------------------------------------------- /torchscale/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | --------------------------------------------------------------------------------