├── .circleci └── config.yml ├── .coveragerc ├── .editorconfig ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── feature-request.md │ └── questions-help-support.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── pre-commit.yml │ └── release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── NOTICE ├── README.md ├── RELEASE.md ├── benchmarks ├── __init__.py ├── datasets │ ├── __init__.py │ ├── mnist.py │ └── wikitext2_data.py ├── experimental │ ├── benchmark_dataset.py │ ├── benchmark_mevo.py │ ├── experimental_async_approaches.py │ ├── offload.py │ └── sync_batchnorm.py ├── fsdp.py ├── golden_configs │ ├── __init__.py │ ├── lm_wikitext2.py │ └── oss_mnist.py ├── models │ ├── __init__.py │ └── transformer_lm.py ├── moe.py ├── oss.py ├── pipe.py └── utils.py ├── codecov.yml ├── docs ├── Makefile ├── requirements.txt └── source │ ├── _static │ ├── css │ │ └── customize.css │ └── img │ │ ├── ddp.png │ │ ├── fairscale-logo.png │ │ ├── flowchart.png │ │ ├── fsdp.png │ │ ├── global.png │ │ ├── offload.png │ │ ├── oss.png │ │ ├── pipe.png │ │ └── sdp.png │ ├── _templates │ ├── layout.html │ └── theme_variables.jinja │ ├── api │ ├── experimental │ │ └── nn │ │ │ ├── offload_model.rst │ │ │ └── slowmo_ddp.rst │ ├── index.rst │ ├── nn │ │ ├── checkpoint │ │ │ └── checkpoint_activations.rst │ │ ├── fsdp.rst │ │ ├── moe.rst │ │ ├── pipe.rst │ │ └── sharded_ddp.rst │ └── optim │ │ ├── adascale.rst │ │ └── oss.rst │ ├── blogs_and_press.rst │ ├── conf.py │ ├── deep_dive │ ├── activation_checkpointing.rst │ ├── adascale.rst │ ├── offload.rst │ ├── oss_sdp_fsdp.rst │ ├── pipeline_parallelism.rst │ └── slowmo_ddp.rst │ ├── getting_involved.rst │ ├── getting_started.rst │ ├── index.rst │ ├── installation_instructions.rst │ ├── integrations.rst │ ├── tutorials │ ├── _static │ │ └── img │ │ │ ├── all_gathered_memory.png │ │ │ ├── layer_memory_activations.png │ │ │ ├── layer_memory_parameters.png │ │ │ ├── layer_memory_profile_optimized.png │ │ │ └── layer_memory_profiles.png │ ├── activation_checkpointing.rst │ ├── adascale.rst │ ├── layer_memory_tracking.rst │ ├── offload_model.rst │ ├── oss.rst │ ├── pipe.rst │ └── slowmo_ddp.rst │ └── what_is_fairscale.rst ├── fairscale ├── README.md ├── __init__.py ├── clib │ └── fused_adam_cuda │ │ ├── compat.h │ │ ├── fused_adam_cuda.cpp │ │ ├── fused_adam_cuda_kernel.cu │ │ └── multi_tensor_apply.cuh ├── experimental │ ├── __init__.py │ ├── nn │ │ ├── __init__.py │ │ ├── ampnet_pipe │ │ │ ├── __init__.py │ │ │ ├── ampnet.py │ │ │ └── pipe.py │ │ ├── auto_shard.py │ │ ├── data_parallel │ │ │ ├── __init__.py │ │ │ └── gossip │ │ │ │ ├── __init__.py │ │ │ │ ├── distributed.py │ │ │ │ ├── gossiper.py │ │ │ │ ├── graph_manager.py │ │ │ │ ├── mixing_manager.py │ │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── cuda_metering.py │ │ │ │ └── helpers.py │ │ ├── distributed_pipeline │ │ │ ├── __init__.py │ │ │ ├── data.py │ │ │ ├── graph.py │ │ │ ├── loss.py │ │ │ ├── partition_handler.py │ │ │ ├── pipeline.py │ │ │ └── trace.py │ │ ├── mevo.py │ │ ├── offload.py │ │ └── sync_batchnorm.py │ ├── optim │ │ ├── __init__.py │ │ └── dynamic_loss_scaler.py │ ├── tooling │ │ ├── __init__.py │ │ └── layer_memory_tracker.py │ └── wgit │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── cli.py │ │ ├── pygit.py │ │ ├── repo.py │ │ ├── sha1_store.py │ │ ├── signal_sparsity.py │ │ ├── signal_sparsity_profiling.py │ │ ├── utils.py │ │ └── version.py ├── fair_dev │ ├── __init__.py │ ├── common_paths.py │ └── testing │ │ ├── __init__.py │ │ ├── golden_testing_data.py │ │ ├── testing.py │ │ └── testing_memory.py ├── internal │ ├── __init__.py │ ├── containers.py │ ├── object.py │ ├── parallel.py │ ├── params.py │ ├── reduce_scatter_bucketer.py │ ├── state_dict.py │ └── version.py ├── nn │ ├── __init__.py │ ├── checkpoint │ │ ├── __init__.py │ │ ├── checkpoint_activations.py │ │ └── checkpoint_utils.py │ ├── data_parallel │ │ ├── __init__.py │ │ ├── fsdp_optim_utils.py │ │ ├── fully_sharded_data_parallel.py │ │ └── sharded_ddp.py │ ├── misc │ │ ├── __init__.py │ │ ├── flatten_params_wrapper.py │ │ └── param_bucket.py │ ├── model_parallel │ │ ├── __init__.py │ │ ├── cross_entropy.py │ │ ├── initialize.py │ │ ├── layers.py │ │ ├── mappings.py │ │ ├── random.py │ │ └── utils.py │ ├── moe │ │ ├── __init__.py │ │ ├── moe_layer.py │ │ └── top2gate.py │ ├── pipe │ │ ├── __init__.py │ │ ├── async_pipe.py │ │ ├── async_pipeline.py │ │ ├── async_schedule.py │ │ ├── balance │ │ │ ├── __init__.py │ │ │ ├── blockpartition.py │ │ │ ├── profile.py │ │ │ └── py.typed │ │ ├── batchnorm.py │ │ ├── checkpoint.py │ │ ├── copy.py │ │ ├── dependency.py │ │ ├── messages.py │ │ ├── microbatch.py │ │ ├── phony.py │ │ ├── pipe.py │ │ ├── pipeline.py │ │ ├── py.typed │ │ ├── rpc.py │ │ ├── skip │ │ │ ├── __init__.py │ │ │ ├── layout.py │ │ │ ├── namespace.py │ │ │ ├── portal.py │ │ │ ├── skippable.py │ │ │ └── tracker.py │ │ ├── stream.py │ │ ├── types.py │ │ └── worker.py │ └── wrap │ │ ├── __init__.py │ │ └── auto_wrap.py ├── optim │ ├── __init__.py │ ├── adam.py │ ├── adascale.py │ ├── grad_scaler.py │ ├── layerwise_gradient_scaler.py │ └── oss.py └── version.py ├── pyproject.toml ├── release_utils.py ├── requirements-benchmarks.txt ├── requirements-dev.txt ├── requirements.txt ├── setup.cfg ├── setup.py ├── stubs └── torch │ ├── __init__.pyi │ ├── autograd │ ├── __init__.pyi │ ├── grad_mode.pyi │ └── profiler.pyi │ ├── backends │ ├── __init__.pyi │ └── cudnn.pyi │ ├── cuda │ ├── __init__.pyi │ ├── amp │ │ ├── __init__.pyi │ │ └── grad_scaler.pyi │ └── comm │ │ └── __init__.pyi │ ├── distributed │ ├── __init__.pyi │ ├── distributed_c10d.pyi │ ├── nn │ │ └── functional.pyi │ └── rpc │ │ └── __init__.pyi │ ├── fft │ └── __init__.pyi │ ├── functional.pyi │ ├── futures.pyi │ ├── jit.pyi │ ├── multiprocessing │ └── __init__.pyi │ ├── nn │ ├── __init__.pyi │ ├── common_types.pyi │ ├── functional.pyi │ ├── modules │ │ ├── __init__.pyi │ │ ├── activation.pyi │ │ ├── adaptive.pyi │ │ ├── batchnorm.pyi │ │ ├── container.pyi │ │ ├── conv.pyi │ │ ├── distance.pyi │ │ ├── dropout.pyi │ │ ├── flatten.pyi │ │ ├── fold.pyi │ │ ├── instancenorm.pyi │ │ ├── linear.pyi │ │ ├── loss.pyi │ │ ├── module.pyi │ │ ├── normalization.pyi │ │ ├── padding.pyi │ │ ├── pixelshuffle.pyi │ │ ├── pooling.pyi │ │ ├── rnn.pyi │ │ ├── sparse.pyi │ │ └── upsampling.pyi │ ├── parallel │ │ ├── __init__.pyi │ │ ├── common_types.pyi │ │ ├── data_parallel.pyi │ │ ├── distributed.pyi │ │ ├── parallel_apply.pyi │ │ ├── replicate.pyi │ │ └── scatter_gather.pyi │ └── parameter.pyi │ ├── optim │ ├── __init__.pyi │ ├── adam.pyi │ ├── lr_scheduler.pyi │ ├── optimizer.pyi │ └── sgd.pyi │ ├── random.pyi │ ├── serialization.pyi │ ├── testing │ └── __init__.pyi │ ├── utils │ ├── __init__.pyi │ ├── checkpoint.pyi │ └── data │ │ ├── __init__.pyi │ │ ├── dataloader.pyi │ │ ├── dataset.pyi │ │ ├── distributed.pyi │ │ └── sampler.pyi │ └── version.pyi └── tests ├── __init__.py ├── ci_test_list_1.txt ├── ci_test_list_2.txt ├── ci_test_list_3.txt ├── ci_test_list_check.sh ├── experimental ├── __init__.py ├── nn │ ├── __init__.py │ ├── ampnet_pipe_process │ │ ├── __init__.py │ │ └── test_ampnet_pipe.py │ ├── data_parallel │ │ └── test_gossip.py │ ├── test_auto_shard.py │ ├── test_mevo.py │ ├── test_multiprocess_pipe.py │ ├── test_offload.py │ └── test_sync_batchnorm.py ├── optim │ ├── __init__.py │ └── test_dynamic_loss_scaler.py ├── tooling │ ├── __init__.py │ └── test_layer_memory_tracker.py └── wgit │ ├── __init__.py │ ├── test_api.py │ ├── test_cli.py │ ├── test_pygit.py │ ├── test_sha1_store.py │ ├── test_signal_sparsity.py │ └── test_signal_sparsity_profiling.py ├── nn ├── __init__.py ├── checkpoint │ ├── __init__.py │ ├── test_checkpoint_activations.py │ └── test_checkpoint_activations_norm.py ├── data_parallel │ ├── __init__.py │ ├── test_fsdp.py │ ├── test_fsdp_apply.py │ ├── test_fsdp_freezing_weights.py │ ├── test_fsdp_fwd_fwd_bwd_bwd.py │ ├── test_fsdp_grad_acc.py │ ├── test_fsdp_hf_transformer_eval.py │ ├── test_fsdp_input.py │ ├── test_fsdp_memory.py │ ├── test_fsdp_metadata.py │ ├── test_fsdp_multiple_forward.py │ ├── test_fsdp_multiple_forward_checkpoint.py │ ├── test_fsdp_multiple_wrapping.py │ ├── test_fsdp_optimizer_utils.py │ ├── test_fsdp_overlap.py │ ├── test_fsdp_pre_backward_hook.py │ ├── test_fsdp_regnet.py │ ├── test_fsdp_shared_weights.py │ ├── test_fsdp_shared_weights_mevo.py │ ├── test_fsdp_state_dict.py │ ├── test_fsdp_summon_full_params.py │ ├── test_fsdp_uneven.py │ ├── test_fsdp_with_checkpoint_wrapper.py │ ├── test_sharded_ddp_features.py │ └── test_sharded_ddp_pytorch_parity.py ├── misc │ ├── __init__.py │ ├── test_flatten_params_wrapper.py │ ├── test_grad_bucket.py │ └── test_param_bucket.py ├── model_parallel │ ├── __init__.py │ ├── test_cross_entropy.py │ ├── test_initialize.py │ ├── test_layers.py │ └── test_random.py ├── moe │ ├── __init__.py │ ├── test_moe_layer.py │ └── test_top2gating.py ├── pipe │ ├── __init__.py │ ├── conftest.py │ ├── skip │ │ ├── __init__.py │ │ ├── test_api.py │ │ ├── test_gpipe.py │ │ ├── test_inspect_skip_layout.py │ │ ├── test_leak.py │ │ ├── test_portal.py │ │ ├── test_stash_pop.py │ │ ├── test_tracker.py │ │ └── test_verify_skippables.py │ ├── test_balance.py │ ├── test_bugs.py │ ├── test_checkpoint.py │ ├── test_checkpoint_ddp.py │ ├── test_copy.py │ ├── test_deferred_batch_norm.py │ ├── test_dependency.py │ ├── test_inplace.py │ ├── test_microbatch.py │ ├── test_parity.py │ ├── test_phony.py │ ├── test_pipe.py │ ├── test_pipeline.py │ ├── test_stream.py │ ├── test_transparency.py │ └── test_worker.py ├── pipe_process │ ├── __init__.py │ ├── conftest.py │ ├── skip │ │ └── __init__.py │ ├── test_bugs.py │ ├── test_inplace.py │ ├── test_pipe.py │ ├── test_rpc.py │ └── test_transparency.py └── wrap │ ├── __init__.py │ └── test_wrap.py ├── optim ├── __init__.py ├── test_adam.py ├── test_ddp_adascale.py ├── test_layerwise_gradient_scaler.py ├── test_oss.py ├── test_oss_adascale.py └── test_single_node_adascale.py ├── run_mpi_tests.sh └── utils ├── __init__.py ├── test_containers.py ├── test_parallel.py ├── test_reduce_scatter_bucketer.py ├── test_state_dict.py └── test_version.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | docs/* 4 | tests/* 5 | setup.py 6 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*.py] 4 | charset = utf-8 5 | trim_trailing_whitespace = true 6 | end_of_line = lf 7 | insert_final_newline = true 8 | indent_style = space 9 | indent_size = 4 10 | 11 | [*.md] 12 | trim_trailing_whitespace = false 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | Briefly introduce yourself. 2 | 3 | --- 4 | name: "\U0001F41B Bug Report" 5 | about: Submit a bug report to help us improve fairscale 6 | 7 | --- 8 | 9 | ## 🐛 Bug 10 | 11 | 12 | 13 | ## Command 14 | 15 | ## To Reproduce 16 | 17 | Steps to reproduce the behavior: 18 | 19 | 20 | 21 | 1. 22 | 2. 23 | 3. 24 | 25 | 26 | 27 | ## Expected behavior 28 | 29 | 30 | 31 | ## Environment 32 | 33 | Please copy and paste the output from the 34 | environment collection script from PyTorch 35 | (or fill out the checklist below manually). 36 | 37 | You can run the script with: 38 | ``` 39 | # For security purposes, please check the contents of collect_env.py before running it. 40 | python -m torch.utils.collect_env 41 | ``` 42 | 43 | - PyTorch Version (e.g., 1.0): 44 | - OS (e.g., Linux): 45 | - How you installed PyTorch (`conda`, `pip`, source): 46 | - Build command you used (if compiling from source): 47 | - Python version: 48 | - CUDA/cuDNN version: 49 | - GPU models and configuration: 50 | - Any other relevant information: 51 | 52 | ## Additional context 53 | 54 | 55 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | Briefly introduce yourself. 2 | 3 | --- 4 | name: "\U0001F680Feature Request" 5 | about: Submit a proposal/request for a new fairscale feature 6 | 7 | --- 8 | 9 | ## 🚀 Feature 10 | 11 | 12 | ## Motivation 13 | 14 | 15 | 16 | ## Pitch 17 | 18 | 19 | 20 | ## Alternatives 21 | 22 | 23 | 24 | ## Additional context 25 | 26 | 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/questions-help-support.md: -------------------------------------------------------------------------------- 1 | Briefly introduce yourself 2 | 3 | --- 4 | name: "❓Questions/Help/Support" 5 | about: Do you need support? 6 | 7 | --- 8 | 9 | ## ❓ Questions and Help 10 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What does this PR do? 2 | Fixes # (issue). 3 | 4 | ## Before submitting 5 | 6 | - [ ] Did you have fun? 7 | - Make sure you had fun coding 🙃 8 | - [ ] Did you read the [contributor guideline](https://github.com/facebookresearch/fairscale/blob/main/CONTRIBUTING.md)? 9 | - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) 10 | - [ ] N/A 11 | - [ ] Did you make sure to update the docs? 12 | - [ ] N/A 13 | - [ ] Did you write any new necessary tests? 14 | - [ ] N/A 15 | - [ ] Did you update the [changelog](https://github.com/facebookresearch/fairscale/blob/main/CHANGELOG.md)? (if needed) 16 | - [ ] N/A 17 | 18 | 19 | ## PR review 20 | Anyone in the community is free to review the PR once the tests have passed. 21 | If we didn't discuss your PR in Github issues there's a high chance it will not be merged. 22 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-20.04 11 | strategy: 12 | matrix: 13 | # make sure python versions are consistent with those used in .circleci/config.yml 14 | python-version: ['3.8.12', '3.9.7', '3.10.1'] 15 | steps: 16 | - uses: actions/checkout@v2 17 | - uses: actions/setup-python@v2 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - uses: pre-commit/action@v2.0.3 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Editor 2 | *~ 3 | *.swp 4 | 5 | # IDEs 6 | .idea/ 7 | 8 | # Testing 9 | *.pyc 10 | *.pyo 11 | .mypy_cache/ 12 | *.egg-info/ 13 | .testmondata 14 | 15 | # Build and release 16 | build/ 17 | dist/ 18 | .eggs/ 19 | 20 | # Pytest verbose output 21 | test-results/ 22 | 23 | # Coverage reports 24 | .coverage 25 | .coverage.* 26 | ./coverage.xml 27 | 28 | # Environments 29 | .env 30 | .venv 31 | .vscode 32 | env/ 33 | venv/ 34 | ENV/ 35 | env.bak/ 36 | venv.bak/ 37 | .vscode/* 38 | *.DS_Store 39 | 40 | # Data generated by tests 41 | cached_datasets/ 42 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # If you change the versions below, please make sure they are in-sync 7 | # with requirements-dev.txt 8 | 9 | exclude: 'build|stubs' 10 | 11 | default_language_version: 12 | python: python3 13 | 14 | repos: 15 | - repo: https://github.com/pre-commit/pre-commit-hooks 16 | rev: v4.0.1 17 | hooks: 18 | - id: trailing-whitespace 19 | - id: check-ast 20 | - id: check-merge-conflict 21 | - id: check-added-large-files 22 | args: ['--maxkb=500'] 23 | - id: end-of-file-fixer 24 | 25 | - repo: https://github.com/ambv/black 26 | rev: 22.3.0 27 | hooks: 28 | - id: black 29 | 30 | - repo: https://github.com/PyCQA/flake8 31 | rev: 4.0.1 32 | hooks: 33 | - id: flake8 34 | args: [--show-source, --statistics] 35 | 36 | - repo: https://github.com/pycqa/isort 37 | rev: 5.10.1 38 | hooks: 39 | - id: isort 40 | exclude: README.md 41 | additional_dependencies: [toml] 42 | 43 | - repo: https://github.com/pre-commit/mirrors-mypy 44 | rev: 'v0.910' 45 | hooks: 46 | - id: mypy 47 | args: [--no-strict-optional, --ignore-missing-imports, --scripts-are-modules, --pretty] 48 | # See requirements-dev.txt for the reason for a fixed version of numpy here. 49 | additional_dependencies: [numpy==1.21.5] 50 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # We need python > 3.8 due to a dependency on numpy. 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.9" 13 | # You can also specify other tool versions: 14 | # nodejs: "16" 15 | # rust: "1.55" 16 | # golang: "1.17" 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | configuration: docs/source/conf.py 21 | 22 | # If using Sphinx, optionally build your docs in additional formats such as PDF 23 | # formats: 24 | # - pdf 25 | 26 | # Optionally declare the Python requirements required to build your docs 27 | python: 28 | install: 29 | - requirements: docs/requirements.txt 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | From fairscale: 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates 4 | 5 | === 6 | 7 | From torchgpipe (fairscale/nn/pipe): 8 | 9 | Copyright 2019 Kakao Brain 10 | 11 | All contributions by Facebook: 12 | Copyright (c) Facebook, Inc. and its affiliates 13 | 14 | === 15 | 16 | Redistribution and use in source and binary forms, with or without 17 | modification, are permitted provided that the following conditions are met: 18 | 19 | 1. Redistributions of source code must retain the above copyright 20 | notice, this list of conditions and the following disclaimer. 21 | 22 | 2. Redistributions in binary form must reproduce the above copyright 23 | notice, this list of conditions and the following disclaimer in the 24 | documentation and/or other materials provided with the distribution. 25 | 26 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 27 | and IDIAP Research Institute nor the names of its contributors may be 28 | used to endorse or promote products derived from this software without 29 | specific prior written permission. 30 | 31 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 32 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 33 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 34 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 35 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 36 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 37 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 38 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 39 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 40 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 41 | POSSIBILITY OF SUCH DAMAGE. 42 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include requirements.txt 3 | recursive-include fairscale *.h *.cuh 4 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | ## Steps to do a release 2 | 3 | ### New Approach 4 | - Go to the [fairscale release workflow](https://github.com/facebookresearch/fairscale/actions/workflows/release.yml) in Github actions. 5 | - In the __Run Workflow__ dropdown, select the branch from which you wish to release. The default value is __main__ and should be used in almost all cases. 6 | - In adherence to [Semantic Versioning]((https://semver.org/spec/v2.0.0.html)) enter one of the following three values for _Release Type_: 7 | - _patch_ 8 | - _minor_ 9 | - _major_ 10 | - Click __Run Workflow__. 11 | - Verify [fairscale/version.py](https://github.com/facebookresearch/fairscale/blob/main/fairscale/version.py) has been updated. 12 | - Verify a new [PyPI package](https://pypi.org/project/fairscale/) has been published. 13 | - Verify a new [Github release](https://github.com/facebookresearch/fairscale/releases) has been created. 14 | 15 | --- 16 | ### Old Approach 17 | 18 | - Update the CHANGELOG.md 19 | - Update "what's new" in README.md 20 | - If needed, update the PyTorch versions in README.md in the Testing section. 21 | - Update `fairscale/__init__.py` and `docs/source/conf.py` for the new version number 22 | - git commit the change with title like "[chore] 0.3.1 release" 23 | - make a tag, like `git tag v0.3.1` 24 | - git push --tags origin [your/branch] 25 | - `python3 setup.py sdist` to build a new package (will be in dist/) 26 | - `python3 -m twine upload --repository pypi dist/*` to upload to pypi 27 | - visit [this page](https://github.com/facebookresearch/fairscale/tags) and create the newly 28 | tagged release. 29 | -------------------------------------------------------------------------------- /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /benchmarks/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /benchmarks/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | from pathlib import Path 8 | import shutil 9 | import tempfile 10 | 11 | from torchvision.datasets import MNIST 12 | 13 | TEMPDIR = tempfile.gettempdir() 14 | 15 | 16 | def setup_cached_mnist(): 17 | done, tentatives = False, 0 18 | while not done and tentatives < 5: 19 | # Monkey patch the resource URLs to work around a possible blacklist 20 | MNIST.mirrors = ["https://github.com/blefaudeux/mnist_dataset/raw/main/"] + MNIST.mirrors 21 | 22 | # This will automatically skip the download if the dataset is already there, and check the checksum 23 | try: 24 | _ = MNIST(transform=None, download=True, root=TEMPDIR) 25 | done = True 26 | except RuntimeError as e: 27 | logging.warning(e) 28 | mnist_root = Path(TEMPDIR + "/MNIST") 29 | # Corrupted data, erase and restart 30 | shutil.rmtree(str(mnist_root)) 31 | 32 | tentatives += 1 33 | 34 | if done is False: 35 | logging.error("Could not download MNIST dataset") 36 | exit(-1) 37 | else: 38 | logging.info("Dataset downloaded") 39 | -------------------------------------------------------------------------------- /benchmarks/experimental/benchmark_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | # TODO(sidgoyal): Refactor benchmarks to remove this file eventually. 10 | 11 | 12 | def collate_sentences_lm(samples): 13 | 14 | if len(samples) == 0: 15 | return {} 16 | 17 | id = torch.LongTensor([s["id"] for s in samples]) 18 | src_tokens = torch.stack([s["source"] for s in samples], 0) 19 | tgt_tokens = torch.stack([s["target"] for s in samples], 0) 20 | ntokens = len(samples) * len(samples[0]["target"]) 21 | src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples)) 22 | 23 | batch = { 24 | "id": id, 25 | "nsentences": len(samples), 26 | "ntokens": ntokens, 27 | "input": src_tokens, 28 | "target": tgt_tokens, 29 | } 30 | return batch 31 | 32 | 33 | class BenchmarkLMDataset(Dataset): 34 | """ 35 | Dataset to benchmark a translation like seq2seq task. 36 | Args: 37 | vocab_size (int, optional): size of the vocabulary (default 10000). 38 | max_source_positions (int, optional): max number of tokens in the 39 | source sentence (default: 1024). 40 | total_samples (int, optional): the total number of rows in the 41 | dataset (default: 10000). 42 | """ 43 | 44 | def __init__( 45 | self, 46 | vocab_size=10000, 47 | max_source_positions=1024, 48 | total_samples=10000, 49 | ): 50 | self.vocab_size = vocab_size 51 | self.max_source_positions = max_source_positions 52 | self.total_samples = total_samples 53 | self.sizes = [self.max_source_positions] * self.total_samples 54 | 55 | def __getitem__(self, index): 56 | length = self.sizes[index] 57 | source = torch.randint(1, self.vocab_size, (length,)) 58 | target = source.clone() 59 | return { 60 | "id": index, 61 | "source": source, 62 | "target": target, 63 | } 64 | 65 | def __len__(self): 66 | return self.total_samples 67 | -------------------------------------------------------------------------------- /benchmarks/experimental/sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import tempfile 7 | import time 8 | 9 | import torch 10 | import torch.distributed as dist 11 | import torch.multiprocessing as mp 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | 14 | import fairscale.experimental.nn 15 | 16 | 17 | def benchmark_bn(rank, world_size, init_file, bn_cls): 18 | dist.init_process_group(dist.Backend.NCCL, init_method="file://" + init_file, rank=rank, world_size=world_size) 19 | x = torch.randn(50, 2048, 7, 7).to(rank) 20 | bn = bn_cls(2048).to(rank) 21 | bn = DDP(bn, device_ids=[rank]) 22 | # Warmup 23 | for i in range(50): 24 | with torch.no_grad(): 25 | x = bn(x) 26 | torch.cuda.synchronize(rank) 27 | t0 = time.time() 28 | for i in range(100): 29 | with torch.no_grad(): 30 | x = bn(x) 31 | torch.cuda.synchronize(rank) 32 | t1 = time.time() 33 | print("Elapsed time is ", t1 - t0) 34 | 35 | 36 | if __name__ == "__main__": 37 | world_size = torch.cuda.device_count() 38 | for cls in [torch.nn.BatchNorm2d, torch.nn.SyncBatchNorm, fairscale.experimental.nn.SyncBatchNorm]: 39 | print(cls) 40 | mp.spawn(benchmark_bn, args=(world_size, tempfile.mkstemp()[1], cls), nprocs=world_size) 41 | -------------------------------------------------------------------------------- /benchmarks/golden_configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /benchmarks/golden_configs/oss_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | def get_golden_real_stats(): 8 | 9 | return { 10 | "reference_speed": 578, 11 | "reference_memory": 945, 12 | "reference_loss": 0.026, 13 | } 14 | 15 | 16 | def get_golden_synthetic_stats(): 17 | # TODO(anj-s): Add support for synthetic regression benchmarks 18 | raise NotImplementedError("Synthetic data benchmarks are not supported.") 19 | -------------------------------------------------------------------------------- /benchmarks/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | codecov: 7 | require_ci_to_pass: yes 8 | coverage: 9 | status: 10 | project: 11 | default: 12 | target: 94% 13 | threshold: 0.1% 14 | parsers: 15 | gcov: 16 | branch_detection: 17 | conditional: yes 18 | loop: yes 19 | method: no 20 | macro: no 21 | comment: 22 | layout: "reach,diff,flags,tree" 23 | behavior: default 24 | require_changes: no 25 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Minimal makefile for Sphinx documentation 7 | # 8 | 9 | # You can set these variables from the command line, and also 10 | # from the environment for the first two. 11 | SPHINXOPTS ?= 12 | SPHINXBUILD ?= sphinx-build 13 | SOURCEDIR = source 14 | BUILDDIR = build 15 | 16 | # Put it first so that "make" without argument is like "make help". 17 | help: 18 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 19 | 20 | setup: 21 | pip install -r requirements.txt 22 | 23 | .PHONY: help Makefile setup 24 | 25 | # Catch-all target: route all unknown targets to Sphinx using the new 26 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 27 | %: Makefile 28 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 29 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | recommonmark==0.5.0 2 | sphinx==4.2.0 3 | sphinx_rtd_theme==0.4.3 4 | sphinxcontrib-programoutput==0.16 5 | git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 6 | # Need to sync with fairscale's requirements.txt below 7 | torch>=1.6.0 8 | numpy >= 1.22.0 9 | -------------------------------------------------------------------------------- /docs/source/_static/css/customize.css: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. */ 2 | /* 3 | * some extra css to make markdown look similar between github/sphinx 4 | */ 5 | 6 | 7 | .tutorials-header .header-logo { 8 | background-image: url("../images/fairscale-logo-dark.svg"); 9 | background-repeat: no-repeat; 10 | background-position: center; 11 | } 12 | 13 | /* .header-logo { 14 | background-image: url("../images/fairscale-logo.svg"); 15 | } */ 16 | 17 | /* .footer-logo { 18 | background-image: url("../images/fairscale-logo-icon.svg"); 19 | } */ 20 | -------------------------------------------------------------------------------- /docs/source/_static/img/ddp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/_static/img/ddp.png -------------------------------------------------------------------------------- /docs/source/_static/img/fairscale-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/_static/img/fairscale-logo.png -------------------------------------------------------------------------------- /docs/source/_static/img/flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/_static/img/flowchart.png -------------------------------------------------------------------------------- /docs/source/_static/img/fsdp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/_static/img/fsdp.png -------------------------------------------------------------------------------- /docs/source/_static/img/global.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/_static/img/global.png -------------------------------------------------------------------------------- /docs/source/_static/img/offload.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/_static/img/offload.png -------------------------------------------------------------------------------- /docs/source/_static/img/oss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/_static/img/oss.png -------------------------------------------------------------------------------- /docs/source/_static/img/pipe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/_static/img/pipe.png -------------------------------------------------------------------------------- /docs/source/_static/img/sdp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/_static/img/sdp.png -------------------------------------------------------------------------------- /docs/source/_templates/theme_variables.jinja: -------------------------------------------------------------------------------- 1 | {%- 2 | set external_urls = { 3 | 'github': 'https://github.com/facebookresearch/fairscale', 4 | 'github_issues': 'https://github.com/facebookresearch/fairscale/issues', 5 | 'contributing': 'https://github.com/facebookresearch/fairscale/blob/main/CONTRIBUTING.md', 6 | 'docs': 'https://fairscale.readthedocs.io/', 7 | 'home': 'https://fairscale.readthedocs.io/', 8 | 'get_started': 'https://github.com/facebookresearch/fairscale/blob/main/README.md', 9 | 'brand_guidelines': 'https://pytorch.org/assets/brand-guidelines/PyTorch-Brand-Guidelines.pdf' 10 | } 11 | -%} 12 | {%- 13 | set og = { 14 | 'description': 'API docs for FairScale. FairScale is a PyTorch extension library for high performance and large scale training.' 15 | } 16 | -%} 17 | -------------------------------------------------------------------------------- /docs/source/api/experimental/nn/offload_model.rst: -------------------------------------------------------------------------------- 1 | Offload Model 2 | ============== 3 | 4 | .. autoclass:: fairscale.experimental.nn.OffloadModel 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/source/api/experimental/nn/slowmo_ddp.rst: -------------------------------------------------------------------------------- 1 | SlowMo Distributed Data Parallel 2 | ================================ 3 | 4 | .. autoclass:: fairscale.experimental.nn.data_parallel.SlowMoDistributedDataParallel 5 | :members: 6 | :undoc-members: 7 | :exclude-members: eval, forward, load_state_dict, state_dict, train, training 8 | -------------------------------------------------------------------------------- /docs/source/api/index.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | optim/adascale 8 | optim/oss 9 | nn/moe 10 | nn/pipe 11 | nn/sharded_ddp 12 | nn/fsdp 13 | nn/checkpoint/checkpoint_activations 14 | experimental/nn/offload_model 15 | experimental/nn/slowmo_ddp 16 | -------------------------------------------------------------------------------- /docs/source/api/nn/checkpoint/checkpoint_activations.rst: -------------------------------------------------------------------------------- 1 | Activation Checkpoint 2 | ====================== 3 | 4 | .. autoclass:: fairscale.nn.checkpoint.checkpoint_wrapper 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/source/api/nn/fsdp.rst: -------------------------------------------------------------------------------- 1 | Fully Sharded Data Parallel 2 | ======================================================= 3 | 4 | .. autoclass:: fairscale.nn.FullyShardedDataParallel 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/source/api/nn/moe.rst: -------------------------------------------------------------------------------- 1 | Mixture Of Experts 2 | ================== 3 | 4 | .. autoclass:: fairscale.nn.MOELayer 5 | -------------------------------------------------------------------------------- /docs/source/api/nn/pipe.rst: -------------------------------------------------------------------------------- 1 | Pipe 2 | ==== 3 | 4 | .. autoclass:: fairscale.nn.Pipe 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/source/api/nn/sharded_ddp.rst: -------------------------------------------------------------------------------- 1 | Sharded Data Parallel 2 | ===================== 3 | 4 | .. autoclass:: fairscale.nn.ShardedDataParallel 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/source/api/optim/adascale.rst: -------------------------------------------------------------------------------- 1 | AdaScale SGD 2 | ============ 3 | 4 | .. autoclass:: fairscale.optim.AdaScale 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/source/api/optim/oss.rst: -------------------------------------------------------------------------------- 1 | Optimizer State Sharding 2 | ======================== 3 | 4 | .. autoclass:: fairscale.optim.OSS 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/source/blogs_and_press.rst: -------------------------------------------------------------------------------- 1 | Blogs and Press 2 | ================= 3 | 4 | 1. `Hugging Face with ZeRO `_ 5 | 2. `Pytorch Lightning `_ 6 | 3. `MMT `_ 7 | 4. `SEER `_ 8 | -------------------------------------------------------------------------------- /docs/source/deep_dive/activation_checkpointing.rst: -------------------------------------------------------------------------------- 1 | Enhanced Activation Checkpointing 2 | ================================= 3 | 4 | Activation checkpointing is a technique used to reduce GPU memory usage during training. This is 5 | done by avoiding the need to store intermediate activation tensors during the forward pass. Instead, 6 | the forward pass is recomputed by keeping track of the original input during the backward pass. 7 | There is a slight increase in computation cost (about 33%) but this reduces the need to store 8 | large activation tensors which allows us to increase the batch size and thereby the net throughput 9 | of the model. 10 | 11 | 12 | Activation checkpointing is implemented by overriding `torch.autograd.Function`. In the `forward` 13 | function which handles the forward pass of the module, using `no_grad`, we can prevent the creation 14 | of the forward graph and materialization of intermediate activation tensors for a long period of 15 | time (i.e till the backward pass). Instead, during the backward pass, the forward pass is executed 16 | again followed by the backward pass. The inputs to the forward pass are saved using a context object 17 | that is then accessed in the backward pass to retrieve the original inputs. We also save the 18 | Random Number Generator(RNG) state for the forward and backward passes as required for Dropout layers. 19 | 20 | The above functionality is already implemented as part of the `torch.utils.checkpoint.checkpoint_wrapper` 21 | API whereby different modules in the forward pass can be wrapped. The wrapper in FairScale offers 22 | functionality beyond that provided by the PyTorch API specifically you can use 23 | `fairscale.nn.checkpoint.checkpoint_wrapper` to wrap a `nn.Module`, handle kwargs in the forward 24 | pass, offload intermediate activations to the CPU and handle non-tensor outputs returned from the 25 | forward function. 26 | 27 | Best practices for `fairscale.nn.checkpoint.checkpoint_wrapper` 28 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 29 | 30 | 1. Memory savings depends entirely on the model and the segmentation of checkpoint wrapping. 31 | Each backprop consists of several mini-forward and backprop passes. The gain is entirely dependent 32 | on the memory footprint of the layer’s activations. 33 | 34 | 2. When using BatchNormalization you may need to freeze the calculation of statistics since we run 35 | the forward pass twice. 36 | 37 | 3. Ensure that the input tensor’s `requires_grad` field is set to True. In order to trigger the 38 | backward function, the output needs to have this field set. By setting it on the input tensor we 39 | ensure that this is propagated to the output and the `backward` function is triggered. 40 | -------------------------------------------------------------------------------- /docs/source/deep_dive/offload.rst: -------------------------------------------------------------------------------- 1 | OffloadModel 2 | ============= 3 | 4 | Heavily inspired by the `Layer-to-Layer `_ algorithm and 5 | `Zero-Offload `_, OffloadModel uses the CPU to store 6 | the entire model, optimizer state and gradients. OffloadModel then brings in a layer (or a number of 7 | layers) onto the GPU for training at a time during the forward and backward pass. The intermediate 8 | activations for the layer boundaries are also stored on the CPU and copied to the GPU as needed for 9 | the backward pass. Once the backward pass is completed all the parameters are updated with the 10 | gradients present on the CPU. 11 | 12 | .. image:: ../_static/img/offload.png 13 | :height: 500px 14 | :width: 500px 15 | 16 | Offload uses the following techniques to enable large model training: 17 | 18 | 1. The model is assumed to be nn.Sequential and sharded (almost) equally based on the number of 19 | parameters into a list of nn.Modules. Each nn.Module now contains a fraction of the whole model 20 | which we shall refer to as model shards. 21 | 22 | 2. At each iteration, each of the model shards are copied from the CPU -> GPU, FW pass is computed 23 | using the minibatch of data and the model shard is copied back from GPU -> CPU. In the BW pass, the 24 | same process is repeated. 25 | 26 | 3. The optimizer remains on the CPU and gradients and parameters are all moved onto the CPU before 27 | running optimizer.step. This ensures that the CPU is responsible for updating the parameters and 28 | holding onto the optimizer state. 29 | 30 | 4. If activation checkpointing is enabled, we use torch.autograd.Function to disable graph construction 31 | in the FW pass and copy intermediate activations from GPU -> CPU after the FW pass of a given shard is 32 | complete. The reverse copy is carried out in the BW pass. 33 | 34 | 5. Microbatches are used to enable larger throughput and offset the cost of moving model parameters 35 | and activations from CPU <-> GPU. Micro-batches allow you to specify large mini-batches which are 36 | broken down into micro-batches and fed to the model shards at each iteration. In short it is a way 37 | to allow more computation at a given time on a model shard to offset the cost of copying from CPU <-> GPU. 38 | 39 | Best practices for using `fairscale.experimental.nn.OffloadModel` 40 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 41 | 42 | 1. Using OffloadModel to train large models can result in loss of throughput which can be overcome by using activation checkpointing and microbatches. 43 | 44 | 2. OffloadModel currently only works for `nn.Sequential` models. 45 | -------------------------------------------------------------------------------- /docs/source/deep_dive/pipeline_parallelism.rst: -------------------------------------------------------------------------------- 1 | Pipeline Parallelism 2 | ===================== 3 | 4 | Training large models can lead to out-of-memory when the size of the model is too large for a single GPU. 5 | To train such a large model, layers can be pipelined across different GPU devices as described in GPipe. 6 | The `fairscale.nn.Pipe` is an implementation of GPipe which has been adopted from torchgpipe. This API 7 | has also been upstreamed to PyTorch in the 1.8 release with the experimental tag. 8 | 9 | .. image:: ../_static/img/pipe.png 10 | 11 | Gpipe first shards the model across different devices where each device hosts a shard of the model. 12 | A shard can be a single layer or a series of layers. However Gpipe splits a mini-batch of data into 13 | micro-batches and feeds it to the device hosting the first shard. The layers on each device process 14 | the micro-batches and send the output to the following shard/device. In the meantime it is ready to 15 | process the micro batch from the previous shard/device. By pipepling the input in this way, Gpipe is 16 | able to reduce the idle time of devices. 17 | 18 | Best practices for using `fairscale.nn.Pipe` 19 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 20 | 21 | 1. Choice of size of micro-batches can affect GPU utilization. A smaller microbatch can reduce latency of shards waiting for previous shard outputs but a large microbatch better utilizes GPUs. 22 | 23 | 2. Sharding the model can also impact GPU utilization where layers with heavier computation can slow down the shards downstream. 24 | -------------------------------------------------------------------------------- /docs/source/getting_involved.rst: -------------------------------------------------------------------------------- 1 | Getting Involved 2 | ================= 3 | 4 | We welcome contributions from everyone! Please see the `CONTRIBUTING `_ 5 | guide on GitHub for more details on how you can contribute to FairScale. 6 | -------------------------------------------------------------------------------- /docs/source/getting_started.rst: -------------------------------------------------------------------------------- 1 | User Workflow 2 | ============== 3 | 4 | User workflow Diagram with explanation of various decision points 5 | 6 | .. image:: _static/img/flowchart.png 7 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. FairScale documentation master file, created by 2 | sphinx-quickstart on Tue Sep 8 16:19:17 2020. 3 | You can adapt this file completely to your liking, 4 | but it should at least contain the root `toctree` 5 | directive. 6 | 7 | FairScale Documentation 8 | ======================= 9 | 10 | FairScale is a PyTorch extension library for high performance and large scale training. 11 | FairScale makes available the latest distributed training techniques in the form of composable 12 | modules and easy to use APIs. 13 | 14 | .. toctree:: 15 | :maxdepth: 1 16 | :caption: Index 17 | 18 | what_is_fairscale 19 | getting_started 20 | blogs_and_press 21 | getting_involved 22 | integrations 23 | 24 | | 25 | | 26 | 27 | .. toctree:: 28 | :maxdepth: 1 29 | :caption: Installation 30 | 31 | installation_instructions 32 | 33 | | 34 | | 35 | 36 | .. toctree:: 37 | :maxdepth: 1 38 | :caption: Deep Dive 39 | 40 | deep_dive/oss_sdp_fsdp 41 | deep_dive/offload 42 | deep_dive/adascale 43 | deep_dive/pipeline_parallelism 44 | deep_dive/activation_checkpointing 45 | deep_dive/slowmo_ddp 46 | 47 | | 48 | | 49 | 50 | .. toctree:: 51 | :maxdepth: 1 52 | :caption: Tutorials 53 | 54 | tutorials/oss 55 | tutorials/activation_checkpointing 56 | tutorials/offload_model 57 | tutorials/adascale 58 | tutorials/pipe 59 | tutorials/layer_memory_tracking 60 | tutorials/slowmo_ddp 61 | 62 | | 63 | | 64 | 65 | .. toctree:: 66 | :maxdepth: 1 67 | :caption: API Documentation 68 | 69 | api/index 70 | -------------------------------------------------------------------------------- /docs/source/installation_instructions.rst: -------------------------------------------------------------------------------- 1 | Installing FairScale 2 | ==================== 3 | 4 | Installing FairScale is extremely simple with pre-built binaries (pip) that we provide. You can also build 5 | from source using the instructions below. 6 | 7 | 8 | Requirements 9 | ~~~~~~~~~~~~ 10 | 11 | * PyTorch>= 1.8.1 12 | 13 | 14 | Installing the pip package (stable) 15 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 16 | 17 | .. code-block:: bash 18 | 19 | pip install fairscale 20 | 21 | 22 | Installing with conda 23 | ~~~~~~~~~~~~~~~~~~~~~ 24 | 25 | Fairscale is packaged by conda-forge (see `here `_) 26 | for both linux & osx, with GPU-enabled builds available on linux. 27 | 28 | .. code-block:: bash 29 | 30 | conda install -c conda-forge fairscale 31 | 32 | 33 | Installing from source 34 | ~~~~~~~~~~~~~~~~~~~~~~ 35 | 36 | .. code-block:: bash 37 | 38 | git clone https://github.com/facebookresearch/fairscale.git 39 | cd fairscale 40 | pip install -r requirements.txt 41 | # -e signified dev mode since e stands for editable 42 | pip install -e . 43 | 44 | To build with GPU-support enabled, be sure to set ``BUILD_CUDA_EXTENSIONS=1`` 45 | as well as an appropriate ``TORCH_CUDA_ARCH_LIST``. 46 | 47 | Note: If either of the above fails, add ``--no-build-isolation`` to the ``pip install`` 48 | command (this could be a problem with recent versions of pip). 49 | -------------------------------------------------------------------------------- /docs/source/integrations.rst: -------------------------------------------------------------------------------- 1 | Integrations 2 | ============ 3 | 4 | FairScale has integrated with the following frameworks: 5 | 6 | 1. Fairseq 7 | 2. VISSL 8 | 3. PyTorch Lightning 9 | 4. Hugging Face 10 | -------------------------------------------------------------------------------- /docs/source/tutorials/_static/img/all_gathered_memory.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/tutorials/_static/img/all_gathered_memory.png -------------------------------------------------------------------------------- /docs/source/tutorials/_static/img/layer_memory_activations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/tutorials/_static/img/layer_memory_activations.png -------------------------------------------------------------------------------- /docs/source/tutorials/_static/img/layer_memory_parameters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/tutorials/_static/img/layer_memory_parameters.png -------------------------------------------------------------------------------- /docs/source/tutorials/_static/img/layer_memory_profile_optimized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/tutorials/_static/img/layer_memory_profile_optimized.png -------------------------------------------------------------------------------- /docs/source/tutorials/_static/img/layer_memory_profiles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/docs/source/tutorials/_static/img/layer_memory_profiles.png -------------------------------------------------------------------------------- /docs/source/tutorials/activation_checkpointing.rst: -------------------------------------------------------------------------------- 1 | Efficient memory usage using Activation Checkpointing 2 | ===================================================== 3 | 4 | Adapted from `torch.utils.checkpoint`, this is a friendlier wrapper for performing activation checkpointing. 5 | 6 | Compared to the PyTorch version, this version wraps a `nn.Module` and allows for all subsequent calls to be 7 | checkpointed. 8 | 9 | .. code-block:: python 10 | 11 | 12 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 13 | 14 | 15 | class CheckpointModel(nn.Module): 16 | 17 | def __init__(self, **kwargs): 18 | super().__init__() 19 | torch.manual_seed(0) # make sure weights are deterministic. 20 | self.ffn_module = nn.Sequential( 21 | nn.Linear(32, 128), 22 | nn.Dropout(p=0.5), 23 | nn.Linear(128, 32), 24 | ) 25 | 26 | self.ffn_module = checkpoint_wrapper(self.ffn_module, **kwargs) 27 | self.last_linear = nn.Linear(32, 1) 28 | 29 | def forward(self, input): 30 | output = self.ffn_module(input) 31 | return self.last_linear(output) 32 | -------------------------------------------------------------------------------- /docs/source/tutorials/pipe.rst: -------------------------------------------------------------------------------- 1 | Model sharding using Pipeline Parallel 2 | ====================================== 3 | 4 | Let us start with a toy model that contains two linear layers. 5 | 6 | .. code-block:: default 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | class ToyModel(nn.Module): 13 | def __init__(self): 14 | super(ToyModel, self).__init__() 15 | self.net1 = torch.nn.Linear(10, 10) 16 | self.relu = torch.nn.ReLU() 17 | self.net2 = torch.nn.Linear(10, 5) 18 | 19 | def forward(self, x): 20 | x = self.relu(self.net1(x)) 21 | return self.net2(x) 22 | 23 | model = ToyModel() 24 | 25 | To run this model on 2 GPUs we need to convert the model 26 | to ``torch.nn.Sequential`` and then wrap it with ``fairscale.nn.Pipe``. 27 | 28 | .. code-block:: default 29 | 30 | 31 | import fairscale 32 | import torch 33 | import torch.nn as nn 34 | 35 | model = nn.Sequential( 36 | torch.nn.Linear(10, 10), 37 | torch.nn.ReLU(), 38 | torch.nn.Linear(10, 5) 39 | ) 40 | 41 | model = fairscale.nn.Pipe(model, balance=[2, 1]) 42 | 43 | This will run the first two layers on ``cuda:0`` and the last 44 | layer on ``cuda:1``. To learn more, visit the `Pipe <../api/nn/pipe.html>`_ documentation. 45 | 46 | You can then define any optimizer and loss function 47 | 48 | .. code-block:: default 49 | 50 | 51 | import torch.optim as optim 52 | import torch.nn.functional as F 53 | 54 | optimizer = optim.SGD(model.parameters(), lr=0.001) 55 | loss_fn = F.nll_loss 56 | 57 | optimizer.zero_grad() 58 | target = torch.randint(0,2,size=(20,1)).squeeze() 59 | data = torch.randn(20, 10) 60 | 61 | 62 | 63 | Finally, to run the model and compute the loss function, make sure that outputs and target are on the same device. 64 | 65 | .. code-block:: default 66 | 67 | device = model.devices[0] 68 | ## outputs and target need to be on the same device 69 | # forward step 70 | outputs = model(data.to(device)) 71 | # compute loss 72 | loss = loss_fn(outputs.to(device), target.to(device)) 73 | 74 | # backward + optimize 75 | loss.backward() 76 | optimizer.step() 77 | 78 | 79 | -------------------------------------------------------------------------------- /docs/source/what_is_fairscale.rst: -------------------------------------------------------------------------------- 1 | What is FairScale? 2 | ==================== 3 | 4 | FairScale is a PyTorch extension library for high performance and large scale training. 5 | This library extends basic PyTorch capabilities while adding new SOTA scaling techniques. 6 | FairScale makes available the latest distributed training techniques in the form of composable 7 | modules and easy to use APIs. These APIs are a fundamental part of a researcher's toolbox as 8 | they attempt to scale models with limited resources. 9 | 10 | .. image:: _static/img/global.png 11 | :width: 400px 12 | :height: 400px 13 | :align: center 14 | 15 | FairScale was designed with the following values in mind: 16 | 17 | 1. **Usability** - Users should be able to understand and use FairScale APIs with minimum cognitive overload. 18 | 19 | 2. **Modularity** - Users should be able to combine multiple FairScale APIs as part of their training loop seamlessly. 20 | 21 | 3. **Performance** - FairScale APIs provide the best performance in terms of scaling and efficiency. 22 | 23 | .. image:: _static/img/ddp.png 24 | 25 | ML training at scale traditionally means `data parallelism `_ 26 | which allows us to use multiple devices at the same 27 | time to train a large batch size per step thereby achieving the goal accuracy in a shorter period of time 28 | as compared to training on a single device. With recent advances in ML research, the size of ML models 29 | has only increased over the years and data parallelism no longer serves all “scaling” purposes. 30 | 31 | There are multiple axes across which you can scale training and FairScale provides the following broad 32 | categories of solutions: 33 | 34 | 1. **Parallelism** → These techniques allow scaling of models by layer parallelism and tensor parallelism. 35 | 36 | 2. **Sharding Methods** → Memory and computation are usually trade-offs and in this category we attempt to achieve both low memory utilization and efficient computation by sharding model layers or parameters, optimizer state and gradients. 37 | 38 | 3. **Optimization** → This bucket deals with optimizing memory usage irrespective of the scale of the model, training without hyperparameter tuning and all other techniques that attempt to optimize training performance in some way. 39 | -------------------------------------------------------------------------------- /fairscale/README.md: -------------------------------------------------------------------------------- 1 | NOTE: 2 | 3 | The experimental and fair_dev submodules are not part of the fairscale public 4 | API. There can be breaking changes in them at anytime. 5 | -------------------------------------------------------------------------------- /fairscale/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ################################################################################ 7 | # Import most common subpackages 8 | # 9 | # NOTE: we don't maintain any public APIs in both experimental and fair_dev 10 | # sub-modules. Code in them are experimental or for developer only. They 11 | # can be changed, removed, anytime. 12 | ################################################################################ 13 | 14 | from typing import List 15 | 16 | from . import nn 17 | from .version import __version_tuple__ 18 | 19 | __version__ = ".".join([str(x) for x in __version_tuple__]) 20 | __all__: List[str] = [] 21 | -------------------------------------------------------------------------------- /fairscale/clib/fused_adam_cuda/compat.h: -------------------------------------------------------------------------------- 1 | #ifndef TORCH_CHECK 2 | #define TORCH_CHECK AT_CHECK 3 | #endif 4 | 5 | #define DATA_PTR data_ptr 6 | -------------------------------------------------------------------------------- /fairscale/clib/fused_adam_cuda/fused_adam_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declaration 4 | void fused_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, float optim_scale, at::Tensor& found_inf, int step, int mode, int bias_correction, float decay); 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("adam", &fused_adam_cuda, "Multi tensor Adam optimized CUDA implementation."); 8 | } 9 | -------------------------------------------------------------------------------- /fairscale/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ################################################################################ 7 | # Import most common subpackages 8 | ################################################################################ 9 | 10 | from typing import List 11 | 12 | # Don't import sub-modules as experimental stuff otherwise gets imported directly 13 | # when user does an `import fairscale`. This can cause experimental code's import 14 | # dependencies (like pygit2) to leak into the fairscale main dependency. 15 | 16 | __all__: List[str] = [] 17 | -------------------------------------------------------------------------------- /fairscale/experimental/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | from .mevo import BaselineSoftmaxNllLoss 9 | from .mevo import MemoryEfficientVocabOutput as MEVO 10 | from .offload import OffloadModel 11 | from .sync_batchnorm import SyncBatchNorm 12 | 13 | __all__: List[str] = [] 14 | -------------------------------------------------------------------------------- /fairscale/experimental/nn/ampnet_pipe/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | __all__: List[str] = [] 9 | -------------------------------------------------------------------------------- /fairscale/experimental/nn/ampnet_pipe/pipe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """The AMPnetPipe interface.""" 7 | 8 | from typing import Any 9 | 10 | from torch import nn 11 | from torch.optim.optimizer import Optimizer 12 | from torch.utils.data import DataLoader 13 | 14 | from fairscale.nn.pipe import AsyncPipe 15 | 16 | from .ampnet import AsyncAMPnetEventLoop 17 | 18 | __all__ = ["AMPnetPipe"] 19 | 20 | 21 | class AMPnetPipe(AsyncPipe): 22 | """ 23 | AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation 24 | which avoids the bubble issue, by using stale weights and gradients. 25 | The implementation closely follows the paper: https://arxiv.org/abs/1705.09786 26 | """ 27 | 28 | def __init__(self, **kwargs: Any) -> None: 29 | super().__init__(**kwargs) 30 | 31 | def interleave( 32 | self, 33 | lm_dataloader: DataLoader, 34 | criterion: nn.Module, 35 | optimizer: Optimizer, 36 | transform_logger_object: Any, 37 | min_update_interval: int = 1, 38 | weight_prediction: bool = False, 39 | ) -> None: 40 | 41 | partitions = self.partitions 42 | n = len(partitions) 43 | 44 | # AMPnet implementation doesn't handle skip_trackers! 45 | 46 | assert self.group 47 | rank = self.group.rank() 48 | 49 | transport = self.pipeline.transport 50 | checkpoint_stop = self.pipeline.checkpoint_stop 51 | ampnet_event_loop = AsyncAMPnetEventLoop( 52 | partitions, 53 | self.group, 54 | transport, 55 | min_update_interval, 56 | weight_prediction, 57 | checkpoint_stop, 58 | self.input_device, 59 | self.chunks, 60 | ) 61 | 62 | if rank == 0: 63 | ampnet_event_loop.event_loop_head_across_minibatches( 64 | lm_dataloader, criterion, optimizer, transform_logger_object 65 | ) 66 | elif self.final_stage: 67 | ampnet_event_loop.event_loop_tail_across_minibatches( 68 | lm_dataloader, criterion, optimizer, transform_logger_object 69 | ) 70 | else: 71 | ampnet_event_loop.event_loop_across_minibatches( 72 | lm_dataloader, criterion, optimizer, transform_logger_object 73 | ) 74 | -------------------------------------------------------------------------------- /fairscale/experimental/nn/data_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .gossip import SlowMoBaseAlgorithm, SlowMoDistributedDataParallel # noqa 7 | -------------------------------------------------------------------------------- /fairscale/experimental/nn/data_parallel/gossip/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .distributed import SlowMoBaseAlgorithm, SlowMoDistributedDataParallel 7 | from .gossiper import PushPull, PushSum 8 | from .graph_manager import ( 9 | DynamicBipartiteExponentialGraph, 10 | DynamicBipartiteLinearGraph, 11 | DynamicDirectedExponentialGraph, 12 | DynamicDirectedLinearGraph, 13 | GraphManager, 14 | NPeerDynamicDirectedExponentialGraph, 15 | RingGraph, 16 | ) 17 | from .mixing_manager import MixingManager, UniformMixing 18 | from .utils import communicate 19 | from .utils.cuda_metering import CudaEventRecorder 20 | -------------------------------------------------------------------------------- /fairscale/experimental/nn/data_parallel/gossip/mixing_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Mixing Manager Class 8 | 9 | :description: Class provides an API for dynamically selecting mixing weights 10 | for gossip 11 | """ 12 | 13 | from abc import ABC, abstractmethod 14 | from typing import Dict, Optional, Union 15 | 16 | import torch 17 | 18 | from .graph_manager import GraphManager 19 | 20 | 21 | class MixingManager(ABC): 22 | def __init__(self, graph: GraphManager, device: Optional[torch.device]) -> None: 23 | self.graph_manager = graph 24 | self.device = device 25 | 26 | def is_regular(self) -> bool: 27 | """ 28 | Whether there is bias accumulated in local entry of stationary 29 | distribution of mixing matrix 30 | """ 31 | return self.graph_manager.is_regular_graph() and self.is_uniform() 32 | 33 | @abstractmethod 34 | def is_uniform(self) -> bool: 35 | """Whether mixing weights are distributed uniformly over peers""" 36 | raise NotImplementedError 37 | 38 | @abstractmethod 39 | def get_mixing_weights(self, residual_adjusted: bool = True) -> Dict[Union[str, int], torch.Tensor]: 40 | """Create mixing weight dictionary using uniform allocation""" 41 | raise NotImplementedError 42 | 43 | 44 | class UniformMixing(MixingManager): 45 | def get_mixing_weights(self, residual_adjusted: bool = True) -> Dict[Union[str, int], torch.Tensor]: 46 | """Create mixing weight dictionary using uniform allocation""" 47 | mixing_weights: Dict[Union[str, int], torch.Tensor] = {} 48 | out_peers, _ = self.graph_manager.get_peers() 49 | 50 | w = torch.tensor([1.0 / (len(out_peers) + 1.0)], device=self.device) 51 | mixing_weights["lo"] = w.clone() 52 | w_op = w if not residual_adjusted else w / mixing_weights["lo"] 53 | mixing_weights["uniform"] = w_op.clone() 54 | for op in out_peers: 55 | mixing_weights[op] = w_op.clone() 56 | return mixing_weights 57 | 58 | def is_uniform(self) -> bool: 59 | return True 60 | -------------------------------------------------------------------------------- /fairscale/experimental/nn/data_parallel/gossip/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .helpers import ( 7 | MultiProcessAdapter, 8 | communicate, 9 | create_process_group, 10 | flatten_tensors, 11 | group_by_dtype, 12 | make_logger, 13 | unflatten_tensors, 14 | ) 15 | -------------------------------------------------------------------------------- /fairscale/experimental/nn/distributed_pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .graph import PipelineModulesGraph 7 | from .loss import DistributedLoss 8 | from .pipeline import DistributedPipeline 9 | -------------------------------------------------------------------------------- /fairscale/experimental/nn/distributed_pipeline/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass 7 | from typing import Generic, TypeVar 8 | 9 | ConsumerType = TypeVar("ConsumerType") 10 | 11 | 12 | @dataclass 13 | class DataConsumer(Generic[ConsumerType]): 14 | """A data class for representating a consumer of an output of a module.""" 15 | 16 | consumer: ConsumerType 17 | consumer_input_idx: int # indicating which input of the consumer module 18 | output_idx: int # indicating which output of the producer module 19 | -------------------------------------------------------------------------------- /fairscale/experimental/nn/distributed_pipeline/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Callable, Dict, Tuple 7 | 8 | from torch import nn 9 | from torch.distributed import rpc 10 | 11 | 12 | def _rloss(loss_func: Callable, input_rref: rpc.RRef, target_rref: rpc.RRef) -> rpc.RRef: 13 | return loss_func(input_rref.to_here(), target_rref.to_here()) 14 | 15 | 16 | def DistributedLoss(loss: nn.Module, *args: Tuple, **kwargs: Dict) -> Callable: 17 | loss_func = loss(*args, **kwargs) 18 | 19 | def dloss(input_rref: rpc.RRef, target_rref: rpc.RRef) -> rpc.RRef: 20 | return rpc.remote(input_rref.owner(), _rloss, args=(loss_func, input_rref, target_rref)) 21 | 22 | return dloss 23 | -------------------------------------------------------------------------------- /fairscale/experimental/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | from .dynamic_loss_scaler import DynamicLossScaler 9 | 10 | __all__: List[str] = [] 11 | -------------------------------------------------------------------------------- /fairscale/experimental/tooling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | __all__: List[str] = [] 9 | -------------------------------------------------------------------------------- /fairscale/experimental/wgit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import sys 7 | from typing import List 8 | 9 | # Check for user requirements before we import our code. 10 | try: 11 | import pygit2 12 | except ImportError: 13 | print("Error: please pip install pygit2 module to use wgit") 14 | sys.exit(1) 15 | 16 | try: 17 | import pgzip 18 | except ImportError: 19 | print("Error: please pip install pgzip module to use wgit") 20 | sys.exit(1) 21 | 22 | 23 | from .repo import Repo 24 | from .signal_sparsity import Algo, SignalSparsity, random_sparse_mask 25 | from .signal_sparsity_profiling import EnergyConcentrationProfile 26 | from .version import __version_tuple__ 27 | 28 | __version__ = ".".join([str(x) for x in __version_tuple__]) 29 | __all__: List[str] = [] 30 | -------------------------------------------------------------------------------- /fairscale/experimental/wgit/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .cli import main 7 | 8 | if __name__ == "__main__": 9 | main() 10 | -------------------------------------------------------------------------------- /fairscale/experimental/wgit/signal_sparsity_profiling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | import torch 9 | from torch import Tensor 10 | 11 | 12 | class EnergyConcentrationProfile: 13 | """Compute "energy" concentration level for a tensor 14 | 15 | Args: 16 | dim (int): 17 | The dimension to measure. 18 | top_k_percents (List[float]): 19 | List of percentage values. For each value, the `measure` 20 | function will compute and return the percentage of "energy" 21 | concentrated on that top-K percent of values in the dimension 22 | to measure. Note, this is the opposite of the sparsity percentage. 23 | """ 24 | 25 | def __init__(self, dim: int, top_k_percents: List[float]) -> None: 26 | assert isinstance(dim, int) 27 | self.dim = dim 28 | self.percents = [] 29 | last_p = 0.0 30 | for p in top_k_percents: 31 | assert isinstance(p, (int, float)) 32 | assert p > 0, p 33 | assert p <= 100, p 34 | assert p > last_p, f"p {p} should be larger than last_p {last_p}" 35 | self.percents.append(float(p)) 36 | last_p = p 37 | 38 | def measure(self, in_tensor: Tensor) -> List[Tensor]: 39 | """Compute the return the results 40 | 41 | Note, we want this function to be nonblocking and async. 42 | 43 | Returns: 44 | (List[Tensor]) 45 | List of tensors. Each tensor is a singleton float 46 | that contains the energy measure for that top_k_percent. 47 | """ 48 | assert in_tensor.is_floating_point(), in_tensor.dtype 49 | assert self.dim < len(in_tensor.shape), f"tensor shape {in_tensor.shape} not compatible with dim {self.dim}" 50 | dim_size = in_tensor.shape[self.dim] 51 | abs_tensor = in_tensor.abs() 52 | full_energy = abs_tensor.sum() 53 | return_tensors = [] 54 | for p in self.percents: 55 | k = max(1, round(p / 100 * dim_size)) 56 | abs_top_k_values, _ = abs_tensor.topk(k, dim=self.dim) 57 | return_tensors.append(abs_top_k_values.sum() / full_energy) 58 | return return_tensors 59 | 60 | def measure_fft(self, in_tensor: Tensor) -> List[Tensor]: 61 | """Like measure, but do it in FFT frequency domain.""" 62 | return self.measure(torch.fft.fft(in_tensor, dim=self.dim).real) 63 | -------------------------------------------------------------------------------- /fairscale/experimental/wgit/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | 8 | 9 | class ExitCode(Enum): 10 | """Collections of the Exit codes as an Enum class""" 11 | 12 | CLEAN = 0 13 | FILE_EXISTS_ERROR = 1 14 | FILE_DOES_NOT_EXIST_ERROR = 2 15 | -------------------------------------------------------------------------------- /fairscale/experimental/wgit/version.py: -------------------------------------------------------------------------------- 1 | __version_tuple__ = (0, 0, 1) 2 | -------------------------------------------------------------------------------- /fairscale/fair_dev/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | __all__: List[str] = [] 9 | -------------------------------------------------------------------------------- /fairscale/fair_dev/common_paths.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | "Common cache root for torchvision.datasets and others." 7 | DATASET_CACHE_ROOT = "cached_datasets" 8 | -------------------------------------------------------------------------------- /fairscale/fair_dev/testing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | __all__: List[str] = [] 9 | -------------------------------------------------------------------------------- /fairscale/fair_dev/testing/golden_testing_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | """ Golden data used in unit tests. """ 8 | 9 | adascale_test_data = [ 10 | # "input" value is a list of input tensors for micro-batch/rank 0 and micro-batch/rank 1. 11 | { 12 | "input": [[1.0, 0], [0, 1.0]], 13 | "expected_gain": 4.0 / 3, 14 | "expected_grad": [[0.5, 0.5], [0.5, 0.5]], 15 | "expected_bias_grad": [1.0, 1.0], 16 | }, 17 | { 18 | "input": [[1.0, 1.0], [1.0, 1.0]], 19 | "expected_gain": 1.0000001249999846, 20 | "expected_grad": [[1.0, 1.0], [1.0, 1.0]], 21 | "expected_bias_grad": [1.0, 1.0], 22 | }, 23 | { 24 | "input": [[-1.0, 1.0], [1.0, -1.0]], 25 | "expected_gain": 2.0, 26 | "expected_grad": [[0.0, 0.0], [0.0, 0.0]], 27 | "expected_bias_grad": [1.0, 1.0], 28 | }, 29 | { 30 | "input": [[1.0, 4.0], [5.0, 0.5]], 31 | "expected_gain": 1.4688796680497926, 32 | "expected_grad": [[3.0, 2.25], [3.0, 2.25]], 33 | "expected_bias_grad": [1.0, 1.0], 34 | }, 35 | { 36 | "input": [[-0.2, 3.0], [5.0, 0.5]], 37 | "expected_gain": 1.8472893901708, 38 | "expected_grad": [[2.4000000953674316, 1.75], [2.4000000953674316, 1.75]], 39 | "expected_bias_grad": [1.0, 1.0], 40 | }, 41 | # "inputs" to trigger multiple iteration tests, which make sure the 42 | # smoothing factor calculation is also covered. 43 | { 44 | "inputs": [[[-0.2, 3.3], [5.2, 0.7]], [[1.0, 4.0], [3.1, 0.1]]], 45 | "expected_gain": 1.6720968158031417, 46 | "expected_grad": [[2.049999952316284, 2.049999952316284], [2.049999952316284, 2.049999952316284]], 47 | "expected_bias_grad": [1.0, 1.0], 48 | }, 49 | ] 50 | 51 | corr_mean_test_data = [ 52 | { 53 | "inputs": [ 54 | [[1.0, 0.0, 2.0], [2.0, 0.0, 1.0]], 55 | [[0.0, 1.0, 2.0], [2.0, 1.0, 0]], 56 | [[3.0, 1.0, 2.0], [2.0, 1.0, -1.0]], 57 | ], 58 | "expected_grad": [[1.5, 0.0, 1.5], [1.0, 1.0, 1.0], [2.5, 1.0, 0.5]], 59 | # expected pearson correlation of two micro-batches 60 | "expected_corr": [0.5, -1.0, 0.327327], 61 | "expected_cos_similarity": [float("nan"), 0.8165, 0.8433], 62 | } 63 | ] 64 | -------------------------------------------------------------------------------- /fairscale/fair_dev/testing/testing_memory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ Shared functions related to testing GPU memory sizes. """ 7 | 8 | import gc 9 | from typing import Tuple 10 | 11 | import torch 12 | 13 | 14 | def find_tensor_by_shape(target_shape: Tuple, only_param: bool = True) -> bool: 15 | """Find a tensor from the heap 16 | 17 | Args: 18 | target_shape (tuple): 19 | Tensor shape to locate. 20 | only_param (bool): 21 | Only match Parameter type (e.g. for weights). 22 | 23 | Returns: 24 | (bool): 25 | Return True if found. 26 | """ 27 | for obj in gc.get_objects(): 28 | try: 29 | # Only need to check parameter type objects if asked. 30 | if only_param and "torch.nn.parameter.Parameter" not in str(type(obj)): 31 | continue 32 | if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)): 33 | if obj.shape == target_shape: 34 | return True 35 | except Exception as e: 36 | pass 37 | return False 38 | -------------------------------------------------------------------------------- /fairscale/internal/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .version import * 7 | -------------------------------------------------------------------------------- /fairscale/internal/object.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import pickle 7 | from typing import Any 8 | 9 | import torch 10 | 11 | 12 | def pyobject_to_tensor(obj: Any, fixed_buffer_size: int = 0) -> torch.Tensor: 13 | pickled = pickle.dumps(obj) 14 | result: torch.Tensor = torch.ByteTensor(bytearray(pickled)) 15 | if fixed_buffer_size: 16 | delta = fixed_buffer_size - len(result) 17 | if delta < 0: 18 | raise ValueError( 19 | f"message too big to send, increase `fixed_buffer_size`? - {len(result)} > {fixed_buffer_size}" 20 | ) 21 | elif delta > 0: 22 | result = torch.cat((result, torch.zeros(delta, dtype=torch.uint8))) 23 | 24 | return result 25 | 26 | 27 | def tensor_to_pyobject(tensor: torch.Tensor) -> Any: 28 | nparray = tensor.cpu().numpy() 29 | return pickle.loads(nparray.tobytes()) 30 | -------------------------------------------------------------------------------- /fairscale/internal/params.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import collections.abc as abc 7 | from dataclasses import dataclass 8 | from math import inf 9 | from typing import Any, Callable, Dict, List, Optional 10 | 11 | import torch 12 | import torch.distributed as dist 13 | 14 | 15 | @dataclass 16 | class Workhandle: 17 | handle: Any 18 | callback: Optional[Callable] = None 19 | 20 | 21 | def get_global_rank(group: Any, rank: int) -> int: 22 | if group is dist.group.WORLD: 23 | return rank 24 | 25 | return dist.distributed_c10d._get_global_rank(group, rank) 26 | 27 | 28 | # Credits: classy_vision/generic/distributed_util.py 29 | def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any: 30 | """ 31 | Recursively searches lists, tuples, dicts and copies tensors to device if 32 | possible. Non-tensor values are passed as-is in the result. 33 | 34 | NOTE: These are all copies, so if there are two objects that reference 35 | the same object, then after this call, there will be two different objects 36 | referenced on the device. 37 | """ 38 | 39 | if isinstance(value, torch.Tensor): 40 | return value.to(device, non_blocking=non_blocking) 41 | 42 | if isinstance(value, (list, tuple)): 43 | values = [] 44 | for val in value: 45 | values.append(recursive_copy_to_device(val, non_blocking=non_blocking, device=device)) 46 | 47 | return values if isinstance(value, list) else tuple(values) 48 | 49 | if isinstance(value, abc.Mapping): 50 | device_val: Dict[str, Any] = {} 51 | for key, val in value.items(): 52 | device_val[key] = recursive_copy_to_device(val, non_blocking=non_blocking, device=device) 53 | 54 | return device_val 55 | 56 | return value 57 | 58 | 59 | def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor: 60 | r"""Calculate gradient norm of an iterable of parameters. 61 | Returns: 62 | Total norm of the parameters (viewed as a single vector). 63 | """ 64 | if isinstance(parameters, torch.Tensor): 65 | parameters = [parameters] 66 | parameters = list(filter(lambda par: par.grad is not None, parameters)) 67 | 68 | if len(parameters) == 0: 69 | return torch.tensor(0.0) 70 | p = float(p) 71 | if p == inf: 72 | local_norm = max(par.grad.detach().abs().max() for par in parameters) # type: ignore 73 | else: 74 | # Compute the norm in full precision no matter what 75 | local_norm = torch.norm(torch.stack([torch.norm(par.grad.detach(), p, dtype=torch.float32) for par in parameters]), p).to(dtype=parameters[0].dtype) # type: ignore 76 | return local_norm 77 | -------------------------------------------------------------------------------- /fairscale/internal/state_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Useful functions for manipulating state_dicts.""" 7 | 8 | from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union 9 | 10 | from torch import Tensor, nn 11 | 12 | if TYPE_CHECKING: 13 | from collections import OrderedDict # noqa: F401 14 | 15 | 16 | def find_module_instances(module: nn.Module, search_class: Type[nn.Module]) -> List[Tuple[str, nn.Module]]: 17 | """ 18 | Find all occurrences of a given search_class among the given Modules's 19 | children and return the corresponding paths in the same format as 20 | state_dicts. 21 | 22 | Usage:: 23 | 24 | net = nn.Sequential( 25 | nn.Linear(1, 1), 26 | nn.ModuleDict({"ln": nn.LayerNorm(1), "linear": nn.Linear(1, 1)}), 27 | nn.LayerNorm(1) 28 | ) 29 | 30 | >>> find_module_instances(net, nn.LayerNorm) 31 | [('1.ln.', LayerNorm((1,), eps=1e-05, elementwise_affine=True)), ('2.', LayerNorm((1,), eps=1e-05, elementwise_affine=True))] 32 | >>> find_module_instances(net, nn.Dropout) 33 | [] 34 | >>> find_module_instances(net, nn.Sequential) 35 | [('', Sequential( 36 | (0): Linear(in_features=1, out_features=1, bias=True) 37 | (1): ModuleDict( 38 | (ln): LayerNorm((1,), eps=1e-05, elementwise_affine=True) 39 | (linear): Linear(in_features=1, out_features=1, bias=True) 40 | ) 41 | (2): LayerNorm((1,), eps=1e-05, elementwise_affine=True) 42 | ))] 43 | """ 44 | paths = [] 45 | 46 | def add_paths_(module: nn.Module, prefix: str = "") -> None: 47 | if isinstance(module, search_class): 48 | paths.append((prefix, module)) 49 | for name, child in module.named_children(): 50 | add_paths_(child, prefix + name + ".") 51 | 52 | add_paths_(module) 53 | return paths 54 | 55 | 56 | def replace_by_prefix_( 57 | state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], old_prefix: str, new_prefix: str 58 | ) -> None: 59 | """ 60 | Replace all keys that match a given old_prefix with a new_prefix (in-place). 61 | 62 | Usage:: 63 | 64 | state_dict = {"layer.xyz": torch.tensor(1)} 65 | replace_by_prefix_(state_dict, "layer.", "module.layer.") 66 | assert state_dict == {"module.layer.xyz": torch.tensor(1)} 67 | """ 68 | if old_prefix == new_prefix: 69 | raise ValueError("old_prefix and new_prefix must be distinct") 70 | for key in list(state_dict.keys()): 71 | if not key.startswith(old_prefix): 72 | continue 73 | new_key = new_prefix + key[len(old_prefix) :] 74 | state_dict[new_key] = state_dict[key] 75 | del state_dict[key] 76 | -------------------------------------------------------------------------------- /fairscale/internal/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | import re 8 | from typing import List, Tuple 9 | 10 | import torch 11 | 12 | __all__: List[str] = ["torch_version"] 13 | 14 | _logged = False 15 | 16 | def torch_version(version: str = torch.__version__) -> Tuple[int, ...]: 17 | numbering = re.search(r"^(\d+).(\d+).(\d+)([^\+]*)(\+\S*)?$", version) 18 | if not numbering: 19 | return tuple() 20 | # Catch torch version if run against internal pre-releases, like `1.8.0a0fb`, 21 | global _logged 22 | if numbering.group(4) and not _logged: 23 | # Two options here: 24 | # - either skip this version (minor number check is not relevant) 25 | # - or check that our codebase is not broken by this ongoing development. 26 | 27 | # Assuming that we're interested in the second use-case more than the first, 28 | # return the pre-release or dev numbering 29 | logging.warning(f"Pytorch pre-release version {version} - assuming intent to test it") 30 | _logged = True 31 | 32 | return tuple(int(numbering.group(n)) for n in range(1, 4)) 33 | -------------------------------------------------------------------------------- /fairscale/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | import torch.distributed as dist 9 | 10 | from .checkpoint import checkpoint_wrapper 11 | from .data_parallel import FullyShardedDataParallel 12 | 13 | if dist.is_available(): 14 | # Prevent import failure if dist is not available. #1057 15 | from .data_parallel import ShardedDataParallel 16 | from .moe import MOELayer, Top2Gate 17 | from .pipe import Pipe, PipeRPCWrapper 18 | 19 | from .misc import FlattenParamsWrapper 20 | from .wrap import auto_wrap, config_auto_wrap_policy, default_auto_wrap_policy, enable_wrap, wrap 21 | 22 | __all__: List[str] = [] 23 | -------------------------------------------------------------------------------- /fairscale/nn/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | from .checkpoint_activations import checkpoint_wrapper, is_checkpointing, is_recomputing 9 | 10 | __all__: List[str] = [] 11 | -------------------------------------------------------------------------------- /fairscale/nn/checkpoint/checkpoint_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | import torch 9 | from torch import Tensor, nn 10 | from torch.nn.modules.batchnorm import _BatchNorm 11 | 12 | 13 | def patch_batchnorm(module: nn.Module) -> List: 14 | """Patch all batchnorm instances (1d, 2d, 3d, sync_bn, etc.) of a module 15 | so that they don't track running stats when torch.no_grad() is enabled. 16 | 17 | This is important in activation checkpointing to ensure stats are tracked 18 | correctly as if there were no activation checkpointing. The reason is 19 | that activation checkpointing runs the forward function twice, first 20 | with torch.no_grad(), then with torch.grad(). 21 | 22 | Args: 23 | module (nn.Module): 24 | The module to be patched in-place. 25 | 26 | Returns: 27 | (list): 28 | A list of hook handles, late can be freed. 29 | """ 30 | 31 | def pre_forward(module: _BatchNorm, input: Tensor) -> None: 32 | if torch.is_grad_enabled(): 33 | return 34 | module._track_running_stats_backup = module.track_running_stats 35 | module.track_running_stats = False 36 | 37 | def post_forward(module: _BatchNorm, input: Tensor, result: Tensor) -> None: 38 | if torch.is_grad_enabled(): 39 | return 40 | module.track_running_stats = module._track_running_stats_backup 41 | 42 | hooks = [] 43 | for name, child in module.named_modules(): 44 | # _BatchNorm is base for bn1d, bn2d, bn3d and sync_bn, apex_sync_bn, etc. 45 | if isinstance(child, _BatchNorm) and not hasattr(child, "disable_patch_batchnorm"): 46 | # Register the pre/post hooks. 47 | pre_handle = child.register_forward_pre_hook(pre_forward) 48 | post_handle = child.register_forward_hook(post_forward) 49 | hooks += [pre_handle, post_handle] 50 | return hooks 51 | -------------------------------------------------------------------------------- /fairscale/nn/data_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | import torch.distributed as dist 9 | 10 | from .fully_sharded_data_parallel import ( 11 | FullyShardedDataParallel, 12 | TrainingState, 13 | auto_wrap_bn, 14 | get_fsdp_instances, 15 | no_pre_load_state_dict_hook, 16 | ) 17 | 18 | if dist.is_available(): 19 | # Prevent import failure if dist is not available. #1057 20 | from .sharded_ddp import ShardedDataParallel 21 | 22 | __all__: List[str] = [] 23 | -------------------------------------------------------------------------------- /fairscale/nn/misc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | # TODO(anj-s): Remove this once we have deprecated fairscale.nn.misc.checkpoint_wrapper path 9 | # in favor of fairscale.nn.checkpoint.checkpoint_wrapper. 10 | from fairscale.nn.checkpoint import checkpoint_wrapper 11 | 12 | from .flatten_params_wrapper import FlattenParamsWrapper, _enable_pre_load_state_dict_hook 13 | from .param_bucket import GradBucket, ParamBucket 14 | 15 | __all__: List[str] = [] 16 | -------------------------------------------------------------------------------- /fairscale/nn/model_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | from .cross_entropy import vocab_parallel_cross_entropy 9 | from .initialize import ( 10 | destroy_model_parallel, 11 | get_data_parallel_group, 12 | get_data_parallel_rank, 13 | get_data_parallel_world_size, 14 | get_model_parallel_group, 15 | get_model_parallel_rank, 16 | get_model_parallel_src_rank, 17 | get_model_parallel_world_size, 18 | get_pipeline_parallel_group, 19 | get_pipeline_parallel_ranks, 20 | initialize_model_parallel, 21 | ) 22 | from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding 23 | from .mappings import copy_to_model_parallel_region, gather_from_model_parallel_region 24 | from .random import get_cuda_rng_tracker, model_parallel_cuda_manual_seed 25 | 26 | __all__: List[str] = [] 27 | -------------------------------------------------------------------------------- /fairscale/nn/moe/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | from .moe_layer import MOELayer 9 | from .top2gate import Top2Gate 10 | 11 | __all__: List[str] = [] 12 | -------------------------------------------------------------------------------- /fairscale/nn/pipe/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | """A Pipe implementation in PyTorch.""" 21 | from .async_pipe import AsyncPipe 22 | from .checkpoint import is_checkpointing, is_recomputing 23 | from .pipe import Pipe 24 | from .rpc import PipeRPCWrapper 25 | from .types import LazyModule 26 | 27 | __all__ = ["Pipe", "is_checkpointing", "is_recomputing", "LazyModule"] 28 | -------------------------------------------------------------------------------- /fairscale/nn/pipe/balance/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/fairscale/nn/pipe/balance/py.typed -------------------------------------------------------------------------------- /fairscale/nn/pipe/dependency.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | """Arbitrary dependency between two autograd lanes.""" 21 | from typing import List, Tuple 22 | 23 | import torch 24 | from torch import Tensor 25 | 26 | from .phony import get_phony 27 | 28 | __all__: List[str] = [] 29 | 30 | 31 | def fork(input: Tensor) -> Tuple[Tensor, Tensor]: 32 | """Branches out from an autograd lane of the given tensor.""" 33 | if torch.is_grad_enabled() and input.requires_grad: 34 | input, phony = Fork.apply(input) 35 | else: 36 | phony = get_phony(input.device, requires_grad=False) 37 | 38 | return input, phony 39 | 40 | 41 | class Fork(torch.autograd.Function): 42 | @staticmethod 43 | def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore 44 | phony = get_phony(input.device, requires_grad=False) 45 | return input.detach(), phony.detach() 46 | 47 | @staticmethod 48 | def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore 49 | return grad_input 50 | 51 | 52 | def join(input: Tensor, phony: Tensor) -> Tensor: 53 | """Merges two autograd lanes.""" 54 | if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): 55 | input = Join.apply(input, phony) 56 | 57 | return input 58 | 59 | 60 | class Join(torch.autograd.Function): 61 | @staticmethod 62 | def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore 63 | return input.detach() 64 | 65 | @staticmethod 66 | def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore 67 | return grad_input, None 68 | -------------------------------------------------------------------------------- /fairscale/nn/pipe/phony.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | """Provides phony for arbitrary dependency in a autograd graph.""" 21 | from typing import Dict, List, Tuple 22 | 23 | import torch 24 | from torch import Tensor 25 | 26 | from .stream import default_stream, use_stream 27 | 28 | __all__: List[str] = [] 29 | 30 | 31 | _phonies: Dict[Tuple[torch.device, bool], Tensor] = {} 32 | 33 | 34 | def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: 35 | """Gets a phony. Phony is tensor without space. It is useful to make 36 | arbitrary dependency in a autograd graph because it doesn't require any 37 | gradient accumulation. 38 | 39 | .. note:: 40 | 41 | Phonies for each device are cached. If an autograd function gets a phony 42 | internally, the phony must be detached to be returned. Otherwise, the 43 | autograd engine will mutate the cached phony in-place:: 44 | 45 | class Phonify(torch.autograd.Function): 46 | @staticmethod 47 | def forward(ctx, input): 48 | phony = get_phony(input.device, requires_grad=False) 49 | return phony.detach() # detach() is necessary. 50 | 51 | """ 52 | key = (device, requires_grad) 53 | 54 | try: 55 | phony = _phonies[key] 56 | except KeyError: 57 | with use_stream(default_stream(device)): 58 | # Creating phony with size 1 instead of zero, since currently 59 | # tensorpipe does not work with tensors of size zero. 60 | phony = torch.empty(1, device=device, requires_grad=requires_grad) 61 | 62 | _phonies[key] = phony 63 | 64 | return phony 65 | -------------------------------------------------------------------------------- /fairscale/nn/pipe/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/fairscale/nn/pipe/py.typed -------------------------------------------------------------------------------- /fairscale/nn/pipe/skip/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | """Supports efficiency with skip connections.""" 21 | from .namespace import Namespace 22 | from .skippable import pop, skippable, stash, verify_skippables 23 | 24 | __all__ = ["skippable", "stash", "pop", "verify_skippables", "Namespace"] 25 | -------------------------------------------------------------------------------- /fairscale/nn/pipe/skip/namespace.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | """Provides isolated namespace of skip tensors.""" 21 | import abc 22 | from functools import total_ordering 23 | from typing import Any 24 | import uuid 25 | 26 | __all__ = ["Namespace"] 27 | 28 | 29 | @total_ordering 30 | class Namespace(metaclass=abc.ABCMeta): 31 | """Namespace for isolating skip tensors used by :meth:`isolate() 32 | `. 33 | """ 34 | 35 | __slots__ = ("id",) 36 | 37 | def __init__(self) -> None: 38 | self.id = uuid.uuid4() 39 | 40 | def __repr__(self) -> str: 41 | return f"" 42 | 43 | def __hash__(self) -> int: 44 | return hash(self.id) 45 | 46 | # Namespaces should support ordering, since SkipLayout will sort tuples 47 | # including a namespace. But actual order between namespaces is not 48 | # important. That's why they are ordered by version 4 UUID which generates 49 | # random numbers. 50 | def __lt__(self, other: Any) -> bool: 51 | if isinstance(other, Namespace): 52 | return self.id < other.id 53 | return False 54 | 55 | def __eq__(self, other: Any) -> bool: 56 | if isinstance(other, Namespace): 57 | return self.id == other.id 58 | return False 59 | 60 | 61 | # 'None' is the default namespace, 62 | # which means that 'isinstance(None, Namespace)' is 'True'. 63 | Namespace.register(type(None)) 64 | -------------------------------------------------------------------------------- /fairscale/nn/pipe/types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass 7 | from typing import Any, Callable, List, Optional, Tuple, Union 8 | 9 | import torch 10 | from torch import Tensor, nn 11 | 12 | ACTIVATIONS_GRADS_QUEUE = 0 13 | SKIP_TENSOR_QUEUE = 1 14 | PORTAL_QUEUE = 2 15 | EVENT_LOOP_QUEUE = 3 16 | EVENT_LOOP_ACTIVATIONS_QUEUE = 4 17 | EVENT_LOOP_GRADIENTS_QUEUE = 5 18 | MESSAGE_GENERATION_START = 6 19 | 20 | MessageGeneration = MESSAGE_GENERATION_START 21 | 22 | Tensors = Tuple[Tensor, ...] 23 | TensorOrTensors = Union[Tensor, Tensors] 24 | 25 | InputDevice = Union[None, int, str, torch.device] 26 | 27 | 28 | class LazyModule: 29 | def __init__(self, function: Callable[[], nn.Module]): 30 | self.function = function 31 | 32 | def __call__(self) -> nn.Module: 33 | return self.function() 34 | 35 | 36 | @dataclass(init=False) 37 | class PipeMessage: 38 | src: int 39 | dest: int 40 | queue_name: int 41 | args: Any 42 | tensors: Tensors 43 | tensor_shapes: List[torch.Size] 44 | tensor_dtypes: List[torch.dtype] 45 | tag: int = 0 46 | 47 | def __init__( 48 | self, 49 | src: int, 50 | dest: int, 51 | queue_name: int, 52 | args: Any = None, 53 | tensors: Optional[Tensors] = None, 54 | tensor_count: int = 0, 55 | ): 56 | self.src = src 57 | self.dest = dest 58 | self.queue_name = queue_name 59 | self.args = args 60 | self.tensors = tensors or tuple() 61 | self.tensor_shapes = [] 62 | self.tensor_dtypes = [] 63 | 64 | global MessageGeneration 65 | self.tag = MessageGeneration 66 | if tensors is None: 67 | MessageGeneration += tensor_count 68 | else: 69 | MessageGeneration += len(self.tensors) 70 | -------------------------------------------------------------------------------- /fairscale/nn/wrap/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List 7 | 8 | from .auto_wrap import auto_wrap, config_auto_wrap_policy, default_auto_wrap_policy, enable_wrap, wrap 9 | 10 | __all__: List[str] = [] 11 | -------------------------------------------------------------------------------- /fairscale/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | :mod:`fairscale.optim` is a package implementing various torch optimization algorithms. 8 | """ 9 | import logging 10 | from typing import List 11 | 12 | from .adascale import AdaScale, AdaScaleWrapper 13 | from .oss import OSS 14 | 15 | try: 16 | from .adam import Adam, Precision 17 | except ImportError: # pragma: no cover 18 | pass # pragma: no cover 19 | try: 20 | from .grad_scaler import GradScaler 21 | except ImportError: 22 | logging.warning("Torch AMP is not available on this platform") 23 | 24 | __all__: List[str] = [] 25 | -------------------------------------------------------------------------------- /fairscale/version.py: -------------------------------------------------------------------------------- 1 | __version_tuple__ = (0, 4, 13) 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools >= 40.6.2", 4 | "wheel >= 0.30.0" 5 | ] 6 | build-backend = "setuptools.build_meta:__legacy__" 7 | 8 | [tool.black] 9 | line-length = 120 10 | exclude = ''' 11 | /( 12 | \.git 13 | | \.mypy_cache 14 | | \.pytest_cache 15 | | build 16 | | docs 17 | | stubs 18 | )/ 19 | ''' 20 | 21 | [tool.isort] 22 | line_length = 120 23 | multi_line_output = 3 24 | include_trailing_comma = true 25 | force_grid_wrap = 0 26 | use_parentheses = true 27 | skip_glob = ["build/*", "stubs/*"] 28 | # Don't split "import" and "from". 29 | force_sort_within_sections = true 30 | -------------------------------------------------------------------------------- /release_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | from typing import Tuple 4 | 5 | from setup import find_version 6 | 7 | 8 | def get_next_version(release_type) -> Tuple[Tuple[int, int, int], str, str]: 9 | current_ver = find_version("fairscale/version.py") 10 | version_list = [int(x) for x in current_ver.strip("'").split(".")] 11 | major, minor, patch = version_list[0], version_list[1], version_list[2] 12 | if release_type == "patch": 13 | patch += 1 14 | elif release_type == "minor": 15 | minor += 1 16 | patch = 0 17 | elif release_type == "major": 18 | major += 1 19 | minor = patch = 0 20 | else: 21 | raise ValueError("Incorrect release type specified. Acceptable types are major, minor and patch.") 22 | 23 | new_version_tuple = (major, minor, patch) 24 | new_version_str = ".".join([str(x) for x in new_version_tuple]) 25 | new_tag_str = "v" + new_version_str 26 | return new_version_tuple, new_version_str, new_tag_str 27 | 28 | 29 | def update_version(new_version_tuple) -> None: 30 | """ 31 | given the current version, update the version to the 32 | next version depending on the type of release. 33 | """ 34 | 35 | with open("fairscale/version.py", "r") as reader: 36 | current_version_data = reader.read() 37 | 38 | # for line in current_version_data: 39 | version_match = re.search(r"^__version_tuple__ ", current_version_data) 40 | 41 | if version_match: 42 | new_version_data = "__version_tuple__ = %s\n" % str(new_version_tuple) 43 | current_version_data = current_version_data.replace(version_match.string, new_version_data) 44 | 45 | with open("fairscale/version.py", "w") as writer: 46 | writer.write(current_version_data) 47 | else: 48 | raise RuntimeError("__version_tuple__ not found in version.py") 49 | 50 | 51 | def main(args): 52 | if args.release_type in ["major", "minor", "patch"]: 53 | new_version_tuple, new_version, new_tag = get_next_version(args.release_type) 54 | else: 55 | raise ValueError("Incorrect release type specified") 56 | 57 | if args.update_version: 58 | update_version(new_version_tuple) 59 | 60 | print(new_version, new_tag) 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser(description="Versioning utils") 65 | parser.add_argument("--release-type", type=str, required=True, help="type of release = major/minor/patch") 66 | parser.add_argument( 67 | "--update-version", action="store_true", required=False, help="updates the version in fairscale/version.py" 68 | ) 69 | 70 | args = parser.parse_args() 71 | main(args) 72 | -------------------------------------------------------------------------------- /requirements-benchmarks.txt: -------------------------------------------------------------------------------- 1 | # Bring in everything that tests depends on. 2 | -r requirements-dev.txt 3 | 4 | # Benchmark dependencies. 5 | torchtext == 0.6.0 6 | torchvision >= 0.6.0 7 | timm == 0.3.4 8 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Core deps. 2 | -r requirements.txt 3 | 4 | # Tools for static checking. 5 | # - flake8-annotations is needed to avoid F811 error with overload 6 | # function typing with mypy. 7 | # - if you change versions below, please make sure it is in-sync with 8 | # .pre-commit-config.yaml for pre-commit. 9 | black == 22.3.0 10 | flake8 == 4.0.1 11 | flake8-annotations == 2.7.0 12 | isort == 5.10.1 13 | mypy == 0.910 14 | pre-commit >= 2.15.0 15 | 16 | # Tools for unit tests & coverage. 17 | pytest == 7.0.0 18 | pytest-cov == 3.0.0 19 | pytest-timeout == 2.1.0 20 | remote-pdb >= 2.1.0 21 | parameterized >= 0.8.1 22 | 23 | # Tools for testing docs 24 | docutils == 0.17 25 | 26 | # For torch.cuda.list_gpu_processes() 27 | pynvml == 8.0.4 28 | 29 | # For mypy typing. It is important to have a fixed version. Otherwise, you 30 | # may run into mypy errors out differently for different versions. 31 | numpy == 1.22.0 32 | 33 | # For layerwise gradient scaler 34 | scikit-learn == 1.1.3 35 | 36 | # For weigit. These are actually user requirements, not developer requirements. 37 | # However, due to the experimental nature of weigit, we don't expose to the 38 | # general users of fairscale yet. We check for them in weigit's init code. 39 | pygit2==1.11.1 40 | pgzip==0.3.1 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # FairScale should only depends on torch, not things higher level than torch. 2 | # Note1: setup.py automatically reads this file to setup install-time dependencies. 3 | # Note2: we use >= in this file but == in requirements-dev.txt for determinism 4 | # in testing. 5 | # Note3: update docs/requirements.txt if you change this file. 6 | torch >= 1.8.0 7 | numpy >= 1.22.0 8 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # pytest 3 | # ----------------------------------------------------------------------------- 4 | 5 | [tool:pytest] 6 | testpaths = tests 7 | addopts = --verbose 8 | junit_family = xunit2 9 | 10 | [aliases] 11 | test = pytest 12 | 13 | # ----------------------------------------------------------------------------- 14 | # coverage 15 | # ----------------------------------------------------------------------------- 16 | 17 | [coverage:report] 18 | # Coverage couldn't detect backward functions because they are called by C++. 19 | # Append "# pragma: no cover" to the definition lines to ignore them. 20 | # https://www.janfreyberg.com/blog/2019-04-01-testing-pytorch-functions/ 21 | exclude_lines = pragma: no cover 22 | 23 | # ----------------------------------------------------------------------------- 24 | # flake8 25 | # ----------------------------------------------------------------------------- 26 | 27 | [flake8] 28 | select = B,C,E,F,P,T4,W,B9 29 | max-line-length = 120 30 | # C408 ignored because we like the dict keyword argument syntax 31 | # E501 is not flexible enough, we're using B950 instead 32 | ignore = 33 | E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, 34 | per-file-ignores = __init__.py: F401 35 | exclude = build,*.pyi,.git 36 | 37 | # ----------------------------------------------------------------------------- 38 | # mypy 39 | # ----------------------------------------------------------------------------- 40 | 41 | # Docs for mypy config: https://mypy.readthedocs.io/en/latest/config_file.html 42 | [mypy] 43 | mypy_path = ./stubs/ 44 | follow_imports = normal 45 | plugins = numpy.typing.mypy_plugin 46 | 47 | # This project must be strictly typed. 48 | [mypy-fairscale.*] 49 | check_untyped_defs = true 50 | disallow_untyped_defs = true 51 | disallow_untyped_calls = true 52 | disallow_untyped_decorators = true 53 | disallow_incomplete_defs = true 54 | warn_unused_ignores = true 55 | 56 | [mypy-fairscale.experimental.nn.distributed_pipeline.trace] 57 | ignore_errors = True 58 | 59 | [mypy-fairscale.experimental.nn.auto_shard] 60 | ignore_errors = True 61 | 62 | [mypy-benchmarks.*] 63 | ignore_errors = True 64 | 65 | # Ignore missing imports from untyped third-party libraries. 66 | [mypy-torch.*,torchvision.*,setuptools.*,pytest.*] 67 | ignore_missing_imports = true 68 | -------------------------------------------------------------------------------- /stubs/torch/autograd/grad_mode.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any, Callable, Optional, TypeVar 4 | 5 | # Used for annotating the decorator usage of 'no_grad' and 'enable_grad'. 6 | # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators 7 | FuncType = Callable[..., Any] 8 | T = TypeVar('T', bound=FuncType) 9 | 10 | class no_grad: 11 | def __enter__(self) -> None: ... 12 | def __exit__(self, *args: Any) -> Optional[bool]: ... 13 | def __call__(self, func: T) -> T: ... 14 | 15 | class enable_grad: 16 | def __enter__(self) -> None: ... 17 | def __exit__(self, *args: Any) -> Optional[bool]: ... 18 | def __call__(self, func: T) -> T: ... 19 | 20 | class set_grad_enabled: 21 | def __init__(self, mode: bool) -> None: ... 22 | def __enter__(self) -> None: ... 23 | def __exit__(self, *args: Any) -> Optional[bool]: ... 24 | -------------------------------------------------------------------------------- /stubs/torch/autograd/profiler.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any, ContextManager, Optional 4 | 5 | class record_function(ContextManager[None]): 6 | def __init__(self, name: str) -> None: ... 7 | def __enter__(self) -> None: ... 8 | def __exit__(self, *args: Any) -> Optional[bool]: ... 9 | -------------------------------------------------------------------------------- /stubs/torch/backends/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | #MODIFIED BY TORCHGPIPE 4 | from . import cudnn 5 | #END 6 | -------------------------------------------------------------------------------- /stubs/torch/backends/cudnn.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | #MODIFIED BY TORCHGPIPE 4 | def version() -> int: ... 5 | #END 6 | deterministic : bool 7 | benchmark: bool 8 | 9 | -------------------------------------------------------------------------------- /stubs/torch/cuda/amp/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any, Generator 4 | 5 | from .grad_scaler import GradScaler as GradScaler 6 | 7 | class autocast: 8 | def __init__(self, enabled=True) -> None: ... 9 | def __enter__(self) -> None: ... 10 | def __exit__(self, *args: Any) -> None: ... 11 | -------------------------------------------------------------------------------- /stubs/torch/cuda/amp/grad_scaler.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from ...optim import Optimizer 4 | from ... import device, Tensor 5 | from typing import Dict, Any, Optional 6 | 7 | class GradScaler(object): 8 | _scale: Optional[Tensor] 9 | _grows_tracker: Optional[Tensor] 10 | _per_optimizer_states: Dict[int, Dict[str, Any]] 11 | 12 | def __init__(self, init_scale: float, growth_factor: float, backoff_factor: float, growth_interval: int, enabled: bool): ... 13 | def _unscale_grads_(self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool) -> Dict[device, Tensor]: ... 14 | def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any): ... 15 | def update(self, new_scale: Optional[float]=None): ... 16 | def unscale_(self, optimizer: Optimizer) -> None: ... 17 | -------------------------------------------------------------------------------- /stubs/torch/cuda/comm/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | #MODIFIED BY TORCHGPIPE 4 | from typing import Iterable, Optional, Tuple 5 | 6 | from torch import Tensor 7 | 8 | 9 | def scatter(tensor: Tensor, 10 | devices: Iterable[int], 11 | chunk_sizes: Optional[Iterable[int]] = None, 12 | dim: int = 0, 13 | ) -> Tuple[Tensor, ...]: ... 14 | 15 | 16 | def gather(tensors: Iterable[Tensor], 17 | dim: int = 0, 18 | destination: Optional[int] = None, 19 | ) -> Tensor: ... 20 | 21 | 22 | def broadcast_coalesced(tensors: Iterable[Tensor], 23 | devices: Iterable[int], 24 | buffer_size: int = 10485760, 25 | ) -> Tuple[Tensor, ...]: ... 26 | 27 | 28 | def reduce_add_coalesced(inputs: Iterable[Iterable[Tensor]], 29 | destination: Optional[int] = None, 30 | buffer_size: int = 10485760, 31 | ) -> Tuple[Tensor, ...]: ... 32 | 33 | #END 34 | -------------------------------------------------------------------------------- /stubs/torch/distributed/distributed_c10d.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any, List, Union, Optional 4 | 5 | from . import ProcessGroup 6 | 7 | def _get_global_rank(group: ProcessGroup, rank: int) -> int: ... 8 | 9 | def _get_default_group() -> ProcessGroup: ... -------------------------------------------------------------------------------- /stubs/torch/distributed/nn/functional.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Optional 4 | from torch import Tensor 5 | from torch.distributed import ProcessGroup, ReduceOp 6 | 7 | def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None): ... 8 | -------------------------------------------------------------------------------- /stubs/torch/distributed/rpc/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Union, Callable, Optional, Any 4 | from torch.futures import Future 5 | 6 | class RRef: 7 | def __init__(self, t: Any) -> None: ... 8 | def local_value(self) -> Any: ... 9 | def owner(self) -> WorkerInfo: ... 10 | def remote(self) -> Any: ... 11 | def rpc_sync(self) -> Any: ... 12 | def to_here(self) -> Any: ... 13 | class WorkerInfo: ... 14 | 15 | class BackendType: 16 | TENSORPIPE: Any 17 | PROCESS_GROUP: Any 18 | 19 | def TensorPipeRpcBackendOptions(init_method: str) -> Any: ... 20 | def ProcessGroupRpcBackendOptions(init_method: str) -> Any: ... 21 | def remote( 22 | to: Union[str, WorkerInfo], 23 | func: Callable, 24 | args: Optional[tuple] = None, 25 | kwargs: Optional[dict] = None, 26 | timeout=-1.0, 27 | ) -> RRef: ... 28 | def rpc_async( 29 | to: Union[str, WorkerInfo], 30 | func: Callable, 31 | args: Optional[tuple] = None, 32 | kwargs: Optional[dict] = None, 33 | timeout=-1.0, 34 | ) -> Future: ... 35 | def rpc_sync( 36 | to: Union[str, WorkerInfo], 37 | func: Callable, 38 | args: Optional[tuple] = None, 39 | kwargs: Optional[dict] = None, 40 | timeout=-1.0, 41 | ) -> Any: ... 42 | def init_rpc( 43 | name: str, 44 | backend: Optional[Any] = None, 45 | rank: int = -1, 46 | world_size: Optional[int] = None, 47 | rpc_backend_options: Optional[Any] = None, 48 | ) -> None: ... 49 | def shutdown(graceful: Optional[bool] = True) -> None: ... 50 | -------------------------------------------------------------------------------- /stubs/torch/fft/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Optional 4 | from torch import Tensor 5 | 6 | # See https://github.com/python/mypy/issues/4146 for why these workarounds 7 | # is necessary 8 | #_int = builtins.int 9 | #_float = builtins.float 10 | #_bool = builtins.bool 11 | #_size = Union[Size, List[int], Tuple[int, ...]] 12 | 13 | 14 | def fft(input: Tensor, n: Optional[int] = None, dim: Optional[int]=-1, norm: Optional[str]=None) -> Tensor: ... 15 | def ifft(input: Tensor, n: Optional[int] = None, dim: Optional[int]=-1, norm: Optional[str]=None) -> Tensor: ... 16 | def rfft(input: Tensor, n: Optional[int] = None, dim: Optional[int]=-1, norm: Optional[str]=None) -> Tensor: ... 17 | def irfft(input: Tensor, n: Optional[int] = None, dim: Optional[int]=-1, norm: Optional[str]=None) -> Tensor: ... 18 | -------------------------------------------------------------------------------- /stubs/torch/functional.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from . import Tensor 4 | from typing import Tuple, List, Union, Optional, Any 5 | 6 | 7 | def split(tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int=0) -> Tuple[Tensor,...]: ... 8 | 9 | def einsum(equation: str, *operands: Tensor): ... 10 | 11 | def norm(input: Tensor, p: Union[int, float, Any], dim: Optional[List[int]]=None, keep_dim: Optional[bool]=False, out: Optional[Tensor]=None, dtype:Optional[int]=None) -> Tensor : ... 12 | -------------------------------------------------------------------------------- /stubs/torch/futures.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any 4 | 5 | class Future: 6 | def wait(self) -> Any: ... 7 | -------------------------------------------------------------------------------- /stubs/torch/jit.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Callable 4 | 5 | def script(fn: Callable) -> Callable: ... 6 | -------------------------------------------------------------------------------- /stubs/torch/multiprocessing/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any, Callable, Optional, Tuple 4 | 5 | from torch import Tensor 6 | 7 | def spawn( 8 | fn: Callable[..., Any], 9 | args: Tuple[Optional[Any], ...] = (), 10 | nprocs: int = 1, 11 | join: bool = True, 12 | daemon: bool = False, 13 | start_method: str = "spawn", 14 | ): ... 15 | 16 | -------------------------------------------------------------------------------- /stubs/torch/nn/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from .modules import * 4 | from .parameter import Parameter as Parameter 5 | from .parallel import DataParallel as DataParallel 6 | from . import functional as functional 7 | -------------------------------------------------------------------------------- /stubs/torch/nn/common_types.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import TypeVar, Union, Tuple 4 | from .. import Tensor 5 | 6 | # Create some useful type aliases 7 | 8 | # Template for arguments which can be supplied as a tuple, or which can be a scalar which PyTorch will internally 9 | # broadcast to a tuple. 10 | # Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations. 11 | T = TypeVar('T') 12 | _scalar_or_tuple_any_t = Union[T, Tuple[T, ...]] 13 | _scalar_or_tuple_1_t = Union[T, Tuple[T]] 14 | _scalar_or_tuple_2_t = Union[T, Tuple[T, T]] 15 | _scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]] 16 | _scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]] 17 | _scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]] 18 | _scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]] 19 | 20 | # For arguments which represent size parameters (eg, kernel size, padding) 21 | _size_any_t = _scalar_or_tuple_any_t[int] 22 | _size_1_t = _scalar_or_tuple_1_t[int] 23 | _size_2_t = _scalar_or_tuple_2_t[int] 24 | _size_3_t = _scalar_or_tuple_3_t[int] 25 | _size_4_t = _scalar_or_tuple_4_t[int] 26 | _size_5_t = _scalar_or_tuple_5_t[int] 27 | _size_6_t = _scalar_or_tuple_6_t[int] 28 | 29 | # For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters) 30 | _ratio_2_t = _scalar_or_tuple_2_t[float] 31 | _ratio_3_t = _scalar_or_tuple_3_t[float] 32 | _ratio_any_t = _scalar_or_tuple_any_t[float] 33 | 34 | _tensor_list_t = _scalar_or_tuple_any_t[Tensor] 35 | 36 | # For the return value of max pooling operations that may or may not return indices. 37 | # With the proposed 'Literal' feature to Python typing, it might be possible to 38 | # eventually eliminate this. 39 | _maybe_indices_t = _scalar_or_tuple_2_t[Tensor] 40 | -------------------------------------------------------------------------------- /stubs/torch/nn/modules/adaptive.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from ... import Tensor 4 | from .module import Module 5 | from .linear import Linear 6 | from collections import namedtuple 7 | from typing import List, Sequence 8 | from .container import ModuleList 9 | 10 | _ASMoutput = namedtuple('_ASMoutput', ['output', 'loss']) 11 | 12 | 13 | class AdaptiveLogSoftmaxWithLoss(Module): 14 | in_features: int = ... 15 | n_classes: int = ... 16 | cutoffs: List[int] = ... 17 | div_value: float = ... 18 | head_bias: bool = ... 19 | head: Linear = ... 20 | tail: ModuleList = ... 21 | 22 | def __init__(self, in_features: int, n_classes: int, cutoffs: Sequence[int], div_value: float = ..., 23 | head_bias: bool = ...) -> None: ... 24 | 25 | def reset_parameters(self) -> None: ... 26 | 27 | def forward(self, input: Tensor, target: Tensor) -> _ASMoutput: ... # type: ignore 28 | 29 | def __call__(self, input: Tensor, target: Tensor) -> _ASMoutput: ... # type: ignore 30 | 31 | def log_prob(self, input: Tensor) -> List[float]: ... 32 | 33 | def predict(self, input: Tensor) -> Tensor: ... 34 | -------------------------------------------------------------------------------- /stubs/torch/nn/modules/batchnorm.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from ... import Tensor 4 | from .. import Parameter 5 | from .module import Module 6 | from typing import Any, Optional 7 | 8 | 9 | class _BatchNorm(Module): 10 | num_features: int = ... 11 | eps: float = ... 12 | momentum: float = ... 13 | affine: bool = ... 14 | track_running_stats: bool = ... 15 | weight: Parameter = ... 16 | bias: Parameter = ... 17 | 18 | # This field is used by fairscale.nn.misc.misc::patch_batchnorm 19 | _track_running_stats_backup: bool 20 | 21 | #MODIFIED BY TORCHGPIPE 22 | running_mean: Tensor 23 | running_var: Tensor 24 | num_batches_tracked: Tensor 25 | 26 | def __init__(self, num_features: int, eps: float = ..., momentum: Optional[float] = ..., affine: bool = ..., 27 | track_running_stats: bool = ...) -> None: ... 28 | #END 29 | 30 | def reset_running_stats(self) -> None: ... 31 | 32 | def reset_parameters(self) -> None: ... 33 | 34 | 35 | class BatchNorm1d(_BatchNorm): ... 36 | 37 | 38 | class BatchNorm2d(_BatchNorm): ... 39 | 40 | 41 | class BatchNorm3d(_BatchNorm): ... 42 | 43 | 44 | class SyncBatchNorm(_BatchNorm): 45 | # TODO set process_group to the write type once torch.distributed is stubbed 46 | def __init__(self, num_features: int, eps: float = ..., momentum: float = ..., affine: bool = ..., 47 | track_running_stats: bool = ..., process_group: Optional[Any] = ...) -> None: ... 48 | -------------------------------------------------------------------------------- /stubs/torch/nn/modules/distance.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from ... import Tensor 4 | from .module import Module 5 | 6 | 7 | class PairwiseDistance(Module): 8 | norm: float 9 | eps: float 10 | keepdim: bool 11 | 12 | def __init__(self, p: float = ..., eps: float = ..., keepdim: bool = ...) -> None: ... 13 | 14 | def forward(self, x1: Tensor, x2: Tensor) -> Tensor: ... # type: ignore 15 | 16 | def __call__(self, x1: Tensor, x2: Tensor) -> Tensor: ... # type: ignore 17 | 18 | 19 | class CosineSimilarity(Module): 20 | dim: int 21 | eps: float 22 | 23 | def __init__(self, dim: int = ..., eps: float = ...) -> None: ... 24 | 25 | def forward(self, x1: Tensor, x2: Tensor) -> Tensor: ... # type: ignore 26 | 27 | def __call__(self, x1: Tensor, x2: Tensor) -> Tensor: ... # type: ignore 28 | -------------------------------------------------------------------------------- /stubs/torch/nn/modules/dropout.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from ... import Tensor 4 | from .module import Module 5 | 6 | 7 | class _DropoutNd(Module): 8 | p: float 9 | inplace: bool 10 | 11 | def __init__(self, p: float = ..., inplace: bool = ...) -> None: ... 12 | 13 | def extra_repr(self): ... 14 | 15 | 16 | class Dropout(_DropoutNd): 17 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 18 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 19 | 20 | 21 | class Dropout2d(_DropoutNd): 22 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 23 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 24 | 25 | 26 | class Dropout3d(_DropoutNd): 27 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 28 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 29 | 30 | 31 | class AlphaDropout(_DropoutNd): 32 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 33 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 34 | 35 | 36 | class FeatureAlphaDropout(_DropoutNd): 37 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 38 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 39 | -------------------------------------------------------------------------------- /stubs/torch/nn/modules/flatten.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any 4 | from .module import Module 5 | 6 | class Flatten(Module): 7 | __constants__: Any = ... 8 | start_dim: Any = ... 9 | end_dim: Any = ... 10 | def __init__(self, start_dim: int = ..., end_dim: int = ...) -> None: ... 11 | def forward(self, input: Any): ... # type: ignore 12 | def __call__(self, input: Any): ... # type: ignore 13 | -------------------------------------------------------------------------------- /stubs/torch/nn/modules/fold.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from .module import Module 4 | from ... import Tensor 5 | from ..common_types import _size_any_t 6 | 7 | 8 | class Fold(Module): 9 | output_size: _size_any_t = ... 10 | kernel_size: _size_any_t = ... 11 | dilation: _size_any_t = ... 12 | padding: _size_any_t = ... 13 | stride: _size_any_t = ... 14 | 15 | def __init__(self, output_size: _size_any_t, kernel_size: _size_any_t, dilation: _size_any_t = ..., 16 | padding: _size_any_t = ..., stride: _size_any_t = ...) -> None: ... 17 | 18 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 19 | 20 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 21 | 22 | 23 | class Unfold(Module): 24 | kernel_size: _size_any_t = ... 25 | dilation: _size_any_t = ... 26 | padding: _size_any_t = ... 27 | stride: _size_any_t = ... 28 | 29 | def __init__(self, kernel_size: _size_any_t, dilation: _size_any_t = ..., padding: _size_any_t = ..., 30 | stride: _size_any_t = ...) -> None: ... 31 | 32 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 33 | 34 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 35 | -------------------------------------------------------------------------------- /stubs/torch/nn/modules/instancenorm.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from ... import Tensor 4 | from .batchnorm import _BatchNorm 5 | 6 | 7 | class _InstanceNorm(_BatchNorm): 8 | def __init__(self, num_features: int, eps: float = ..., momentum: float = ..., affine: bool = ..., 9 | track_running_stats: bool = ...) -> None: ... 10 | 11 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 12 | 13 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 14 | 15 | 16 | class InstanceNorm1d(_InstanceNorm): ... 17 | 18 | 19 | class InstanceNorm2d(_InstanceNorm): ... 20 | 21 | 22 | class InstanceNorm3d(_InstanceNorm): ... 23 | -------------------------------------------------------------------------------- /stubs/torch/nn/modules/linear.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from .module import Module 4 | from .. import Parameter 5 | from ... import Tensor 6 | 7 | import torch 8 | from typing import Union 9 | 10 | 11 | class Identity(Module): 12 | 13 | def __init__(self) -> None: ... 14 | 15 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 16 | 17 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 18 | 19 | 20 | class Linear(Module): 21 | in_features: int = ... 22 | out_features: int = ... 23 | weight: Parameter = ... 24 | bias: Parameter = ... 25 | 26 | def __init__(self, in_features: int, out_features: int, bias: bool = ..., device:str = ..., dtype:Union[str, torch.dtype] = ...) -> None: ... 27 | 28 | def reset_parameters(self) -> None: ... 29 | 30 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 31 | 32 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 33 | 34 | 35 | class Bilinear(Module): 36 | in1_features: int = ... 37 | in2_features: int = ... 38 | out_features: int = ... 39 | weight: Parameter = ... 40 | bias: Parameter = ... 41 | 42 | def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = ...) -> None: ... 43 | 44 | def reset_parameters(self) -> None: ... 45 | 46 | def forward(self, input1: Tensor, input2: Tensor) -> Tensor: ... # type: ignore 47 | 48 | def __call__(self, input1: Tensor, input2: Tensor) -> Tensor: ... # type: ignore 49 | -------------------------------------------------------------------------------- /stubs/torch/nn/modules/normalization.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from .module import Module 4 | from typing import Any, Union, List 5 | from ... import Tensor, Size 6 | from .. import Parameter 7 | 8 | 9 | class LocalResponseNorm(Module): 10 | size: int = ... 11 | alpha: float = ... 12 | beta: float = ... 13 | k: float = ... 14 | 15 | def __init__(self, size: int, alpha: float = ..., beta: float = ..., k: float = ...) -> None: ... 16 | 17 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 18 | 19 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 20 | 21 | 22 | class CrossMapLRN2d(Module): 23 | size: int = ... 24 | alpha: float = ... 25 | beta: float = ... 26 | k: float = ... 27 | 28 | def __init__(self, size: int, alpha: float = ..., beta: float = ..., k: float = ...) -> None: ... 29 | 30 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 31 | 32 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 33 | 34 | 35 | _shape_t = Union[int, List[int], Size] 36 | 37 | 38 | class LayerNorm(Module): 39 | normalized_shape: _shape_t = ... 40 | eps: float = ... 41 | elementwise_affine: bool = ... 42 | weight: Parameter = ... 43 | bias: Parameter = ... 44 | 45 | def __init__(self, normalized_shape: _shape_t, eps: float = ..., elementwise_affine: bool = ...) -> None: ... 46 | 47 | def reset_parameters(self) -> None: ... 48 | 49 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 50 | 51 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 52 | 53 | 54 | class GroupNorm(Module): 55 | num_groups: int = ... 56 | num_channels: int = ... 57 | eps: float = ... 58 | affine: bool = ... 59 | weight: Parameter = ... 60 | bias: Parameter = ... 61 | 62 | def __init__(self, num_groups: int, num_channels: int, eps: float = ..., affine: bool = ...) -> None: ... 63 | 64 | def reset_parameters(self) -> None: ... 65 | 66 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 67 | 68 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 69 | -------------------------------------------------------------------------------- /stubs/torch/nn/modules/padding.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from .module import Module 4 | from ... import Tensor 5 | from ..common_types import _size_2_t, _size_4_t, _size_6_t 6 | 7 | 8 | class _ConstantPadNd(Module): 9 | value: float 10 | 11 | def __init__(self, value: float) -> None: ... 12 | 13 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 14 | 15 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 16 | 17 | 18 | class ConstantPad1d(_ConstantPadNd): 19 | padding: _size_2_t = ... 20 | 21 | def __init__(self, padding: _size_2_t, value: float) -> None: ... 22 | 23 | 24 | class ConstantPad2d(_ConstantPadNd): 25 | padding: _size_4_t = ... 26 | 27 | def __init__(self, padding: _size_4_t, value: float) -> None: ... 28 | 29 | 30 | class ConstantPad3d(_ConstantPadNd): 31 | padding: _size_6_t = ... 32 | 33 | def __init__(self, padding: _size_6_t, value: float) -> None: ... 34 | 35 | 36 | class _ReflectionPadNd(Module): 37 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 38 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 39 | 40 | def extra_repr(self): ... 41 | 42 | 43 | class ReflectionPad1d(_ReflectionPadNd): 44 | padding: _size_2_t = ... 45 | 46 | def __init__(self, padding: _size_2_t) -> None: ... 47 | 48 | 49 | class ReflectionPad2d(_ReflectionPadNd): 50 | padding: _size_4_t = ... 51 | 52 | def __init__(self, padding: _size_4_t) -> None: ... 53 | 54 | 55 | class _ReplicationPadNd(Module): 56 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 57 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 58 | 59 | def extra_repr(self): ... 60 | 61 | 62 | class ReplicationPad1d(_ReplicationPadNd): 63 | padding: _size_2_t = ... 64 | 65 | def __init__(self, padding: _size_2_t) -> None: ... 66 | 67 | 68 | class ReplicationPad2d(_ReplicationPadNd): 69 | padding: _size_4_t = ... 70 | 71 | def __init__(self, padding: _size_4_t) -> None: ... 72 | 73 | 74 | class ReplicationPad3d(_ReplicationPadNd): 75 | padding: _size_6_t = ... 76 | 77 | def __init__(self, padding: _size_6_t) -> None: ... 78 | 79 | 80 | class ZeroPad2d(ConstantPad2d): 81 | padding: _size_4_t = ... 82 | 83 | def __init__(self, padding: _size_4_t) -> None: ... 84 | -------------------------------------------------------------------------------- /stubs/torch/nn/modules/pixelshuffle.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from .module import Module 4 | from ... import Tensor 5 | 6 | 7 | class PixelShuffle(Module): 8 | upscale_factor: int = ... 9 | 10 | def __init__(self, upscale_factor: int) -> None: ... 11 | 12 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 13 | 14 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 15 | -------------------------------------------------------------------------------- /stubs/torch/nn/modules/sparse.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from .module import Module 4 | from typing import Optional 5 | from .. import Parameter 6 | from ... import Tensor 7 | 8 | 9 | class Embedding(Module): 10 | num_embeddings: int = ... 11 | embedding_dim: int = ... 12 | padding_idx: int = ... 13 | max_norm: float = ... 14 | norm_type: float = ... 15 | scale_grad_by_freq: bool = ... 16 | weight: Parameter = ... 17 | sparse: bool = ... 18 | 19 | def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = ..., 20 | max_norm: Optional[float] = ..., norm_type: float = ..., scale_grad_by_freq: bool = ..., 21 | sparse: bool = ..., _weight: Optional[Tensor] = ...) -> None: ... 22 | 23 | def reset_parameters(self) -> None: ... 24 | 25 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 26 | 27 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 28 | 29 | @classmethod 30 | def from_pretrained(cls, embeddings: Tensor, freeze: bool = ..., padding_idx: Optional[int] = ..., 31 | max_norm: Optional[float] = ..., norm_type: float = ..., scale_grad_by_freq: bool = ..., 32 | sparse: bool = ...): ... 33 | 34 | 35 | class EmbeddingBag(Module): 36 | num_embeddings: int = ... 37 | embedding_dim: int = ... 38 | max_norm: float = ... 39 | norm_type: float = ... 40 | scale_grad_by_freq: bool = ... 41 | weight: Parameter = ... 42 | mode: str = ... 43 | sparse: bool = ... 44 | 45 | def __init__(self, num_embeddings: int, embedding_dim: int, max_norm: Optional[float] = ..., norm_type: float = ..., 46 | scale_grad_by_freq: bool = ..., mode: str = ..., sparse: bool = ..., 47 | _weight: Optional[Tensor] = ...) -> None: ... 48 | 49 | def reset_parameters(self) -> None: ... 50 | 51 | def forward(self, input: Tensor, offsets: Optional[Tensor] = ...) -> Tensor: ... # type: ignore 52 | 53 | def __call__(self, input: Tensor, offsets: Optional[Tensor] = ...) -> Tensor: ... # type: ignore 54 | 55 | @classmethod 56 | def from_pretrained(cls, embeddings: Tensor, freeze: bool = ..., max_norm: Optional[float] = ..., 57 | norm_type: float = ..., scale_grad_by_freq: bool = ..., mode: str = ..., 58 | sparse: bool = ...): ... 59 | -------------------------------------------------------------------------------- /stubs/torch/nn/modules/upsampling.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from ... import Tensor 4 | from .module import Module 5 | from typing import Optional 6 | from ..common_types import _size_2_t, _ratio_2_t, _size_any_t, _ratio_any_t 7 | 8 | 9 | class Upsample(Module): 10 | name: str = ... 11 | size: _size_any_t = ... 12 | scale_factor: _ratio_any_t = ... 13 | mode: str = ... 14 | align_corners: bool = ... 15 | 16 | def __init__(self, size: Optional[_size_any_t] = ..., scale_factor: Optional[_ratio_any_t] = ..., mode: str = ..., 17 | align_corners: Optional[bool] = ...) -> None: ... 18 | 19 | def forward(self, input: Tensor) -> Tensor: ... # type: ignore 20 | 21 | def __call__(self, input: Tensor) -> Tensor: ... # type: ignore 22 | 23 | 24 | class UpsamplingNearest2d(Upsample): 25 | def __init__(self, size: Optional[_size_2_t] = ..., scale_factor: Optional[_ratio_2_t] = ...) -> None: ... 26 | 27 | 28 | class UpsamplingBilinear2d(Upsample): 29 | def __init__(self, size: Optional[_size_2_t] = ..., scale_factor: Optional[_ratio_2_t] = ...) -> None: ... 30 | -------------------------------------------------------------------------------- /stubs/torch/nn/parallel/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from .data_parallel import DataParallel as DataParallel, data_parallel as data_parallel 4 | from .distributed import DistributedDataParallel as DistributedDataParallel 5 | from .parallel_apply import parallel_apply as parallel_apply 6 | from .replicate import replicate as replicate 7 | from .scatter_gather import gather as gather, scatter as scatter 8 | -------------------------------------------------------------------------------- /stubs/torch/nn/parallel/common_types.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Union, Sequence 4 | from ... import device 5 | 6 | _device_t = Union[int, device] 7 | _devices_t = Sequence[_device_t] 8 | -------------------------------------------------------------------------------- /stubs/torch/nn/parallel/data_parallel.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any, Optional, TypeVar 4 | from .common_types import _devices_t, _device_t 5 | from ..modules import Module 6 | from ... import device, Tensor 7 | 8 | T_co = TypeVar('T_co', covariant=True) 9 | class DataParallel(Module[T_co]): 10 | module: Module = ... 11 | device_ids: _devices_t = ... 12 | dim: int = ... 13 | output_device: _device_t = ... 14 | src_device_obj: device = ... 15 | 16 | def __init__(self, module: Module[T_co], device_ids: Optional[_devices_t] = ..., output_device: Optional[_device_t] = ..., 17 | dim: int = ...) -> None: ... 18 | 19 | def forward(self, *inputs: Any, **kwargs: Any) -> T_co: ... 20 | def __call__(self, *inputs: Any, **kwargs: Any) -> T_co: ... 21 | 22 | 23 | def data_parallel(module: Module, inputs: Any, device_ids: Optional[_devices_t] = ..., 24 | output_device: Optional[_device_t] = ..., dim: int = ..., 25 | module_kwargs: Optional[Any] = ...) -> Tensor: ... 26 | -------------------------------------------------------------------------------- /stubs/torch/nn/parallel/distributed.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from ..modules import Module 4 | from typing import Any, Optional, TypeVar 5 | from .common_types import _devices_t, _device_t 6 | 7 | T_co = TypeVar('T_co', covariant=True) 8 | 9 | def get_rank(group: Any): ... 10 | 11 | class DistributedDataParallel(Module[T_co]): 12 | process_group: Any = ... 13 | dim: int = ... 14 | module: Module[T_co] = ... 15 | device_ids: _devices_t = ... 16 | output_device: _device_t = ... 17 | broadcast_buffers: bool = ... 18 | check_reduction: bool = ... 19 | broadcast_bucket_size: float = ... 20 | bucket_bytes_cap: float = ... 21 | find_unused_parameters: bool = ... 22 | 23 | # TODO type process_group once `distributed` module is stubbed 24 | def __init__(self, module: Module[T_co], device_ids: Optional[_devices_t] = ..., 25 | output_device: Optional[_device_t] = ..., dim: int = ..., 26 | broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ..., 27 | check_reduction: bool = ..., find_unused_parameters: bool = ...) -> None: ... 28 | 29 | def forward(self, *inputs: Any, **kwargs: Any) -> T_co: ... 30 | 31 | def __call__(self, *inputs: Any, **kwargs: Any) -> T_co: ... 32 | -------------------------------------------------------------------------------- /stubs/torch/nn/parallel/parallel_apply.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any, Optional, Sequence, List 4 | from .common_types import _devices_t 5 | from ..modules import Module 6 | 7 | 8 | def parallel_apply(modules: Sequence[Module], inputs: Sequence[Any], kwargs_tup: Optional[Any] = ..., 9 | devices: Optional[_devices_t] = ...) -> List[Any]: ... 10 | -------------------------------------------------------------------------------- /stubs/torch/nn/parallel/replicate.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import List, Union, Sequence, TypeVar 4 | from ..modules import Module 5 | from .common_types import _devices_t 6 | 7 | T = TypeVar('T') 8 | 9 | 10 | def replicate(network: Module[T], devices: Union[_devices_t, Sequence[_devices_t]], detach: bool = ...) -> List[ 11 | Module[T]]: ... 12 | -------------------------------------------------------------------------------- /stubs/torch/nn/parallel/scatter_gather.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any, Dict, List, Tuple, overload, TypeVar 4 | from ... import Tensor 5 | from .common_types import _device_t, _devices_t 6 | 7 | 8 | T = TypeVar('T', Dict, List, Tuple) 9 | 10 | # For some reason, 'scatter' returns a tuple when given a single Tensor input but a list otherwise. 11 | @overload 12 | def scatter(inputs: Tensor, target_gpus: _devices_t, dim: int = ...) -> Tuple[Tensor, ...]: ... 13 | 14 | # flake8 will raise a spurious error here since `torch/__init__.pyi` has not been generated yet 15 | # so mypy will interpret `Tensor` as `Any` since it is an import from what it belives to be an 16 | # untyped module. Thus to mypy, the first definition of `scatter` looks strictly more general 17 | # than this overload. 18 | @overload 19 | def scatter(inputs: T, target_gpus: _devices_t, dim: int = ...) -> List[T]: ... # type: ignore 20 | 21 | 22 | # TODO More precise types here. 23 | def scatter_kwargs(inputs: Any, kwargs: Any, target_gpus: _devices_t, dim: int = ...) -> Any: ... 24 | 25 | 26 | def gather(outputs: Any, target_device: _device_t, dim: int = ...) -> Any: ... 27 | -------------------------------------------------------------------------------- /stubs/torch/nn/parameter.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Optional, Tuple, Any 4 | from .. import Size, Tensor 5 | from ..cuda import Stream 6 | import builtins 7 | 8 | class Parameter(Tensor): 9 | # These are dynamic attributes added by shard_params_data_parallel class. 10 | # Added here for better type checking. 11 | _is_sharded: bool 12 | _is_shared: bool 13 | _orig_size: Size 14 | _cpu_grad: Tensor 15 | _full_param_padded: Tensor 16 | _fp32_shard: Tensor 17 | _fp16_shard: Optional[Tensor] 18 | _shard_bwd_hook: Tuple[Any, Any] 19 | _saved_grad_shard: Tensor 20 | _linked_param: Parameter 21 | 22 | def __new__(cls, data: Tensor, requires_grad: builtins.bool = True): ... 23 | 24 | def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ... 25 | 26 | ... 27 | -------------------------------------------------------------------------------- /stubs/torch/optim/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from .sgd import SGD as SGD 4 | from .adam import Adam as Adam 5 | from . import lr_scheduler as lr_scheduler 6 | from .optimizer import Optimizer as Optimizer 7 | #MODIFIED BY TORCHGPIPE 8 | from .rmsprop import RMSprop as RMSprop 9 | #END 10 | -------------------------------------------------------------------------------- /stubs/torch/optim/adam.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Tuple 4 | from .optimizer import _params_t, Optimizer 5 | 6 | class Adam(Optimizer): 7 | def __init__(self, params: _params_t, lr: float=..., betas: Tuple[float, float]=..., eps: float=..., weight_decay: float=..., amsgrad: bool = ...) -> None: ... 8 | -------------------------------------------------------------------------------- /stubs/torch/optim/lr_scheduler.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Iterable, Any, Optional, Callable 4 | from .optimizer import Optimizer 5 | 6 | class _LRScheduler: 7 | def __init__(self, optimizer: Optimizer, last_epoch: int=...) -> None: ... 8 | def state_dict(self) -> dict: ... 9 | def load_state_dict(self, state_dict: dict) -> None: ... 10 | #MODIFIED BY TORCHGPIPE 11 | from typing import List 12 | def get_lr(self) -> List[float]: ... 13 | def step(self, epoch: Optional[int] = ...) -> None: ... 14 | #END 15 | 16 | class LambdaLR(_LRScheduler): 17 | #MODIFIED BY TORCHGPIPE 18 | from typing import Callable, List, Union 19 | def __init__(self, optimizer: Optimizer, lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], last_epoch: int=...) -> None: ... 20 | #END 21 | 22 | class StepLR(_LRScheduler): 23 | def __init__(self, optimizer: Optimizer, step_size: int, gamma: float=..., last_epoch: int=...) -> None:... 24 | 25 | class MultiStepLR(_LRScheduler): 26 | def __init__(self, optimizer: Optimizer, milestones: Iterable[int], gamma: float=..., last_epoch: int=...) -> None: ... 27 | 28 | class ExponentialLR(_LRScheduler): 29 | def __init__(self, optimizer: Optimizer, gamma: float, last_epoch: int=...) -> None: ... 30 | 31 | class CosineAnnealingLR(_LRScheduler): 32 | def __init__(self, optimizer: Optimizer, T_max: int, eta_min: float, last_epoch: int=...) -> None: ... 33 | 34 | class ReduceLROnPlateau: 35 | in_cooldown: bool 36 | 37 | def __init__(self, optimizer: Optimizer, mode: str=..., factor: float=..., patience: int=..., verbose: bool=..., threshold: float=..., threshold_mode: str=..., cooldown: int=..., min_lr: float=..., eps: float=...) -> None: ... 38 | def step(self, metrics: Any, epoch: Optional[int]=...) -> None: ... 39 | def state_dict(self) -> dict: ... 40 | def load_state_dict(self, state_dict: dict): ... 41 | 42 | class CyclicLR(_LRScheduler): 43 | def __init__(self, optimizer: Optimizer, base_lr: float=..., max_lr: float=..., step_size_up: int=..., step_size_down: int=..., mode: str=..., gamma: float=..., scale_fn: Optional[Callable[[float], float]]=..., scale_mode: str=..., cycle_momentum: bool=..., base_momentum: float=..., max_momentum: float=..., last_epoch: int=...) -> None: ... 44 | 45 | class CosineAnnealingWarmRestarts(_LRScheduler): 46 | def __init__(self, optimizer: Optimizer, T_0: int=..., T_mult: int=..., eta_min: int=..., last_epoch: int=...) -> None: ... 47 | def step(self, epoch: Optional[int] = ...) -> None: ... 48 | -------------------------------------------------------------------------------- /stubs/torch/optim/optimizer.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any, List, Dict, Iterable, Union, Callable, Optional 4 | from .. import Tensor 5 | 6 | _params_t = Union[Iterable[Tensor], Iterable[Dict]] 7 | 8 | class Optimizer(object): 9 | param_groups: List[Dict] 10 | state: Dict 11 | def __init__(self, params: _params_t, defaults: Optional[Dict]=None, lr: Optional[float]=None) -> None: ... 12 | def __getattr__(self, name: str) -> Any: ... 13 | def state_dict(self) -> Dict: ... 14 | def load_state_dict(self, state_dict: Dict) -> None: ... 15 | def zero_grad(self) -> None: ... 16 | def step(self, closure: Optional[Callable[[], float]]=...) -> Optional[float]: ... 17 | def add_param_group(self, param_group: Dict) -> None: ... 18 | -------------------------------------------------------------------------------- /stubs/torch/optim/sgd.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from .optimizer import _params_t, Optimizer 4 | 5 | class SGD(Optimizer): 6 | def __init__(self, params: _params_t, lr: float, momentum: float=..., dampening: float=..., weight_decay:float=..., nesterov:bool=...) -> None: ... 7 | -------------------------------------------------------------------------------- /stubs/torch/random.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | #MODIFIED BY TORCHGPIPE 4 | from contextlib import contextmanager 5 | from typing import Any, Generator, Iterable, Union 6 | from torch import ByteTensor, device 7 | 8 | def set_rng_state(new_state: ByteTensor) -> None: ... 9 | def get_rng_state() -> ByteTensor: ... 10 | 11 | def manual_seed(seed: int) -> Any: ... 12 | def seed() -> int: ... 13 | def initial_seed() -> int: ... 14 | 15 | @contextmanager 16 | def fork_rng(devices: Iterable[Union[device, str, int]] = ..., enabled: bool = ...) -> Generator[None, None, None]: ... 17 | #END 18 | -------------------------------------------------------------------------------- /stubs/torch/serialization.pyi: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from pathlib import Path 4 | from typing import Any, BinaryIO, Callable, IO, Union 5 | 6 | DEFAULT_PROTOCOL: int = 2 7 | 8 | def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]], 9 | pickle_module: Any=pickle, pickle_protocol: int=DEFAULT_PROTOCOL, _use_new_zipfile_serialization: bool=True) -> None: ... 10 | 11 | def load(f: Union[str, BinaryIO, Path], map_location=None) -> Any: ... 12 | -------------------------------------------------------------------------------- /stubs/torch/testing/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any 4 | 5 | # Deprecate allclose when we move to newer versions. 6 | def assert_allclose(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ... 7 | def assert_close(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ... 8 | -------------------------------------------------------------------------------- /stubs/torch/utils/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from . import checkpoint 4 | -------------------------------------------------------------------------------- /stubs/torch/utils/checkpoint.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any, Iterable, Tuple 4 | from .. import Tensor 5 | from torch.nn.modules.module import Module 6 | 7 | def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ... 8 | def checkpoint(function: Module, *args, **kwargs): ... 9 | def check_backward_validity(inputs: Iterable[Any]): ... 10 | def checkpoint_sequential(function: Module, segments: int, *args, **kwargs): ... 11 | -------------------------------------------------------------------------------- /stubs/torch/utils/data/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from .sampler import Sampler as Sampler, SequentialSampler as SequentialSampler, RandomSampler as RandomSampler, \ 4 | SubsetRandomSampler as SubsetRandomSampler, WeightedRandomSampler as WeightedRandomSampler, BatchSampler as BatchSampler 5 | from .distributed import DistributedSampler as DistributedSampler 6 | from .dataset import Dataset as Dataset, TensorDataset as TensorDataset, ConcatDataset as ConcatDataset, \ 7 | Subset as Subset, random_split as random_split 8 | from .dataloader import DataLoader as DataLoader 9 | -------------------------------------------------------------------------------- /stubs/torch/utils/data/dataloader.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Any, Callable, TypeVar, Generic, overload, Sequence, List, Optional 4 | from . import Dataset, Sampler 5 | 6 | T_co = TypeVar('T_co', covariant=True) 7 | T = TypeVar('T') 8 | _worker_init_fn_t = Callable[[int], None] 9 | 10 | # Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that 11 | # type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. 12 | # See https://github.com/python/mypy/issues/3737. 13 | _collate_fn_t = Callable[[List[T]], Any] 14 | 15 | class DataLoader(Generic[T_co]): 16 | dataset: Dataset[T_co] 17 | batch_size: int 18 | num_workers: int 19 | pin_memory: bool 20 | drop_last: bool 21 | timeout: float 22 | 23 | @overload 24 | def __init__(self, dataset: Dataset[T_co], batch_size: int=..., shuffle: bool=..., 25 | sampler: Optional[Sampler[int]]=..., num_workers: int=..., collate_fn: _collate_fn_t=..., 26 | pin_memory: bool=..., drop_last: bool=..., timeout: float=..., 27 | worker_init_fn: _worker_init_fn_t=...) -> None: ... 28 | @overload 29 | def __init__(self, dataset: Dataset[T_co], batch_sampler: Optional[Sampler[Sequence[int]]]=..., 30 | num_workers: int=..., collate_fn: _collate_fn_t=..., pin_memory: bool=..., timeout: float=..., 31 | worker_init_fn: _worker_init_fn_t=...) -> None: ... 32 | 33 | def __len__(self) -> int: ... 34 | # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up 35 | # since '_BaseDataLoaderIter' references 'DataLoader'. In mypy 0.720 and newer a new semantic 36 | # analyzer is used that obviates the need for this but we leave the quoting in to support older 37 | # versions of mypy 38 | def __iter__(self) -> '_BaseDataLoaderIter':... 39 | 40 | class _BaseDataLoaderIter: 41 | def __init__(self, loader: DataLoader) -> None:... 42 | def __len__(self) -> int: ... 43 | def __iter__(self) -> _BaseDataLoaderIter: ... 44 | def __next__(self) -> Any: ... 45 | -------------------------------------------------------------------------------- /stubs/torch/utils/data/dataset.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import TypeVar, Generic, Iterable, Sequence, List, Tuple 4 | from ... import Tensor 5 | 6 | T_co = TypeVar('T_co', covariant=True) 7 | T = TypeVar('T') 8 | class Dataset(Generic[T_co]): 9 | def __getitem__(self, index: int) -> T_co: ... 10 | def __len__(self) -> int: ... 11 | def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]': ... 12 | 13 | class IterableDataset(Dataset[T_co]): 14 | def __iter__(self) -> Iterable[T_co]: ... 15 | 16 | 17 | class TensorDataset(Dataset[Tuple[Tensor, ...]]): 18 | tensors: List[Tensor] 19 | 20 | def __init__(self, *tensors: Tensor) -> None: ... 21 | 22 | class ConcatDataset(Dataset[T_co]): 23 | datasets: List[Dataset[T_co]] 24 | cumulative_sizes: List[int] 25 | 26 | def __init__(self, datasets: Iterable[Dataset]) -> None: ... 27 | 28 | class Subset(Dataset[T_co]): 29 | dataset: Dataset[T_co] 30 | indices: Sequence[int] 31 | 32 | def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: ... 33 | 34 | def random_split(dataset: Dataset[T], lengths: Sequence[int]) -> List[Subset[T]]: ... 35 | -------------------------------------------------------------------------------- /stubs/torch/utils/data/distributed.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import TypeVar, Optional, Iterator 4 | from . import Sampler, Dataset 5 | 6 | T_co = TypeVar('T_co', covariant=True) 7 | class DistributedSampler(Sampler[T_co]): 8 | def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=...): ... 9 | def __iter__(self) -> Iterator[T_co]: ... 10 | def __len__(self) -> int: ... 11 | def set_epoch(self, epoch: int) -> None: ... 12 | -------------------------------------------------------------------------------- /stubs/torch/utils/data/sampler.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized 4 | 5 | T_co = TypeVar('T_co', covariant=True) 6 | class Sampler(Generic[T_co]): 7 | def __init__(self, data_source: Sized) -> None: ... 8 | def __iter__(self) -> Iterator[T_co]: ... 9 | def __len__(self) -> int: ... 10 | 11 | class SequentialSampler(Sampler[int]): 12 | pass 13 | 14 | class RandomSampler(Sampler[int]): 15 | num_samples: int 16 | 17 | def __init__(self, data_source: Sized, replacement: bool=..., num_samples: Optional[int]=...) -> None: ... 18 | 19 | class SubsetRandomSampler(Sampler[int]): 20 | def __init__(self, indices: Sequence[int]) -> None: ... 21 | 22 | class WeightedRandomSampler(Sampler[int]): 23 | def __init__(self, weights: Sequence[float], num_samples: int, replacement: bool=...) -> None: ... 24 | 25 | class BatchSampler(Sampler[List[int]]): 26 | def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None: ... 27 | -------------------------------------------------------------------------------- /stubs/torch/version.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | #MODIFIED BY TORCHGPIPE 4 | debug: bool = ... 5 | cuda: str = ... 6 | git_version: str = ... 7 | #END 8 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | # 7 | # We need to have __init__.py in tests dir due to a pytest issue. 8 | # 9 | # if you have: 10 | # tests/ 11 | # aa/test_name.py 12 | # bb/test_name.py 13 | # 14 | # running `pytest tests` will give an error like "import file mismatch" 15 | # because it can't distinguish between the file in `aa` and `bb` with 16 | # the same file name. Add __init__.py file fixes it. 17 | # 18 | # However, `pytest tests/__init__.py` triggers running tests that's 19 | # not related. So we just don't include any __init__.py in the test 20 | # list files. 21 | -------------------------------------------------------------------------------- /tests/ci_test_list_1.txt: -------------------------------------------------------------------------------- 1 | tests/nn/data_parallel/test_fsdp_memory.py 2 | tests/nn/data_parallel/test_fsdp_multiple_wrapping.py 3 | tests/nn/data_parallel/test_fsdp_freezing_weights.py 4 | tests/nn/data_parallel/test_fsdp_regnet.py 5 | tests/nn/data_parallel/test_fsdp_uneven.py 6 | tests/nn/data_parallel/test_fsdp_grad_acc.py 7 | tests/nn/data_parallel/test_fsdp_summon_full_params.py 8 | tests/nn/data_parallel/test_fsdp_input.py 9 | tests/nn/data_parallel/test_fsdp_optimizer_utils.py 10 | tests/nn/data_parallel/test_fsdp.py 11 | tests/nn/data_parallel/test_fsdp_with_checkpoint_wrapper.py 12 | tests/optim/test_layerwise_gradient_scaler.py 13 | -------------------------------------------------------------------------------- /tests/ci_test_list_2.txt: -------------------------------------------------------------------------------- 1 | tests/experimental/tooling/test_layer_memory_tracker.py 2 | tests/experimental/nn/test_mevo.py 3 | tests/experimental/nn/test_multiprocess_pipe.py 4 | tests/experimental/nn/test_sync_batchnorm.py 5 | tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py 6 | tests/experimental/nn/test_offload.py 7 | tests/experimental/nn/test_auto_shard.py 8 | tests/experimental/optim/test_dynamic_loss_scaler.py 9 | tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py 10 | tests/nn/data_parallel/test_fsdp_shared_weights.py 11 | tests/nn/data_parallel/test_fsdp_pre_backward_hook.py 12 | tests/nn/data_parallel/test_fsdp_overlap.py 13 | tests/nn/data_parallel/test_fsdp_multiple_forward.py 14 | tests/nn/data_parallel/test_fsdp_apply.py 15 | tests/nn/data_parallel/test_fsdp_state_dict.py 16 | tests/nn/data_parallel/test_fsdp_metadata.py 17 | tests/utils/test_reduce_scatter_bucketer.py 18 | tests/utils/test_containers.py 19 | tests/utils/test_parallel.py 20 | tests/utils/test_state_dict.py 21 | tests/utils/test_version.py 22 | tests/nn/misc/test_grad_bucket.py 23 | tests/nn/misc/test_param_bucket.py 24 | tests/nn/wrap/test_wrap.py 25 | tests/nn/pipe_process/test_pipe.py 26 | tests/nn/pipe_process/test_transparency.py 27 | tests/nn/pipe_process/test_inplace.py 28 | tests/nn/pipe_process/test_bugs.py 29 | tests/nn/pipe_process/conftest.py 30 | tests/nn/pipe_process/test_rpc.py 31 | tests/nn/model_parallel/test_initialize.py 32 | tests/nn/model_parallel/test_random.py 33 | tests/nn/model_parallel/test_cross_entropy.py 34 | tests/nn/model_parallel/test_layers.py 35 | tests/nn/pipe/test_microbatch.py 36 | tests/nn/pipe/test_checkpoint.py 37 | tests/nn/pipe/test_worker.py 38 | tests/nn/pipe/test_balance.py 39 | tests/nn/pipe/test_pipe.py 40 | tests/nn/pipe/test_transparency.py 41 | tests/nn/pipe/test_inplace.py 42 | tests/nn/pipe/test_copy.py 43 | tests/nn/pipe/test_bugs.py 44 | tests/nn/pipe/conftest.py 45 | tests/nn/pipe/test_pipeline.py 46 | tests/nn/pipe/test_phony.py 47 | tests/nn/pipe/test_deferred_batch_norm.py 48 | tests/nn/pipe/test_dependency.py 49 | tests/nn/pipe/test_stream.py 50 | tests/nn/moe/test_moe_layer.py 51 | tests/nn/moe/test_top2gating.py 52 | tests/nn/data_parallel/test_fsdp_fwd_fwd_bwd_bwd.py 53 | -------------------------------------------------------------------------------- /tests/ci_test_list_3.txt: -------------------------------------------------------------------------------- 1 | tests/nn/checkpoint/test_checkpoint_activations.py 2 | tests/nn/checkpoint/test_checkpoint_activations_norm.py 3 | tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py 4 | tests/nn/misc/test_grad_bucket.py 5 | tests/nn/misc/test_param_bucket.py 6 | tests/nn/misc/test_flatten_params_wrapper.py 7 | tests/nn/data_parallel/test_sharded_ddp_features.py 8 | tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py 9 | tests/nn/pipe/test_parity.py 10 | tests/nn/pipe/skip/test_gpipe.py 11 | tests/nn/pipe/skip/test_verify_skippables.py 12 | tests/nn/pipe/skip/test_stash_pop.py 13 | tests/nn/pipe/skip/test_api.py 14 | tests/nn/pipe/skip/test_leak.py 15 | tests/nn/pipe/skip/test_portal.py 16 | tests/nn/pipe/skip/test_tracker.py 17 | tests/nn/pipe/skip/test_inspect_skip_layout.py 18 | tests/nn/pipe/test_checkpoint_ddp.py 19 | tests/optim/test_single_node_adascale.py 20 | tests/optim/test_adam.py 21 | tests/optim/test_oss.py 22 | tests/optim/test_oss_adascale.py 23 | tests/optim/test_ddp_adascale.py 24 | tests/experimental/nn/data_parallel/test_gossip.py 25 | tests/nn/data_parallel/test_fsdp_hf_transformer_eval.py 26 | tests/experimental/wgit/test_cli.py 27 | tests/experimental/wgit/test_api.py 28 | tests/experimental/wgit/test_pygit.py 29 | tests/experimental/wgit/test_sha1_store.py 30 | tests/experimental/wgit/test_signal_sparsity.py 31 | tests/experimental/wgit/test_signal_sparsity_profiling.py 32 | -------------------------------------------------------------------------------- /tests/ci_test_list_check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Verify that we don't miss any tests. 4 | 5 | find tests -name \*.py -type f| grep -v __init__.py | sort | uniq > /tmp/find.out 6 | cat tests/ci_test_list*.txt | sort | uniq > /tmp/cat.out 7 | 8 | if ! diff /tmp/find.out /tmp/cat.out ; then 9 | echo "Unit test is missing from CI" 10 | echo "See the diff above to fix it" 11 | exit 1 12 | fi 13 | -------------------------------------------------------------------------------- /tests/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/experimental/__init__.py -------------------------------------------------------------------------------- /tests/experimental/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/experimental/nn/__init__.py -------------------------------------------------------------------------------- /tests/experimental/nn/ampnet_pipe_process/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/experimental/nn/ampnet_pipe_process/__init__.py -------------------------------------------------------------------------------- /tests/experimental/optim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/experimental/optim/__init__.py -------------------------------------------------------------------------------- /tests/experimental/optim/test_dynamic_loss_scaler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Testing scaler 8 | """ 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | 15 | from fairscale.experimental.optim.dynamic_loss_scaler import DynamicLossScaler 16 | 17 | 18 | class ManualLinearRegression(nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.linear = nn.Linear(1, 1) 22 | 23 | def forward(self, x): 24 | return self.linear(x) 25 | 26 | 27 | device = "cuda" if torch.cuda.is_available() else "cpu" 28 | 29 | 30 | def _init_dataset(): 31 | np.random.seed(42) 32 | x = np.random.rand(100, 1) 33 | y = 1 + 2 * x + 0.1 * np.random.randn(100, 1) 34 | # Shuffles the indices 35 | idx = np.arange(100) 36 | np.random.shuffle(idx) 37 | # Generates train sets 38 | x_train, y_train = x[idx], y[idx] 39 | x_train_tensor = torch.tensor([x_train]).float().to(device) 40 | y_train_tensor = torch.tensor([y_train]).float().to(device) 41 | return x_train_tensor, y_train_tensor 42 | 43 | 44 | def _train_with_dls(x, y): 45 | scaler = DynamicLossScaler() 46 | torch.manual_seed(42) 47 | lr = 1e-1 48 | n_epochs = 1000 49 | loss_fn = nn.MSELoss(reduction="mean") 50 | model = ManualLinearRegression().to(device) 51 | optimizer = optim.SGD(model.parameters(), lr=lr) 52 | for epoch in range(n_epochs): 53 | optimizer.zero_grad() 54 | model.train() 55 | yhat = model(x) 56 | loss = loss_fn(y, yhat) 57 | scaler.scale(loss).backward() 58 | scaler.step(optimizer) 59 | scaler.update() 60 | return model 61 | 62 | 63 | def test_dls_without_overflow(): 64 | x, y = _init_dataset() 65 | model = _train_with_dls(x, y) 66 | for name, param in model.named_parameters(): 67 | if param.requires_grad: 68 | print(name, param.data) 69 | if name == "linear.weight": 70 | assert (param.data.item() - 2) <= 0.05 71 | if name == "linear.bias": 72 | assert (param.data.item() - 1) <= 0.03 73 | 74 | 75 | # TODO(tmarkstrum): add test case covering check_overflow function 76 | # TODO(tmarkstrum): add test case covering the state_dict, FP16 77 | -------------------------------------------------------------------------------- /tests/experimental/tooling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/experimental/tooling/__init__.py -------------------------------------------------------------------------------- /tests/experimental/wgit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /tests/experimental/wgit/test_signal_sparsity_profiling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import time 7 | 8 | import pytest 9 | import torch 10 | 11 | from fairscale.experimental.wgit.signal_sparsity_profiling import EnergyConcentrationProfile as ECP 12 | from fairscale.fair_dev.testing.testing import objects_are_equal, skip_if_no_cuda 13 | 14 | # Our own tolerance 15 | ATOL = 1e-6 16 | RTOL = 1e-5 17 | 18 | # enable this for debugging. 19 | # torch.set_printoptions(precision=20) 20 | 21 | 22 | @skip_if_no_cuda 23 | def test_nonblocking(): 24 | """Tests cpu runs ahead of the GPU in the measuring process.""" 25 | big = torch.rand(10, 1000, 1000).cuda() 26 | ecp = ECP(dim=2, top_k_percents=[1, 5, 10, 50, 90]) 27 | start = time.time() 28 | out = ecp.measure(big) 29 | out_fft = ecp.measure_fft(big) 30 | cpu_time = time.time() - start 31 | torch.cuda.synchronize() 32 | gpu_time = time.time() - start 33 | assert cpu_time * 5 < gpu_time, f"GPU time should dominate {cpu_time} vs. {gpu_time}" 34 | for o in [out, out_fft]: 35 | # validate the output 36 | p = [x.item() for x in o] 37 | for n, n1 in zip(p, p[1:]): 38 | assert n <= n1 and n >= 0 and n <= 100, f"n={n} n1={n1}" 39 | 40 | 41 | def get_ones(): 42 | """Return test data with ones tensor""" 43 | return ( 44 | 0, 45 | [1, 5, 10, 100], 46 | torch.ones(100), 47 | [torch.tensor(0.01), torch.tensor(0.05), torch.tensor(0.1), torch.tensor(1.0)], 48 | ) 49 | 50 | 51 | def get_dim_0(): 52 | """Test case for dim=0 for 2D input.""" 53 | return ( 54 | 0, 55 | [1, 3, 33, 66, 90], 56 | torch.tensor([0.1, 0.2, 0.1, 0.45]).repeat(100, 1), 57 | [torch.tensor(0.01), torch.tensor(0.03), torch.tensor(0.33), torch.tensor(0.66), torch.tensor(0.9)], 58 | ) 59 | 60 | 61 | @pytest.mark.parametrize( 62 | "dim, percents, in_tensor, out_tensors", 63 | [ 64 | get_ones(), 65 | get_dim_0(), 66 | ], 67 | ) 68 | def test_expected_output(dim, percents, in_tensor, out_tensors): 69 | """Test with a few expected input & outputs.""" 70 | ecp = ECP(dim, percents) 71 | out = ecp.measure(in_tensor) 72 | objects_are_equal(out, out_tensors, raise_exception=True, rtol=RTOL, atol=ATOL) 73 | out_fft = ecp.measure_fft(torch.fft.ifft(in_tensor, dim=dim)) 74 | objects_are_equal(out_fft, out_tensors, raise_exception=True, rtol=RTOL, atol=ATOL) 75 | -------------------------------------------------------------------------------- /tests/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/nn/__init__.py -------------------------------------------------------------------------------- /tests/nn/checkpoint/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/nn/checkpoint/__init__.py -------------------------------------------------------------------------------- /tests/nn/data_parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/nn/data_parallel/__init__.py -------------------------------------------------------------------------------- /tests/nn/data_parallel/test_fsdp_apply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import functools 7 | import unittest 8 | 9 | from parameterized import parameterized 10 | import pytest 11 | import torch.nn as nn 12 | 13 | from fairscale.internal import torch_version 14 | 15 | from .test_fsdp import ( 16 | CONFIG_OPTIONS, 17 | DistributedTest, 18 | NestedWrappedModule, 19 | TransformerWithSharedParams, 20 | rename_test, 21 | spawn_and_init, 22 | ) 23 | 24 | 25 | @pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required") 26 | class TestApply(DistributedTest): 27 | @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) 28 | def test_transformer_weight_init(self, config): 29 | model_init_fn = functools.partial(model_init_and_apply_custom_weight_init, TransformerWithSharedParams) 30 | test_fn = functools.partial(self._test_identical_outputs, model_init_fn, config, lr=0.01) 31 | spawn_and_init(test_fn) 32 | 33 | @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) 34 | def test_nested_wrapped_weight_init(self, config): 35 | model_init_fn = functools.partial(model_init_and_apply_custom_weight_init, NestedWrappedModule) 36 | test_fn = functools.partial(self._test_identical_outputs, model_init_fn, config, lr=0.01) 37 | spawn_and_init(test_fn) 38 | 39 | 40 | def model_init_and_apply_custom_weight_init(model_init_fn, *args, **kwargs): 41 | model = model_init_fn(*args, **kwargs) 42 | model.apply(init_bert_params_) 43 | return model 44 | 45 | 46 | def init_bert_params_(module): 47 | """ 48 | Initialize the weights specific to the BERT Model. 49 | """ 50 | 51 | def normal_(data): 52 | # with FSDP, module params will be on CUDA, so we cast them back to CPU 53 | # so that the RNG is consistent with and without FSDP 54 | data.copy_(data.cpu().normal_(mean=0.0, std=0.02)) 55 | 56 | if isinstance(module, nn.Linear): 57 | normal_(module.weight.data) 58 | if module.bias is not None: 59 | module.bias.data.zero_() 60 | if isinstance(module, nn.Embedding): 61 | normal_(module.weight.data) 62 | if module.padding_idx is not None: 63 | module.weight.data[module.padding_idx].zero_() 64 | if isinstance(module, nn.MultiheadAttention): 65 | normal_(module.in_proj_weight.data) 66 | 67 | 68 | if __name__ == "__main__": 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /tests/nn/data_parallel/test_fsdp_fwd_fwd_bwd_bwd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from fairscale.fair_dev.testing.testing import skip_if_single_gpu, temp_files_ctx 10 | from fairscale.nn import enable_wrap, wrap 11 | from fairscale.nn.data_parallel import FullyShardedDataParallel 12 | 13 | 14 | class FFN(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | self.fc1 = nn.Linear(10, 10) 18 | self.fc2 = nn.Linear(10, 10) 19 | self.relu = nn.ReLU() 20 | 21 | def forward(self, x): 22 | return self.fc2(self.relu(self.fc1(x))) 23 | 24 | 25 | def main(rank, sync_file): 26 | torch.manual_seed(0) 27 | torch.cuda.manual_seed(0) 28 | torch.cuda.set_device(rank) 29 | torch.distributed.init_process_group( 30 | backend="nccl", 31 | init_method=f"file://{sync_file}", 32 | world_size=2, 33 | rank=rank, 34 | ) 35 | ffn = FFN().cuda().half() 36 | 37 | with enable_wrap(wrapper_cls=FullyShardedDataParallel): 38 | model = wrap( 39 | ffn, 40 | process_group=torch.distributed.new_group(), 41 | flatten_parameters=True, 42 | compute_dtype=torch.float16, 43 | ) 44 | 45 | model = model.train() 46 | 47 | # We test this behavior because it might be used by pipelining. 48 | # However, we don't check if the speed (compute/comm overlapping) 49 | # and memory (necessary all-gather & free) are optimal. 50 | losses = [] 51 | for _ in range(3): 52 | x = torch.rand((10, 10)).cuda().half() 53 | out = model(x) 54 | loss = out.sum() 55 | losses.append(loss) 56 | 57 | # Only the last bwd can be outside of no_sync context. 58 | with model.no_sync(): 59 | losses[0].backward() 60 | losses[1].backward() 61 | losses[2].backward() 62 | 63 | 64 | @skip_if_single_gpu 65 | def test_fwd_fwd_bwd_bwd(): 66 | with temp_files_ctx(num=1) as temp_files: 67 | torch.multiprocessing.spawn( 68 | fn=main, 69 | nprocs=2, 70 | args=(temp_files[0],), 71 | join=True, 72 | ) 73 | -------------------------------------------------------------------------------- /tests/nn/data_parallel/test_fsdp_multiple_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pylint: disable=missing-module-docstring 7 | # pylint: disable=missing-class-docstring 8 | # pylint: disable=missing-function-docstring 9 | 10 | """ Test FSDP with different multiple forward of the same module. """ 11 | 12 | import tempfile 13 | 14 | import pytest 15 | import torch 16 | import torch.multiprocessing as mp 17 | from torch.nn import Linear, Module 18 | from torch.optim import SGD 19 | 20 | from fairscale.fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown 21 | from fairscale.internal import torch_version 22 | from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP 23 | from fairscale.nn.data_parallel import TrainingState 24 | 25 | 26 | def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): 27 | result = dist_init(rank, world_size, tempfile_name, unused) 28 | assert result, "Dist init failed" 29 | 30 | assert isinstance(fsdp_config, dict), str(fsdp_config) 31 | 32 | class Model(Module): 33 | def __init__(self): 34 | super().__init__() 35 | self.inner = FSDP(Linear(4, 4), **fsdp_config) 36 | self.outer = Linear(4, 5) 37 | 38 | def forward(self, x): 39 | # Forward twice. 40 | i = self.inner(x) 41 | j = self.inner(x) 42 | return self.outer(i + j) 43 | 44 | model = FSDP(Model(), **fsdp_config).cuda() 45 | optim = SGD(model.parameters(), lr=0.1) 46 | 47 | for _ in range(3): 48 | in_data = torch.rand(64, 4).cuda() 49 | in_data.requires_grad = True 50 | out = model(in_data) 51 | out.sum().backward() 52 | optim.step() 53 | optim.zero_grad() 54 | 55 | model.assert_state(TrainingState.IDLE) 56 | teardown() 57 | 58 | 59 | # We use strings for precision and flatten instead of bool to 60 | # make the pytest output more readable. 61 | @skip_if_single_gpu 62 | @pytest.mark.parametrize("precision", ["full", "mixed"]) 63 | @pytest.mark.parametrize("flatten", ["flatten", "no_flatten"]) 64 | def test1(precision, flatten): 65 | if torch_version() < (1, 6, 0): 66 | pytest.skip("older pytorch doesn't support reduce_scatter") 67 | 68 | temp_file_name = tempfile.mkstemp()[1] 69 | unused = tempfile.mkstemp()[1] 70 | 71 | fsdp_config = {} 72 | fsdp_config["mixed_precision"] = precision == "mixed" 73 | fsdp_config["flatten_parameters"] = flatten == "flatten" 74 | 75 | # Some bugs only show up when we are in world_size > 1 due to sharding changing 76 | # the tensor dimensions. 77 | world_size = 2 78 | mp.spawn( 79 | _test_func, 80 | args=(world_size, fsdp_config, temp_file_name, unused), 81 | nprocs=world_size, 82 | join=True, 83 | ) 84 | -------------------------------------------------------------------------------- /tests/nn/data_parallel/test_fsdp_pre_backward_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pylint: disable=missing-module-docstring 7 | # pylint: disable=missing-class-docstring 8 | # pylint: disable=missing-function-docstring 9 | 10 | """ Test FSDP with pre-backward hook bug. """ 11 | 12 | import pytest 13 | import torch 14 | from torch.nn import Linear, Module 15 | 16 | from fairscale.fair_dev.testing.testing import dist_init, skip_if_no_cuda, teardown, temp_files_ctx 17 | from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP 18 | 19 | 20 | # A fixture to get tempfiles and ensure they are cleaned up. 21 | @pytest.fixture() 22 | def temp_files(): 23 | # dist_init needs 2 files 24 | with temp_files_ctx(2) as files: 25 | yield files 26 | 27 | 28 | @skip_if_no_cuda 29 | def test_pre_backward_hook(temp_files): 30 | """Test FSDP with a model that triggers a pre_backward hook bug.""" 31 | 32 | result = dist_init(rank=0, world_size=1, filename=temp_files[0], filename_rpc=temp_files[1]) 33 | assert result, "Dist init failed" 34 | 35 | class Model(Module): 36 | def __init__(self): 37 | super().__init__() 38 | self.l1 = Linear(4, 4).cuda() 39 | self.l2 = FSDP(Linear(4, 4).cuda()) 40 | self.l3 = Linear(4, 4).cuda() 41 | 42 | def forward(self, x): 43 | x = self.l1(x) 44 | x = self.l2(x) 45 | inner_result = x 46 | x = self.l3(x) 47 | return x, inner_result 48 | 49 | def assert_and_clear_grad(self): 50 | for p in self.parameters(): 51 | assert p.shape in [(4, 4), (4,), (4 * 4 + 4,)], p.shape 52 | assert p.grad is not None 53 | p.grad = None 54 | 55 | model = FSDP(Model(), flatten_parameters=False).cuda() 56 | in_data = torch.rand(1, 4).cuda() 57 | for _ in range(3): 58 | out, _ = model(in_data) 59 | out.sum().backward() 60 | model.assert_and_clear_grad() 61 | 62 | teardown() 63 | -------------------------------------------------------------------------------- /tests/nn/data_parallel/test_fsdp_summon_full_params.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import functools 7 | import gc 8 | import unittest 9 | 10 | from parameterized import parameterized 11 | import pytest 12 | import torch 13 | 14 | from fairscale.internal.version import torch_version 15 | 16 | from .test_fsdp import CONFIG_OPTIONS, DistributedTest, rename_test, spawn_and_init 17 | 18 | 19 | def get_cuda_mem(): 20 | torch.cuda.synchronize() 21 | gc.collect() 22 | return torch.cuda.memory_allocated() 23 | 24 | 25 | @pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required") 26 | class TestMemory(DistributedTest): 27 | @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) 28 | def test_memory(self, config): 29 | spawn_and_init(functools.partial(self._test_memory, config)) 30 | 31 | @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) 32 | def test_memory_volatile(self, config): 33 | spawn_and_init(functools.partial(self._test_memory, config, volatile=True)) 34 | 35 | @classmethod 36 | def _test_memory(self, config, rank, group, volatile=False): 37 | model = self.get_wrapped_model(group, cuda_first=False, config=config) 38 | self._train_for_several_steps(model, 1, autocast=model.mixed_precision) 39 | 40 | mems = [get_cuda_mem()] 41 | 42 | with model.summon_full_params(volatile=volatile): 43 | mems.append(get_cuda_mem()) 44 | assert mems[1] >= mems[0] 45 | 46 | state_dict = model.state_dict() 47 | mems.append(get_cuda_mem()) 48 | assert mems[2] >= mems[1] 49 | 50 | mems.append(get_cuda_mem()) 51 | assert mems[3] <= mems[2] 52 | 53 | del state_dict 54 | mems.append(get_cuda_mem()) 55 | # Any value other than `==` indicates a memory leak. If mems[4] > 56 | # mems[0], that indicates we're not cleaning up params properly in 57 | # summon_full_params. If mems[4] < mems[0], that indicates there's a 58 | # memory leak in _train_for_several_steps. 59 | assert mems[4] == mems[0], f"memory leak detected, {mems[4]} != {mems[0]}" 60 | 61 | 62 | class TestPersistence(DistributedTest): 63 | @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) 64 | def test_non_volatile(self, config): 65 | spawn_and_init(functools.partial(self._test_persistence, config)) 66 | 67 | @classmethod 68 | def _test_persistence(self, config, rank, group, volatile=False): 69 | model = self.get_wrapped_model(group, cuda_first=False, config=config) 70 | 71 | with model.summon_full_params(volatile=False): 72 | model.module.embed_tokens.weight.data.fill_(42) 73 | with model.summon_full_params(): 74 | # non-volatile changes are persisted 75 | assert torch.all(model.module.embed_tokens.weight.data == 42.0) 76 | 77 | 78 | if __name__ == "__main__": 79 | unittest.main() 80 | -------------------------------------------------------------------------------- /tests/nn/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/nn/misc/__init__.py -------------------------------------------------------------------------------- /tests/nn/misc/test_grad_bucket.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import pytest 8 | import torch 9 | 10 | from fairscale.nn.misc import GradBucket 11 | 12 | 13 | def test_grad_values_conserved(): 14 | with torch.no_grad(): # remove a warning 15 | param = torch.rand((2, 3), requires_grad=True) 16 | param.grad = torch.rand(2, 3) 17 | 18 | bucket = GradBucket(10, param.dtype, param.device, -1) 19 | param_ = param.clone() 20 | 21 | bucket.add_grad(param_) 22 | torch.allclose(param.grad, param_.grad) 23 | 24 | 25 | def test_memory_leak(): 26 | with torch.no_grad(): # remove a warning 27 | param = torch.rand((2, 3), requires_grad=True) 28 | param.grad = torch.rand(2, 3) 29 | 30 | bucket = GradBucket(300, param.dtype, param.device, -1) 31 | bucket.add_grad(param) 32 | bucket.shrink() 33 | 34 | storage = bucket.buffer.storage() 35 | # See https://github.com/pytorch/pytorch/pull/59671/ 36 | if hasattr(storage, "nbytes"): 37 | assert storage.nbytes() == 6 * bucket.buffer.element_size() 38 | else: 39 | assert len(storage) == 6 40 | 41 | 42 | def test_max_size(): 43 | with torch.no_grad(): # remove a warning 44 | param = torch.rand((20, 30), requires_grad=True) 45 | param.grad = torch.rand(20, 30) 46 | 47 | bucket = GradBucket(5, param.dtype, param.device, -1) 48 | with pytest.raises(AssertionError): 49 | bucket.add_grad(param) 50 | 51 | 52 | def test_collapse(): 53 | with torch.no_grad(): # remove a warning 54 | size = (5, 6) 55 | param = torch.rand(size, requires_grad=True) 56 | param.grad = torch.rand(size) 57 | 58 | bucket = GradBucket(300, param.dtype, param.device, -1) 59 | bucket.add_grad(param) 60 | bucket.shrink() 61 | bucket.collapse() 62 | 63 | assert bucket.buffer.numel() == 0 64 | assert param.grad is None 65 | bucket.rebuild() 66 | 67 | assert param.grad is not None 68 | torch.allclose(param.grad, torch.zeros(size)) 69 | -------------------------------------------------------------------------------- /tests/nn/misc/test_param_bucket.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import pytest 8 | import torch 9 | 10 | from fairscale.nn.misc import ParamBucket 11 | 12 | 13 | def test_param_values_conserved(): 14 | param = torch.rand((2, 3)) 15 | 16 | bucket = ParamBucket(10, param.dtype, param.device) 17 | param_ = param.clone() 18 | 19 | bucket.add_param(param_) 20 | torch.allclose(param, param_) 21 | 22 | 23 | def test_max_size(): 24 | param = torch.rand((20, 30)) 25 | 26 | bucket = ParamBucket(5, param.dtype, param.device) 27 | with pytest.raises(AssertionError): 28 | bucket.add_param(param) 29 | 30 | 31 | def test_double_check_int(): 32 | param = torch.rand((5, 6)) 33 | 34 | bucket = ParamBucket(300, param.dtype, param.device) 35 | bucket.add_param(param) 36 | 37 | with pytest.raises(AssertionError): 38 | bucket.add_param(param) 39 | 40 | 41 | def test_type_change(): 42 | size = (5, 6) 43 | param = torch.rand(size, requires_grad=True) 44 | param_ = param.clone() 45 | 46 | bucket = ParamBucket(30, param.dtype, param.device) 47 | bucket.add_param(param) 48 | 49 | # Move the bucket to fp16 and back 50 | bucket.to(dtype=torch.float16, device=param.device) 51 | assert bucket.buffer.dtype == torch.float16 52 | 53 | bucket.to(dtype=torch.float32, device=param.device, keep_param_alignment=True) 54 | assert bucket.buffer.dtype == torch.float32 55 | 56 | # Same with the reference tensor 57 | param_.to(dtype=torch.float16) 58 | param_.to(dtype=torch.float32) 59 | 60 | torch.allclose(param, param_) 61 | -------------------------------------------------------------------------------- /tests/nn/model_parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/nn/model_parallel/__init__.py -------------------------------------------------------------------------------- /tests/nn/moe/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/nn/moe/__init__.py -------------------------------------------------------------------------------- /tests/nn/moe/test_top2gating.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import pytest 7 | import torch 8 | 9 | from fairscale.nn import Top2Gate 10 | from fairscale.nn.moe.top2gate import top2gating 11 | 12 | skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") 13 | 14 | 15 | def test_create(): 16 | gate = Top2Gate(4, 8) 17 | 18 | 19 | @skip_if_no_cuda 20 | def test_create_cuda(): 21 | gate = Top2Gate(4, 8).cuda() 22 | 23 | 24 | def do_test_forward(device): 25 | torch.manual_seed(3) 26 | input = torch.randn(12, 4).to(device) 27 | gate = Top2Gate(4, 6).to(device) 28 | capacity = 2 * 12 // 6 29 | l_aux, combine_weights, dispatch_mask = gate(input) 30 | assert pytest.approx(l_aux.item(), rel=0.01) == 0.0267, l_aux 31 | assert combine_weights.shape == (12, 6, 4) 32 | assert dispatch_mask.shape == (12, 6, 4) 33 | assert torch.equal(combine_weights.bool(), dispatch_mask) 34 | assert torch.all(torch.sum(dispatch_mask, axis=(0, 2)) <= capacity) 35 | assert torch.all(combine_weights >= 0.0) 36 | assert torch.all(combine_weights <= 1.0) 37 | weights_sum = torch.sum(combine_weights).item() 38 | assert round(weights_sum) == pytest.approx(weights_sum), weights_sum 39 | # For this random seed, we get 12 slots filled. 40 | assert weights_sum == pytest.approx(12.0), weights_sum 41 | 42 | 43 | def test_forward_cpu(): 44 | do_test_forward("cpu") 45 | 46 | 47 | @skip_if_no_cuda 48 | def test_forward_cuda(): 49 | do_test_forward("cuda") 50 | 51 | 52 | # Verify that top gate is allocated capacity as per Algorithm 1 in GShard paper. 53 | def test_expert1_overflow(): 54 | num_tokens = 8 55 | num_experts = 4 56 | logits = torch.randn(num_tokens, num_experts) 57 | logits[:, 0] = torch.max(logits, dim=1).values + 1 # Force overflow 58 | top1s = torch.argmax(logits, dim=1) 59 | assert top1s.eq(0).all(), top1s 60 | _, __, dispatch_mask = top2gating(logits) 61 | capacity = 2 * num_tokens // num_experts 62 | 63 | for i in range(num_tokens): 64 | if i < capacity: 65 | assert dispatch_mask[i][0][i] 66 | else: 67 | assert not dispatch_mask[i][0].any() 68 | -------------------------------------------------------------------------------- /tests/nn/pipe/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | # tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH. 21 | # See also: https://docs.pytest.org/en/latest/goodpractices.html 22 | -------------------------------------------------------------------------------- /tests/nn/pipe/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import functools 21 | import os 22 | from typing import Any, Callable 23 | 24 | import pytest 25 | import torch 26 | 27 | from fairscale.nn.model_parallel import destroy_model_parallel 28 | 29 | 30 | @pytest.fixture(autouse=True) 31 | def manual_seed_zero() -> None: 32 | torch.manual_seed(0) 33 | 34 | 35 | def cuda_sleep_impl(seconds, cycles_per_ms): 36 | torch.cuda._sleep(int(seconds * cycles_per_ms * 1000)) 37 | 38 | 39 | @pytest.fixture(scope="session") 40 | def cuda_sleep() -> Callable: 41 | # Warm-up CUDA. 42 | torch.empty(1, device="cuda") 43 | 44 | # From test/test_cuda.py in PyTorch. 45 | start = torch.cuda.Event(enable_timing=True) 46 | end = torch.cuda.Event(enable_timing=True) 47 | start.record() 48 | torch.cuda._sleep(1000000) 49 | end.record() 50 | end.synchronize() 51 | cycles_per_ms = 1000000 / start.elapsed_time(end) 52 | 53 | return functools.partial(cuda_sleep_impl, cycles_per_ms=cycles_per_ms) 54 | 55 | 56 | def pytest_report_header() -> str: 57 | return f"torch: {torch.__version__}" 58 | 59 | 60 | def pytest_runtest_setup(item: Any) -> None: 61 | print("setup mpi function called") 62 | 63 | 64 | def pytest_runtest_teardown(item: Any) -> None: 65 | if "OMPI_COMM_WORLD_RANK" in os.environ: 66 | destroy_model_parallel() 67 | if torch.distributed.is_initialized(): 68 | torch.distributed.destroy_process_group() 69 | try: 70 | torch.distributed.rpc.shutdown() 71 | except Exception: 72 | pass 73 | -------------------------------------------------------------------------------- /tests/nn/pipe/skip/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | -------------------------------------------------------------------------------- /tests/nn/pipe/skip/test_api.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import copy 21 | 22 | from torch import nn 23 | 24 | from fairscale.nn.pipe.skip import Namespace, skippable, stash 25 | 26 | 27 | def test_namespace_difference(): 28 | ns1 = Namespace() 29 | ns2 = Namespace() 30 | assert ns1 != ns2 31 | 32 | 33 | def test_namespace_copy(): 34 | ns = Namespace() 35 | assert copy.copy(ns) == ns 36 | assert copy.copy(ns) is not ns 37 | 38 | 39 | def test_skippable_repr(): 40 | @skippable(stash=["hello"]) 41 | class Hello(nn.Module): 42 | def __init__(self): 43 | super().__init__() 44 | self.conv = nn.Conv2d(1, 1, 1) 45 | 46 | def forward(self, x): 47 | yield stash("hello", x) 48 | return self.conv(x) 49 | 50 | m = Hello() 51 | assert ( 52 | repr(m) 53 | == """ 54 | @skippable(Hello( 55 | (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1)) 56 | )) 57 | """.strip() 58 | ) 59 | -------------------------------------------------------------------------------- /tests/nn/pipe/test_copy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import pytest 21 | import torch 22 | 23 | from fairscale.nn.pipe.copy import Copy, Wait 24 | from fairscale.nn.pipe.stream import CPUStream, current_stream, get_device, is_cuda, new_stream, use_stream 25 | 26 | skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") 27 | 28 | 29 | def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None): 30 | device = get_device(prev_stream) 31 | 32 | with use_stream(prev_stream): 33 | if is_cuda(prev_stream): 34 | cuda_sleep(0.5) 35 | x = torch.ones(100, device=device, requires_grad=True) 36 | 37 | (y,) = Copy.apply(prev_stream, next_stream, x) 38 | (y,) = Wait.apply(prev_stream, next_stream, x) 39 | 40 | with use_stream(next_stream): 41 | assert torch.allclose(y.sum(), torch.tensor(100.0, device=device)) 42 | y.norm().backward() 43 | with use_stream(prev_stream): 44 | assert torch.allclose(x.grad.sum(), torch.tensor(10.0, device=device)) 45 | 46 | 47 | def test_copy_wait_cpu_cpu(): 48 | prev_stream = CPUStream 49 | next_stream = CPUStream 50 | _test_copy_wait(prev_stream, next_stream) 51 | 52 | 53 | @skip_if_no_cuda 54 | def test_copy_wait_cpu_cuda(cuda_sleep): 55 | prev_stream = CPUStream 56 | next_stream = current_stream(torch.device("cuda")) 57 | _test_copy_wait(prev_stream, next_stream, cuda_sleep) 58 | 59 | 60 | @skip_if_no_cuda 61 | def test_copy_wait_cuda_cpu(cuda_sleep): 62 | prev_stream = current_stream(torch.device("cuda")) 63 | next_stream = CPUStream 64 | _test_copy_wait(prev_stream, next_stream, cuda_sleep) 65 | 66 | 67 | @skip_if_no_cuda 68 | def test_copy_wait_cuda_cuda(cuda_sleep): 69 | prev_stream = current_stream(torch.device("cuda")) 70 | next_stream = new_stream(torch.device("cuda")) 71 | _test_copy_wait(prev_stream, next_stream, cuda_sleep) 72 | 73 | 74 | def test_wait_multiple_tensors(): 75 | a = torch.rand(1, requires_grad=True) 76 | b = torch.rand(1, requires_grad=True) 77 | 78 | a, b = Wait.apply(CPUStream, CPUStream, a, b) 79 | 80 | assert a.grad_fn is b.grad_fn 81 | assert a.grad_fn.__class__ is Wait._backward_cls 82 | -------------------------------------------------------------------------------- /tests/nn/pipe/test_phony.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import torch 21 | 22 | from fairscale.nn.pipe.phony import get_phony 23 | 24 | 25 | def test_phony_size(): 26 | p = get_phony(torch.device("cpu"), requires_grad=False) 27 | assert p.size() == (1,) 28 | 29 | 30 | def test_phony_requires_grad(): 31 | p1 = get_phony(torch.device("cpu"), requires_grad=True) 32 | p2 = get_phony(torch.device("cpu"), requires_grad=False) 33 | assert p1.requires_grad 34 | assert not p2.requires_grad 35 | 36 | 37 | def test_cached_phony(): 38 | p1 = get_phony(torch.device("cpu"), requires_grad=True) 39 | p2 = get_phony(torch.device("cpu"), requires_grad=True) 40 | assert p1 is p2 41 | 42 | p3 = get_phony(torch.device("cpu"), requires_grad=False) 43 | p4 = get_phony(torch.device("cpu"), requires_grad=False) 44 | assert p3 is p4 45 | 46 | assert p1 is not p3 47 | 48 | 49 | def test_phony_in_autograd_function(): 50 | class Phonify(torch.autograd.Function): 51 | @staticmethod 52 | def forward(ctx, input): 53 | phony = get_phony(input.device, requires_grad=False) 54 | return phony.detach() 55 | 56 | x = torch.rand(1, requires_grad=True) 57 | 58 | p1 = Phonify.apply(x) 59 | p2 = get_phony(torch.device("cpu"), requires_grad=True) 60 | 61 | assert p1 is not p2 62 | assert p1.grad_fn is not None 63 | assert p2.grad_fn is None 64 | -------------------------------------------------------------------------------- /tests/nn/pipe/test_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from fairscale.nn.pipe.pipeline import clock_cycles 21 | 22 | 23 | def test_clock_cycles(): 24 | assert list(clock_cycles(1, 1)) == [[(0, 0)]] 25 | assert list(clock_cycles(1, 3)) == [[(0, 0)], [(0, 1)], [(0, 2)]] 26 | assert list(clock_cycles(3, 1)) == [[(0, 0)], [(1, 0)], [(2, 0)]] 27 | 28 | assert list(clock_cycles(3, 3)) == [ # noqa 29 | [(0, 0)], 30 | [(1, 0), (0, 1)], 31 | [(2, 0), (1, 1), (0, 2)], 32 | [(2, 1), (1, 2)], 33 | [(2, 2)], 34 | ] 35 | 36 | assert list(clock_cycles(4, 2)) == [ # noqa 37 | [(0, 0)], 38 | [(1, 0), (0, 1)], 39 | [(2, 0), (1, 1)], 40 | [(3, 0), (2, 1)], 41 | [(3, 1)], 42 | ] 43 | -------------------------------------------------------------------------------- /tests/nn/pipe/test_transparency.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import torch 21 | from torch import nn 22 | 23 | from fairscale.nn import Pipe 24 | 25 | 26 | def test_simple_linears(): 27 | def sum_grad(parameters): 28 | return sum([p.grad.sum() for p in parameters if p.grad is not None]) 29 | 30 | def zero_grad(parameters): 31 | for p in parameters: 32 | p.grad = None 33 | 34 | inputs = torch.rand(8, 1) 35 | model = nn.Sequential( 36 | nn.Linear(1, 2), 37 | nn.Linear(2, 4), 38 | nn.Linear(4, 2), 39 | nn.Linear(2, 1), 40 | ) 41 | 42 | # Without Pipe 43 | outputs = model(inputs) 44 | loss = outputs.mean() 45 | loss.backward() 46 | 47 | grad_without_pipe = sum_grad(model.parameters()) 48 | 49 | zero_grad(model.parameters()) 50 | 51 | # With Pipe 52 | model = Pipe(model, [2, 2], devices=["cpu", "cpu"], chunks=4) 53 | 54 | outputs = model(inputs) 55 | loss = outputs.mean() 56 | loss.backward() 57 | 58 | grad_with_pipe = sum_grad(model.parameters()) 59 | 60 | # Both grads should be identical. 61 | assert torch.allclose(grad_with_pipe, grad_without_pipe) 62 | -------------------------------------------------------------------------------- /tests/nn/pipe_process/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | # tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH. 21 | # See also: https://docs.pytest.org/en/latest/goodpractices.html 22 | -------------------------------------------------------------------------------- /tests/nn/pipe_process/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import functools 21 | import os 22 | from typing import Any, Callable 23 | 24 | import pytest 25 | import torch 26 | 27 | from fairscale.nn.model_parallel import destroy_model_parallel 28 | 29 | 30 | @pytest.fixture(autouse=True) 31 | def manual_seed_zero() -> None: 32 | torch.manual_seed(0) 33 | 34 | 35 | def cuda_sleep_impl(seconds, cycles_per_ms): 36 | torch.cuda._sleep(int(seconds * cycles_per_ms * 1000)) 37 | 38 | 39 | @pytest.fixture(scope="session") 40 | def cuda_sleep() -> Callable: 41 | # Warm-up CUDA. 42 | torch.empty(1, device="cuda") 43 | 44 | # From test/test_cuda.py in PyTorch. 45 | start = torch.cuda.Event(enable_timing=True) 46 | end = torch.cuda.Event(enable_timing=True) 47 | start.record() 48 | torch.cuda._sleep(1000000) 49 | end.record() 50 | end.synchronize() 51 | cycles_per_ms = 1000000 / start.elapsed_time(end) 52 | 53 | return functools.partial(cuda_sleep_impl, cycles_per_ms=cycles_per_ms) 54 | 55 | 56 | def pytest_report_header() -> str: 57 | return f"torch: {torch.__version__}" 58 | 59 | 60 | def pytest_runtest_setup(item: Any) -> None: 61 | print("setup mpi function called") 62 | 63 | 64 | def pytest_runtest_teardown(item: Any) -> None: 65 | if "OMPI_COMM_WORLD_RANK" in os.environ: 66 | destroy_model_parallel() 67 | if torch.distributed.is_initialized(): 68 | torch.distributed.destroy_process_group() 69 | try: 70 | torch.distributed.rpc.shutdown() 71 | except Exception: 72 | pass 73 | -------------------------------------------------------------------------------- /tests/nn/pipe_process/skip/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/nn/pipe_process/skip/__init__.py -------------------------------------------------------------------------------- /tests/nn/pipe_process/test_transparency.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Copyright 2019 Kakao Brain 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import pytest 21 | import torch 22 | from torch import nn 23 | 24 | from fairscale.fair_dev.testing.testing import get_worker_map, set_random_seed, torch_spawn 25 | from fairscale.nn.pipe import AsyncPipe 26 | 27 | 28 | @torch_spawn([2]) 29 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") 30 | @pytest.mark.parametrize("pipe_class", [AsyncPipe]) 31 | def simple_linears(pipe_class): 32 | def sum_grad(parameters): 33 | return sum([p.grad.sum() for p in parameters if p.grad is not None]) 34 | 35 | def zero_grad(parameters): 36 | for p in parameters: 37 | p.grad = None 38 | 39 | set_random_seed(12345) 40 | inputs = torch.rand(8, 1) 41 | model = nn.Sequential( 42 | nn.Linear(1, 2), 43 | nn.Linear(2, 4), 44 | nn.Linear(4, 2), 45 | nn.Linear(2, 1), 46 | ) 47 | 48 | # Without MultiProcessPipe 49 | outputs = model(inputs) 50 | loss = outputs.mean() 51 | loss.backward() 52 | 53 | grad_without_pipe = [ 54 | sum_grad([*model[0].parameters(), *model[1].parameters()]), 55 | sum_grad([*model[2].parameters(), *model[3].parameters()]), 56 | ] 57 | 58 | ref_without_pipe = [p.grad for p in model.parameters()] 59 | 60 | zero_grad(model.parameters()) 61 | 62 | model = pipe_class(model, [2, 2], worker_map=get_worker_map(), chunks=4) 63 | 64 | outputs = model(inputs) 65 | if model.group.rank() == 1: 66 | loss = outputs.mean() 67 | loss.backward() 68 | grad_with_pipe = sum_grad(model.partition.parameters()) 69 | 70 | # Both grads should be identical. 71 | assert torch.allclose(grad_with_pipe, grad_without_pipe[1]) 72 | else: 73 | model.back_helper(outputs) 74 | grad_with_pipe = sum_grad(model.partition.parameters()) 75 | 76 | # Both grads should be identical. 77 | assert torch.allclose(grad_with_pipe, grad_without_pipe[0]) 78 | torch.distributed.barrier() 79 | -------------------------------------------------------------------------------- /tests/nn/wrap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/nn/wrap/__init__.py -------------------------------------------------------------------------------- /tests/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /tests/run_mpi_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | rpc_tests=$(pytest --collect-only | grep 'Function.*rpc' | cut -d' ' -f 6 | tr -d '>') 4 | 5 | for WORKERS in {1..6}; do 6 | mpirun -n $WORKERS -mca orte_base_help_aggregate 0 python -m pytest tests/nn/pipe_process -k "not rpc" 7 | for test_name in $rpc_tests; do 8 | mpirun -n $WORKERS -mca orte_base_help_aggregate 0 python -m pytest tests/nn/pipe_process -k $test_name 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairscale/146f160241651e1211c4247979f159a4ef43b54a/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/test_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pylint: disable=missing-module-docstring 7 | # pylint: disable=missing-class-docstring 8 | # pylint: disable=missing-function-docstring 9 | 10 | """ Test utility classes from fairscale.utils.parallel """ 11 | 12 | from parameterized import parameterized 13 | import torch 14 | 15 | from fairscale.internal.parallel import chunk_and_pad 16 | 17 | 18 | @parameterized.expand([[num_chunks] for num_chunks in range(1, 33)]) 19 | def test_chunk_and_pad(num_chunks): 20 | max_tensor_size = 256 21 | tensor = torch.zeros(max_tensor_size) 22 | for tensor_size in range(1, max_tensor_size + 1): 23 | tensor_i = tensor[:tensor_size] 24 | chunks = chunk_and_pad(tensor_i, num_chunks) 25 | assert len(chunks) == num_chunks 26 | assert all(len(chunks[0]) == len(chunk) for chunk in chunks) 27 | -------------------------------------------------------------------------------- /tests/utils/test_state_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pylint: disable=missing-module-docstring 7 | # pylint: disable=missing-class-docstring 8 | # pylint: disable=missing-function-docstring 9 | 10 | """ Test utility classes from state_dict.py. """ 11 | 12 | import torch 13 | from torch import nn 14 | 15 | from fairscale.internal.state_dict import find_module_instances, replace_by_prefix_ 16 | 17 | 18 | def test_find_module_instances(): 19 | net = nn.Sequential( 20 | nn.Linear(1, 1), nn.ModuleDict({"ln": nn.LayerNorm(1), "linear": nn.Linear(1, 1)}), nn.LayerNorm(1) 21 | ) 22 | assert find_module_instances(net, nn.LayerNorm) == [("1.ln.", net[1]["ln"]), ("2.", net[2])] 23 | assert find_module_instances(net, nn.Linear) == [("0.", net[0]), ("1.linear.", net[1]["linear"])] 24 | assert find_module_instances(net, nn.Dropout) == [] 25 | assert find_module_instances(net, nn.Sequential) == [("", net)] 26 | 27 | 28 | def test_replace_by_prefix(): 29 | state_dict = {"layer.a": torch.tensor(1), "abc.layer.def": torch.tensor(2), "layer.b": torch.tensor(3)} 30 | replace_by_prefix_(state_dict, "layer.", "module.layer.") 31 | assert state_dict == { 32 | "module.layer.a": torch.tensor(1), 33 | "abc.layer.def": torch.tensor(2), 34 | "module.layer.b": torch.tensor(3), 35 | } 36 | -------------------------------------------------------------------------------- /tests/utils/test_version.py: -------------------------------------------------------------------------------- 1 | from fairscale.internal import torch_version 2 | 3 | 4 | def test_torch_version(): 5 | assert torch_version("") == tuple() 6 | assert torch_version("bad format") == tuple() 7 | assert torch_version("1.9.0") == (1, 9, 0) 8 | assert torch_version("1.10.0a0+gitbc6fc3e") == (1, 10, 0) 9 | assert torch_version("1.7.0+cu102") == (1, 7, 0) 10 | assert torch_version("1.10.0a0+fb") == (1, 10, 0) 11 | --------------------------------------------------------------------------------