├── .coverage ├── .flake8 ├── .github └── workflows │ ├── model_tests.yaml │ └── wait_for_ssh_to_drain.sh ├── .gitignore ├── .gitmodules ├── ARCHITECTURE.md ├── CITATION ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── REFERENCE.md ├── binaries ├── build.py └── upload.py ├── check.sh ├── docs ├── .gitignore ├── Makefile ├── README.md ├── requirements.txt └── source │ ├── _static │ └── css │ │ └── custom.css │ ├── _templates │ ├── autosummary │ │ └── class.rst │ ├── classtemplate.rst │ └── layout.html │ ├── conf.py │ ├── docutils.conf │ └── index.rst ├── examples ├── basic │ ├── example.py │ ├── example_manual_stage.py │ └── example_train.py ├── checkpoint │ └── toy_model.py ├── cpu_init │ ├── README.md │ └── gpt2_cpu_init.py ├── huggingface │ ├── hf_utils.py │ ├── pippy_bert.py │ ├── pippy_blenderbot.py │ ├── pippy_camemBert.py │ ├── pippy_convBert.py │ ├── pippy_deberta.py │ ├── pippy_debertaV2.py │ ├── pippy_distilBert.py │ ├── pippy_electra.py │ ├── pippy_fnet.py │ ├── pippy_gpt2.py │ ├── pippy_gptNeo.py │ ├── pippy_layoutLM.py │ ├── pippy_mbart.py │ ├── pippy_megatronBert.py │ ├── pippy_mobileBert.py │ ├── pippy_opt.py │ ├── pippy_trOCR.py │ ├── pippy_unet.py │ └── pippy_xlnet.py ├── llama │ ├── README.md │ └── pippy_llama.py ├── mixture_of_experts │ └── dist_moe.py ├── profiling │ └── mlp_profiling.py ├── tp+pp │ └── pippy_tp.py └── unrolling │ ├── README.md │ └── pippy_unroll.py ├── format.sh ├── pippy ├── ManualPipelineStage.py ├── ModelSplit.py ├── PipelineSchedule.py ├── _IR.py ├── _PipelineStage.py ├── __init__.py ├── _backward.py ├── _debug.py ├── _unflatten.py ├── _utils.py ├── graphsplit.py ├── microbatch.py └── utilities │ ├── __init__.py │ └── hf_checkpoint.py ├── pyproject.toml ├── requirements.txt ├── setup.py ├── test ├── __init__.py ├── hf_test.py ├── multinode_trainer.slurm ├── test_autosplit.py ├── test_bwd.py ├── test_chunkspec.py ├── test_composability.py ├── test_cpu_init.py ├── test_fwd.py ├── test_grad.py ├── test_interleave.py ├── test_ir.py ├── test_microbatch.py ├── test_optim.py ├── test_pipe.py ├── test_pipe_bwd.py ├── test_pipeline_schedule.py ├── test_pipeline_schedule_e2e.py ├── test_pipeline_stage.py ├── test_skip_conn.py ├── test_stage_backward.py ├── test_transformer.py └── test_unflatten.py └── version.txt /.coverage: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/PiPPy/1bcb2bfb2d6cc4ac2125c0edb37c35585bb9695f/.coverage -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = B,C,E,F,P,T4,W,B9 3 | max-line-length = 120 4 | # C408 ignored because we like the dict keyword argument syntax 5 | # E501 is not flexible enough, we're using B950 instead 6 | ignore = 7 | E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, 8 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying 9 | # to line this up with executable bit 10 | EXE001, 11 | # these ignores are from flake8-bugbear; please fix! 12 | B007,B008, 13 | # these ignores are from flake8-comprehensions; please fix! 14 | C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 15 | per-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950 16 | optional-ascii-coding = True 17 | exclude = 18 | ./.git, 19 | ./build_test_custom_build, 20 | ./build, 21 | ./caffe2, 22 | ./docs/caffe2, 23 | ./docs/cpp/src, 24 | ./docs/src, 25 | ./scripts, 26 | ./test/generated_type_hints_smoketest.py, 27 | ./third_party, 28 | ./torch/include, 29 | ./torch/lib, 30 | ./venv, 31 | ./pippy/unflatten.py, 32 | *.pyi 33 | -------------------------------------------------------------------------------- /.github/workflows/model_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Model Tests 2 | # Run models in `examples` folder 3 | 4 | on: 5 | # Run when any example is changed 6 | pull_request: 7 | paths: 8 | - '.github/workflows/model_tests.yaml' 9 | - 'examples/**' 10 | # Nightly run against pytorch nightly build 11 | schedule: 12 | - cron: "30 11 * * *" # Everyday 11:30 am UTC, i.e. 4:30 am PST 13 | 14 | concurrency: 15 | # Cancel CI on previous commit when a new commit is pushed to the same branch 16 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 17 | cancel-in-progress: true 18 | 19 | defaults: 20 | run: 21 | shell: bash -l -eo pipefail {0} 22 | 23 | jobs: 24 | model_tests_4gpu: 25 | runs-on: linux.g5.12xlarge.nvidia.gpu 26 | strategy: 27 | matrix: 28 | python-version: ['3.10'] 29 | steps: 30 | - name: Check out repo 31 | uses: actions/checkout@v3 32 | - name: Setup conda env 33 | uses: conda-incubator/setup-miniconda@v2 34 | with: 35 | auto-update-conda: true 36 | miniconda-version: "latest" 37 | activate-environment: test 38 | python-version: ${{ matrix.python-version }} 39 | - name: Activate conda env 40 | run: conda activate test 41 | - name: Install dependencies 42 | run: | 43 | pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 44 | - name: Install Transformers for getting models 45 | run: pip install transformers 46 | # - name: Install Diffusers for getting models 47 | # run: pip install diffusers 48 | - name: Run GPT2 49 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_gpt2.py 50 | - name: Run BERT 51 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_bert.py 52 | - name: Run blenderbot 53 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_blenderbot.py 54 | - name: Run camemBert 55 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_camemBert.py 56 | - name: Run convBert 57 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_convBert.py 58 | - name: Run deberta 59 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_deberta.py 60 | # - name: Run debertaV2 61 | # run: torchrun --nproc-per-node 4 examples/huggingface/pippy_debertaV2.py 62 | - name: Run distilBert 63 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_distilBert.py 64 | - name: Run electra 65 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_electra.py 66 | - name: Run fnet 67 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_fnet.py 68 | - name: Run gptNeo 69 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_gptNeo.py 70 | - name: Run layoutLM 71 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_layoutLM.py 72 | - name: Run mbart 73 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_mbart.py 74 | - name: Run megatronBert 75 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_megatronBert.py 76 | - name: Run mobileBert 77 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_mobileBert.py 78 | # - name: Run opt 79 | # run: torchrun --nproc-per-node 2 examples/huggingface/pippy_opt.py 80 | # - name: Run trOCR 81 | # run: torchrun --nproc-per-node 4 examples/huggingface/pippy_trOCR.py 82 | # - name: Run unet 83 | # run: torchrun --nproc-per-node 2 examples/huggingface/pippy_unet.py 84 | - name: Run xlnet 85 | run: torchrun --nproc-per-node 4 examples/huggingface/pippy_xlnet.py 86 | - name: Test CPU init + GPU run 87 | run: torchrun --nproc-per-node 4 examples/cpu_init/gpt2_cpu_init.py 88 | -------------------------------------------------------------------------------- /.github/workflows/wait_for_ssh_to_drain.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eou pipefail 4 | 5 | echo "Holding runner for 2 hours until all ssh sessions have logged out" 6 | for _ in $(seq 1440); do 7 | # Break if no ssh session exists anymore 8 | if [ "$(who)" = "" ]; then 9 | break 10 | fi 11 | echo "." 12 | sleep 5 13 | done 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | build 3 | pippy.egg-info 4 | torchpippy.egg-info 5 | pippy/version.py 6 | dist 7 | .idea/ 8 | .pyre/ 9 | **/*.json 10 | **/*.out 11 | **/.DS_STORE 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "docs/src/pytorch-sphinx-theme"] 2 | path = docs/src/pytorch-sphinx-theme 3 | url = https://github.com/pytorch/pytorch_sphinx_theme.git 4 | -------------------------------------------------------------------------------- /CITATION: -------------------------------------------------------------------------------- 1 | @Misc{pippy2022, 2 | author = {James Reed, Pavel Belevich, Ke Wen}, 3 | title = {PiPPy: Pipeline Parallelism for PyTorch}, 4 | howpublished = {\url{https://github.com/pytorch/PiPPy}}, 5 | year = {2022} 6 | } 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing 2 | This project has been upstreamed to [PyTorch](https://github.com/pytorch/pytorch) under [`torch.distributed.pipelining`](https://github.com/pytorch/pytorch/tree/main/torch/distributed/pipelining). All future development will happen in PyTorch. Please file issues and pull requests directly to PyTorch following these [contributing guidelines](https://github.com/pytorch/pytorch/blob/main/CONTRIBUTING.md). 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Meta Platforms, Inc. and its affiliates. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, 4 | are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, 7 | this list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | 3. Neither the name of the copyright holder nor the names of its contributors 14 | may be used to endorse or promote products derived from this software 15 | without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 21 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 23 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 24 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 25 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /REFERENCE.md: -------------------------------------------------------------------------------- 1 | # Advanced: Pipeline Schedules 2 | 3 | Pipeline parallel training of deep neural networks is _bidirectional_ since training requires running both forward- and back-propagation of the network. As a result, multiple items of work may be ready to run on a pipeline stage at a given time. The problem of selecting between these work items is known as _scheduling_, and a specific policy for selecting work-items is known as a _pipeline schedule_. 4 | 5 | PiPPy provides both off-the-shelf pipeline schedules as described in the research literature as well as a programmable interface for creating new schedules. The schedules include: 6 | 7 | * Fill-Drain. Fill-drain is a schedule that executes all forward microbatches before executing any backward microbatches. This is the "standard" schedule used in GPipe (Huang, 2018). 8 | 9 | * 1F1B (one forward, one backward) is a schedule that provides good hardware utilization as well as limits the amount of memory needed on a stage. At steady-state, a pipeline stage will alternate between processing forward and backward micro-batches. 1F1B was introduced in its asynchronous form in (Harlap, 2018) and in its synchronous form in (Narayanan, 2021). 10 | 11 | * Interleaved 1F1B. Interleaved 1F1B is a variant of 1F1B that divides the program into smaller chunks and assigns multiple chunks per stage in a wrap-around fashion. Interleaving improves pipeline throughput with similar memory characteristics to 1F1B. Interleaved 1F1B was introduced by (Narayanan, 2021). 12 | 13 | # Future Work 14 | 15 | Future work on PiPPy includes: 16 | 17 | * Increasing automation. We aim to develop automated systems that can alleviate the burden of the user to specify things such as the batch dimension or pipeline split points. Automatic, optimal splitting of a program into balanced pipeline stages is an interesting research field with advances in the deep learning systems field (e.g. Zheng, 2022) and adjacent fields such as high-level synthesis for digital design (e.g. Zaretsky, 2007). 18 | * Expanding to more forms of parallelism. PiPPy is our first foray into compiler-mediated distribution of PyTorch programs. We would like to explore expanding the analysis and partitioning capabilities enabled by a compiler stack to other forms of parallelism, including data parallelism, model parallelism, and MoE parallelism. Such automation is a rich area of research that we would like to contribute to. 19 | 20 | # References 21 | 22 | * Chi-Chung Chen, Chia-Lin Yang, & Hsiang-Yun Cheng (2018). Efficient and Robust Parallel DNN Training through Model Parallelism on Multi-GPU Platform. CoRR, abs/1809.02839. 23 | * Geng, J., Li, D., & Wang, S. (2019). ElasticPipe: An Efficient and Dynamic Model-Parallel Solution to DNN Training. In Proceedings of the 10th Workshop on Scientific Cloud Computing (pp. 5–9). Association for Computing Machinery. 24 | * Lei Guan and Wotao Yin and Dongsheng Li and Xicheng Lu (2019). XPipe: Efficient Pipeline Model Parallelism for Multi-GPU DNN Training. CoRR, abs/1911.04610. 25 | * Aaron Harlap and Deepak Narayanan and Amar Phanishayee and Vivek Seshadri and Nikhil R. Devanur and Gregory R. Ganger and Phillip B. Gibbons (2018). PipeDream: Fast and Efficient Pipeline Parallel DNN Training. CoRR, abs/1806.03377. 26 | *Yanping Huang and Yonglong Cheng and Dehao Chen and HyoukJoong Lee and Jiquan Ngiam and Quoc V. Le and Zhifeng Chen (2018). GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism. CoRR, abs/1811.06965. 27 | * Chiheon Kim and Heungsub Lee and Myungryong Jeong and Woonhyuk Baek and Boogeon Yoon and Ildoo Kim and Sungbin Lim and Sungwoong Kim (2020). torchgpipe: On-the-fly Pipeline Parallelism for Training Giant Models. CoRR, abs/2004.09910. 28 | * Atli Kosson and Vitaliy Chiley and Abhinav Venigalla and Joel Hestness and Urs Köster (2020). Pipelined Backpropagation at Scale: Training Large Models without Batches. CoRR, abs/2003.11666. 29 | * Deepak Narayanan and Amar Phanishayee and Kaiyu Shi and Xie Chen and Matei Zaharia (2020). Memory-Efficient Pipeline-Parallel DNN Training. CoRR, abs/2006.09503. 30 | * Deepak Narayanan and Mohammad Shoeybi and Jared Casper and Patrick LeGresley and Mostofa Patwary and Vijay Korthikanti and Dmitri Vainbrand and Prethvi Kashinkunti and Julie Bernauer and Bryan Catanzaro and Amar Phanishayee and Matei Zaharia (2021). Efficient Large-Scale Language Model Training on GPU Clusters. CoRR, abs/2104.04473. 31 | * Petrowski, A., Dreyfus, G., & Girault, C. (1993). Performance analysis of a pipelined backpropagation parallel algorithm. IEEE Transactions on Neural Networks, 4(6), 970-981. 32 | * Bowen Yang and Jian Zhang and Jonathan Li and Christopher Ré and Christopher R. Aberger and Christopher De Sa (2019). PipeMare: Asynchronous Pipeline Parallel DNN Training. CoRR, abs/1910.05124. 33 | * Lianmin Zheng, Zhuohan Li, Hao Zhang, Yonghao Zhuang, Zhifeng Chen, Yanping Huang, Yida Wang, Yuanzhong Xu, Danyang Zhuo, Joseph E. Gonzalez, & Ion Stoica (2022). Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning. CoRR, abs/2201.12023. 34 | * D. C. Zaretsky, G. Mittal, R. P. Dick and P. Banerjee, "Balanced Scheduling and Operation Chaining in High-Level Synthesis for FPGA Designs," 8th International Symposium on Quality Electronic Design (ISQED'07), 2007, pp. 595-601, doi: 10.1109/ISQED.2007.41. 35 | * Lai, Z., Li, S., Tang, X., Ge, K., Liu, W., Duan, Y., Qiao, L., & Li, D. (2022). Merak: A Efficient Distributed DNN Training Framework with Automated 3D Parallelism for Giant Foundation Models. arXiv preprint arXiv:2206.04959. 36 | -------------------------------------------------------------------------------- /binaries/build.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import sys 5 | 6 | 7 | """ 8 | Instructions: 9 | 10 | WARNING: 11 | Please make sure the "build" folder and the "dist" folder are cleaned before build. 12 | You can achieve that by running: 13 | `python setup.py clean` 14 | 15 | To build wheel file with Git hash, please run: 16 | `python build.py` 17 | 18 | To build wheel file with release version only, please run: 19 | `VERSION_NO_GIT=1 python build.py` 20 | """ 21 | 22 | # To help discover local modules 23 | REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") 24 | sys.path.append(REPO_ROOT) 25 | 26 | 27 | def build_dist_whl(args): 28 | """ 29 | Function to build the wheel files for PiPPy 30 | """ 31 | 32 | print("## Started pippy build") 33 | create_wheel_cmd = "python setup.py bdist_wheel " 34 | 35 | os.chdir(REPO_ROOT) 36 | 37 | # Build wheel 38 | print( 39 | f"## In directory: {os.getcwd()} | Executing command: {create_wheel_cmd}" 40 | ) 41 | 42 | if not args.dry_run: 43 | build_exit_code = os.system(create_wheel_cmd) 44 | # If any one of the steps fail, exit with error 45 | if build_exit_code != 0: 46 | sys.exit(f"## PiPPy build Failed !") 47 | 48 | 49 | def build(args): 50 | dist_dir = os.path.join(REPO_ROOT, "dist") 51 | 52 | # Detect whether old build exists 53 | # If any, stop 54 | if os.path.exists(dist_dir): 55 | raise RuntimeError( 56 | f"dist folder already exist at {dist_dir}. Please run: " 57 | "`python setup.py clean` " 58 | "to clean existing builds." 59 | ) 60 | 61 | # Build dist wheel files 62 | build_dist_whl(args) 63 | 64 | pippy_wheel_path = os.path.join(dist_dir, "*.whl") 65 | if not args.dry_run: 66 | # `glob.glob` returns a list of files that matches the path having wildcards 67 | pippy_wheel_path = glob.glob(pippy_wheel_path) 68 | 69 | print(f"## PiPPy wheel location: {pippy_wheel_path}") 70 | 71 | 72 | if __name__ == "__main__": 73 | parser = argparse.ArgumentParser( 74 | description="Build wheel package for pippy" 75 | ) 76 | 77 | parser.add_argument( 78 | "--dry-run", 79 | action="store_true", 80 | help="print the commands that will be run without running them", 81 | ) 82 | 83 | args = parser.parse_args() 84 | 85 | build(args) 86 | -------------------------------------------------------------------------------- /binaries/upload.py: -------------------------------------------------------------------------------- 1 | #! /usr/env/bin 2 | import argparse 3 | import glob 4 | import os 5 | import sys 6 | import subprocess 7 | 8 | 9 | """ 10 | Instructions: 11 | 12 | Make sure you have installed the following packages before running this script: 13 | `pip install twine` 14 | 15 | Make sure you have cleaned and then built wheel files locally: 16 | see instructions in `build.py` 17 | 18 | To upload to pypi, run: 19 | `python upload.py --upload` 20 | Then copy and paste the pypi token when prompted. 21 | """ 22 | 23 | 24 | # To help discover local modules 25 | REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") 26 | sys.path.append(REPO_ROOT) 27 | 28 | 29 | def exe_cmd(cmd, dry_run=True): 30 | if dry_run: 31 | print(f"Executing command: {cmd}") 32 | else: 33 | try: 34 | subprocess.run([cmd], shell=True, check=True) 35 | except subprocess.CalledProcessError as e: 36 | raise (e) 37 | 38 | 39 | def upload_pypi_packages(args, WHL_PATHS): 40 | """ 41 | Takes a list of path values and uploads them to pypi using twine, using token stored in environment variable 42 | """ 43 | dry_run = not args.upload 44 | 45 | # Note: TWINE_USERNAME and TWINE_PASSWORD are expected to be set in the environment 46 | options = "--username __token__ " 47 | 48 | if args.test_pypi: 49 | options += "--repository-url https://test.pypi.org/legacy/ " 50 | # TODO: 51 | # maybe "--repository testpypi " works the same (and shorter)? 52 | # Ref: https://packaging.python.org/en/latest/tutorials/packaging-projects/#uploading-the-distribution-archives 53 | 54 | for dist_path in WHL_PATHS: 55 | cmd = "twine upload " + options + f" {dist_path}/*" 56 | exe_cmd(cmd, dry_run) 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser( 61 | description="Upload pypi packages for PiPPy" 62 | ) 63 | 64 | parser.add_argument( 65 | "--upload", 66 | action="store_true", 67 | required=False, 68 | help="Actually upload packages; otherwise dry run", 69 | ) 70 | 71 | parser.add_argument( 72 | "--test-pypi", 73 | action="store_true", 74 | help="Upload to test.pypi instead of pypi", 75 | ) 76 | 77 | args = parser.parse_args() 78 | 79 | PACKAGES = ["pippy"] 80 | 81 | if args.upload: 82 | PiPPY_WHEEL_PATH = glob.glob(os.path.join(REPO_ROOT, "dist"))[0] 83 | else: 84 | PiPPY_WHEEL_PATH = os.path.join(REPO_ROOT, "dist") 85 | 86 | WHL_PATHS = [PiPPY_WHEEL_PATH] 87 | 88 | upload_pypi_packages(args, WHL_PATHS) 89 | -------------------------------------------------------------------------------- /check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function usage() { 4 | echo 2>&1 <&1 37 | usage 38 | exit 1; 39 | ;; 40 | esac 41 | done 42 | 43 | if (( KEEP_GOING == 0 )); then 44 | set -e 45 | fi 46 | 47 | 48 | RETVAL=0 49 | 50 | if (( SKIP_FORMAT == 0 )); then 51 | echo; echo "Running format check ..." 52 | ufmt diff pippy/*.py test/*.py 53 | (( RETVAL |= $? )) 54 | fi 55 | 56 | if (( SKIP_PYRE == 0 )); then 57 | echo; echo "Running pyre ..." 58 | pyre check 59 | (( RETVAL |= $? )) 60 | fi 61 | 62 | echo; echo "Running flake8 ..." 63 | flake8 pippy 64 | (( RETVAL |= $? )) 65 | 66 | echo; echo "Running mypy ..." 67 | mypy --follow-imports=skip pippy 68 | (( RETVAL |= $? )) 69 | 70 | echo; echo "Running pylint ..." 71 | pylint --disable=all --enable=unused-import $(git ls-files '*.py') 72 | (( RETVAL |= $? )) 73 | 74 | exit $RETVAL 75 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | node_modules/ 2 | package-lock.json 3 | package.json 4 | yarn.lock 5 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS ?= -j auto -WT --keep-going 6 | SPHINXBUILD ?= sphinx-build 7 | SPHINXPROJ ?= functorch 8 | SOURCEDIR ?= source 9 | BUILDDIR ?= build 10 | PYCMD ?= python 11 | 12 | # Put it first so that "make" without argument is like "make help". 13 | help: 14 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 15 | 16 | docset: html 17 | doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url https://pytorch.org/docs/ --force $(BUILDDIR)/html/ 18 | 19 | # Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution. 20 | cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png 21 | convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png 22 | 23 | html-stable: 24 | # stable differs from `make html` in two ways: 25 | # 1) The stable logo is used instead of the unstable logo 26 | # 2) There will not be a link to the stable docs. 27 | # See conf.py for more details. 28 | RELEASE=1 make html 29 | 30 | .PHONY: help Makefile docset 31 | 32 | # Catch-all target: route all unknown targets to Sphinx using the new 33 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 34 | %: Makefile 35 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 36 | 37 | clean: 38 | @echo "Removing everything under 'build' and 'source/generated'.." 39 | @rm -rf $(BUILDDIR)/html/ $(BUILDDIR)/doctrees $(SOURCEDIR)/generated 40 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | pippy docs build 2 | -------------------- 3 | 4 | ## Build Locally 5 | 6 | Install requirements: 7 | ``` 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | One may also need to install [pandoc](https://pandoc.org/installing.html). On Linux we can use: `sudo apt-get install pandoc`. Or using `conda` we can use: `conda install -c conda-forge pandoc`. 12 | 13 | To run the docs build: 14 | ``` 15 | make html 16 | ``` 17 | 18 | Check out the output files in `build/html`. 19 | 20 | ## Deploy 21 | 22 | The pippy docs website does not updated automatically. We need to periodically regenerate it. 23 | 24 | You need write permissions to pippy to do this. We use GitHub Pages to serve docs. 25 | 26 | 1. Build the docs 27 | 2. Save the build/html folder somewhere 28 | 3. Checkout the branch `gh-pages`. 29 | 4. Delete the contents of the branch and replace it with the build/html folder. `index.html` should be at the root. 30 | 5. Commit the changes and push the changes to the `gh-pages` branch. 31 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==3.5.4 2 | docutils==0.16 3 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 4 | sphinxcontrib.katex 5 | sphinx_copybutton>=0.3.1 6 | nbsphinx 7 | IPython 8 | # Required for nbsphinx: I don't think these can be installed via pip 9 | # conda install -c conda-forge pandoc 10 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | .codeblock-height-limiter { 2 | max-height: 500px; 3 | overflow: scroll; 4 | } 5 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autoclass:: {{ name }} 9 | :inherited-members: 10 | :members: 11 | 12 | .. autogenerated from source/_templates/autosummary/class.rst 13 | -------------------------------------------------------------------------------- /docs/source/_templates/classtemplate.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | 11 | 12 | .. 13 | autogenerated from source/_templates/classtemplate.rst 14 | note it does not have :inherited-members: 15 | -------------------------------------------------------------------------------- /docs/source/docutils.conf: -------------------------------------------------------------------------------- 1 | [html writers] 2 | table_style: colwidths-auto # Necessary for the table generated by autosummary to look decent 3 | -------------------------------------------------------------------------------- /examples/basic/example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # Minimal effort to run this code: 3 | # $ torchrun --nproc-per-node 3 example.py 4 | 5 | import os 6 | import torch 7 | from pippy import pipeline, SplitPoint, ScheduleGPipe, PipelineStage 8 | 9 | in_dim = 512 10 | layer_dims = [512, 1024, 256] 11 | out_dim = 10 12 | 13 | # Single layer definition 14 | class MyNetworkBlock(torch.nn.Module): 15 | def __init__(self, in_dim, out_dim): 16 | super().__init__() 17 | self.lin = torch.nn.Linear(in_dim, out_dim) 18 | 19 | def forward(self, x): 20 | x = self.lin(x) 21 | x = torch.relu(x) 22 | return x 23 | 24 | 25 | # Full model definition 26 | class MyNetwork(torch.nn.Module): 27 | def __init__(self): 28 | super().__init__() 29 | self.num_layers = len(layer_dims) 30 | 31 | prev_dim = in_dim 32 | # Add layers one by one 33 | for i, dim in enumerate(layer_dims): 34 | super().add_module(f"layer{i}", MyNetworkBlock(prev_dim, dim)) 35 | prev_dim = dim 36 | 37 | # Final output layer (with OUT_DIM projection classes) 38 | self.output_proj = torch.nn.Linear(layer_dims[-1], out_dim) 39 | 40 | def forward(self, x): 41 | for i in range(self.num_layers): 42 | layer = getattr(self, f"layer{i}") 43 | x = layer(x) 44 | 45 | return self.output_proj(x) 46 | 47 | 48 | # To run a distributed training job, we must launch the script in multiple 49 | # different processes. We are using `torchrun` to do so in this example. 50 | # `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`, 51 | # which represent the index of this process within the set of processes and 52 | # the total number of processes, respectively. 53 | # 54 | # To learn more about `torchrun`, see 55 | # https://pytorch.org/docs/stable/elastic/run.html 56 | 57 | torch.manual_seed(0) 58 | rank = int(os.environ["RANK"]) 59 | world_size = int(os.environ["WORLD_SIZE"]) 60 | 61 | # Figure out device to use 62 | if torch.cuda.is_available(): 63 | device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") 64 | else: 65 | device = torch.device("cpu") 66 | 67 | # Create the model 68 | mn = MyNetwork().to(device) 69 | 70 | split_spec = { 71 | "layer0": SplitPoint.END, 72 | "layer1": SplitPoint.END, 73 | } 74 | 75 | batch_size = 32 76 | example_input = torch.randn(batch_size, in_dim, device=device) 77 | chunks = 4 78 | 79 | pipe = pipeline(mn, chunks, example_args=(example_input,), split_spec=split_spec) 80 | 81 | if rank == 0: 82 | print(" pipe ".center(80, "*")) 83 | print(pipe) 84 | print(" stage 0 ".center(80, "*")) 85 | print(pipe.split_gm.submod_0) 86 | print(" stage 1 ".center(80, "*")) 87 | print(pipe.split_gm.submod_1) 88 | print(" stage 2 ".center(80, "*")) 89 | print(pipe.split_gm.submod_2) 90 | 91 | 92 | # Initialize distributed environment 93 | import torch.distributed as dist 94 | 95 | dist.init_process_group(rank=rank, world_size=world_size) 96 | 97 | # Pipeline stage is our main pipeline runtime. It takes in the pipe object, 98 | # the rank of this process, and the device. 99 | stage = PipelineStage(pipe, rank, device) 100 | 101 | # Attach to a schedule 102 | schedule = ScheduleGPipe(stage, chunks) 103 | 104 | # Input data 105 | x = torch.randn(batch_size, in_dim, device=device) 106 | 107 | # Run the pipeline with input `x`. Divide the batch into 4 micro-batches 108 | # and run them in parallel on the pipeline 109 | if rank == 0: 110 | schedule.step(x) 111 | else: 112 | output = schedule.step() 113 | 114 | if rank == world_size - 1: 115 | # Run the original code and get the output for comparison 116 | reference_output = mn(x) 117 | # Compare numerics of pipeline and original model 118 | torch.testing.assert_close(output, reference_output) 119 | print(" Pipeline parallel model ran successfully! ".center(80, "*")) 120 | -------------------------------------------------------------------------------- /examples/basic/example_manual_stage.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # Minimal effort to run this code: 3 | # $ torchrun --nproc-per-node 3 example_manual_stage.py 4 | 5 | import os 6 | import torch 7 | from pippy import ScheduleGPipe, ManualPipelineStage 8 | 9 | in_dim = 512 10 | layer_dims = [512, 1024, 256] 11 | out_dim = 10 12 | 13 | # Single layer definition 14 | class MyNetworkBlock(torch.nn.Module): 15 | def __init__(self, in_dim, out_dim): 16 | super().__init__() 17 | self.lin = torch.nn.Linear(in_dim, out_dim) 18 | 19 | def forward(self, x): 20 | x = self.lin(x) 21 | x = torch.relu(x) 22 | return x 23 | 24 | 25 | # Model chunk definition 26 | class ModelChunk0(torch.nn.Module): 27 | def __init__(self): 28 | super().__init__() 29 | self.layer0 = MyNetworkBlock(in_dim, layer_dims[0]) 30 | 31 | def forward(self, x): 32 | return self.layer0(x) 33 | 34 | class ModelChunk1(torch.nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | self.layer1 = MyNetworkBlock(layer_dims[0], layer_dims[1]) 38 | 39 | def forward(self, x): 40 | return self.layer1(x) 41 | 42 | class ModelChunk2(torch.nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | self.layer2 = MyNetworkBlock(layer_dims[1], layer_dims[2]) 46 | # Final output layer (with OUT_DIM projection classes) 47 | self.output_proj = torch.nn.Linear(layer_dims[2], out_dim) 48 | 49 | def forward(self, x): 50 | x = self.layer2(x) 51 | return self.output_proj(x) 52 | 53 | # To run a distributed training job, we must launch the script in multiple 54 | # different processes. We are using `torchrun` to do so in this example. 55 | # `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`, 56 | # which represent the index of this process within the set of processes and 57 | # the total number of processes, respectively. 58 | # 59 | # To learn more about `torchrun`, see 60 | # https://pytorch.org/docs/stable/elastic/run.html 61 | 62 | torch.manual_seed(0) 63 | rank = int(os.environ["RANK"]) 64 | world_size = int(os.environ["WORLD_SIZE"]) 65 | 66 | # Initialize distributed environment 67 | import torch.distributed as dist 68 | 69 | dist.init_process_group(rank=rank, world_size=world_size) 70 | 71 | # Figure out device to use 72 | if torch.cuda.is_available(): 73 | device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") 74 | else: 75 | device = torch.device("cpu") 76 | 77 | # Create the model chunks 78 | batch_size = 32 79 | example_input_stage_0 = torch.randn(batch_size, in_dim, device=device) 80 | example_input_stage_1 = torch.randn(batch_size, layer_dims[0], device=device) 81 | example_input_stage_2 = torch.randn(batch_size, layer_dims[1], device=device) 82 | chunks = 4 83 | 84 | rank_model_and_input = { 85 | 0: (ModelChunk0(), example_input_stage_0), 86 | 1: (ModelChunk1(), example_input_stage_1), 87 | 2: (ModelChunk2(), example_input_stage_2), 88 | } 89 | 90 | # Pipeline stage is our main pipeline runtime. It takes in the pipe object, 91 | # the rank of this process, and the device. 92 | if rank in rank_model_and_input: 93 | model, example_input = rank_model_and_input[rank] 94 | stage = ManualPipelineStage( 95 | model, 96 | rank, 97 | world_size, 98 | device, 99 | chunks, 100 | example_input, 101 | ) 102 | print(f"Rank {rank} initialized") 103 | else: 104 | raise RuntimeError("Invalid rank") 105 | 106 | # Attach to a schedule 107 | schedule = ScheduleGPipe(stage, chunks) 108 | 109 | # Input data 110 | x = torch.randn(batch_size, in_dim, device=device) 111 | 112 | # Run the pipeline with input `x`. Divide the batch into 4 micro-batches 113 | # and run them in parallel on the pipeline 114 | if rank == 0: 115 | schedule.step(x) 116 | else: 117 | output = schedule.step() 118 | 119 | print(f"Rank {rank} finished") 120 | -------------------------------------------------------------------------------- /examples/basic/example_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # Minimal effort to run this code: 3 | # $ torchrun --nproc-per-node 3 example_train.py 4 | 5 | import os 6 | import torch 7 | from pippy import SplitPoint, ScheduleGPipe, PipelineStage 8 | 9 | in_dim = 512 10 | layer_dims = [512, 1024, 256] 11 | out_dim = 10 12 | 13 | # Single layer definition 14 | class MyNetworkBlock(torch.nn.Module): 15 | def __init__(self, in_dim, out_dim): 16 | super().__init__() 17 | self.lin = torch.nn.Linear(in_dim, out_dim) 18 | 19 | def forward(self, x): 20 | x = self.lin(x) 21 | x = torch.relu(x) 22 | return x 23 | 24 | 25 | # Full model definition 26 | class MyNetwork(torch.nn.Module): 27 | def __init__(self): 28 | super().__init__() 29 | self.num_layers = len(layer_dims) 30 | 31 | prev_dim = in_dim 32 | # Add layers one by one 33 | for i, dim in enumerate(layer_dims): 34 | super().add_module(f"layer{i}", MyNetworkBlock(prev_dim, dim)) 35 | prev_dim = dim 36 | 37 | # Final output layer (with OUT_DIM projection classes) 38 | self.output_proj = torch.nn.Linear(layer_dims[-1], out_dim) 39 | 40 | def forward(self, x): 41 | for i in range(self.num_layers): 42 | layer = getattr(self, f"layer{i}") 43 | x = layer(x) 44 | 45 | return self.output_proj(x) 46 | 47 | 48 | # To run a distributed training job, we must launch the script in multiple 49 | # different processes. We are using `torchrun` to do so in this example. 50 | # `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`, 51 | # which represent the index of this process within the set of processes and 52 | # the total number of processes, respectively. 53 | # 54 | # To learn more about `torchrun`, see 55 | # https://pytorch.org/docs/stable/elastic/run.html 56 | 57 | torch.manual_seed(0) 58 | rank = int(os.environ["RANK"]) 59 | world_size = int(os.environ["WORLD_SIZE"]) 60 | 61 | # Figure out device to use 62 | if torch.cuda.is_available(): 63 | device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") 64 | else: 65 | device = torch.device("cpu") 66 | 67 | # Create the model 68 | mn = MyNetwork().to(device) 69 | 70 | split_spec = { 71 | "layer0": SplitPoint.END, 72 | "layer1": SplitPoint.END, 73 | } 74 | 75 | batch_size = 32 76 | example_input = torch.randn(batch_size, in_dim, device=device) 77 | chunks = 4 78 | 79 | from pippy import pipeline 80 | pipe = pipeline(mn, chunks, example_args=(example_input,), split_spec=split_spec) 81 | 82 | if rank == 0: 83 | print(" pipe ".center(80, "*")) 84 | print(pipe) 85 | print(" stage 0 ".center(80, "*")) 86 | print(pipe.split_gm.submod_0) 87 | print(" stage 1 ".center(80, "*")) 88 | print(pipe.split_gm.submod_1) 89 | print(" stage 2 ".center(80, "*")) 90 | print(pipe.split_gm.submod_2) 91 | 92 | 93 | # Initialize distributed environment 94 | import torch.distributed as dist 95 | 96 | dist.init_process_group(rank=rank, world_size=world_size) 97 | 98 | # Pipeline stage is our main pipeline runtime. It takes in the pipe object, 99 | # the rank of this process, and the device. 100 | stage = PipelineStage(pipe, rank, device) 101 | 102 | # Define a loss function 103 | loss_fn=torch.nn.MSELoss(reduction="sum") 104 | 105 | # Attach to a schedule 106 | schedule = ScheduleGPipe(stage, chunks, loss_fn=loss_fn) 107 | 108 | # Input data 109 | x = torch.randn(batch_size, in_dim, device=device) 110 | target = torch.randn(batch_size, out_dim, device=device) 111 | 112 | # Run the pipeline with input `x`. Divide the batch into 4 micro-batches 113 | # and run them in parallel on the pipeline 114 | if rank == 0: 115 | schedule.step(x) 116 | elif rank == world_size - 1: 117 | losses = [] 118 | output = schedule.step(target=target, losses=losses) 119 | else: 120 | schedule.step() 121 | 122 | if rank == world_size - 1: 123 | # Run the original code and get the output for comparison 124 | reference_output = mn(x) 125 | # Compare numerics of pipeline and original model 126 | torch.testing.assert_close(output, reference_output) 127 | print(f"Loss of microbatches: {losses}") 128 | print(" Pipeline parallel model ran successfully! ".center(80, "*")) 129 | -------------------------------------------------------------------------------- /examples/cpu_init/README.md: -------------------------------------------------------------------------------- 1 | This example demonstrates how to create a pipeline based on a model on CPU, move different parts of the model to GPU 2 | and run the model with data on GPU. This technique can help when a model is too large to materialize on a single GPU. 3 | 4 | Run command: 5 | ``` 6 | $ torchrun --nproc-per-node 4 gpt2_cpu_init.py 7 | ``` 8 | -------------------------------------------------------------------------------- /examples/cpu_init/gpt2_cpu_init.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 gpt2_cpu_init.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import GPT2ForSequenceClassification, GPT2Config 14 | 15 | 16 | def run(args): 17 | # Model configs 18 | config = GPT2Config() 19 | 20 | # Create model on CPU 21 | model_class = GPT2ForSequenceClassification 22 | model_name = "GPT2ForSequenceClassification" 23 | gpt2 = model_class(config) 24 | gpt2.eval() 25 | if args.rank == 0: 26 | print(gpt2.config) 27 | print(gpt2) 28 | 29 | # Example input on CPU 30 | example_input = torch.randint( 31 | low=0, 32 | high=config.vocab_size, 33 | size=(args.batch_size, 512), # bs x seq_len 34 | device="cpu", 35 | dtype=torch.int64, 36 | requires_grad=False, 37 | ) 38 | 39 | # Split spec 40 | decoders_per_rank = (gpt2.config.n_layer + args.world_size - 1) // args.world_size 41 | print(f"decoders_per_rank = {decoders_per_rank}") 42 | split_spec = { 43 | f'transformer.h.{i * decoders_per_rank}': SplitPoint.BEGINNING 44 | for i in range(1, args.world_size) 45 | } 46 | 47 | # Create pipeline 48 | pipe = pipeline( 49 | gpt2, 50 | num_chunks=args.chunks, 51 | example_args=(example_input,), 52 | split_spec=split_spec, 53 | ) 54 | 55 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 56 | 57 | # Create schedule runtime 58 | stage = PipelineStage( 59 | pipe, 60 | args.rank, 61 | device=args.device, 62 | ) 63 | 64 | # Attach to a schedule 65 | schedule = ScheduleGPipe(stage, args.chunks) 66 | 67 | # Real input on GPU 68 | real_input = torch.randint( 69 | low=0, 70 | high=config.vocab_size, 71 | size=(args.batch_size, 512), # bs x seq_len 72 | device=args.device, 73 | dtype=torch.int64, 74 | requires_grad=False, 75 | ) 76 | 77 | # Run 78 | if args.rank == 0: 79 | schedule.step(real_input) 80 | else: 81 | out = schedule.step() 82 | 83 | print(f"Rank {args.rank} completes") 84 | 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 89 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 90 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 91 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 92 | parser.add_argument('--schedule', type=str, default="FillDrain") 93 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 94 | parser.add_argument("--chunks", type=int, default=4) 95 | parser.add_argument('--batch_size', type=int, default=4) 96 | parser.add_argument('--batches', type=int, default=1) 97 | 98 | args = parser.parse_args() 99 | 100 | if args.cuda: 101 | dev_id = args.rank % torch.cuda.device_count() 102 | args.device = torch.device(f"cuda:{dev_id}") 103 | else: 104 | args.device = torch.device("cpu") 105 | 106 | # Init process group 107 | backend = "nccl" if args.cuda else "gloo" 108 | dist.init_process_group( 109 | backend=backend, 110 | rank=args.rank, 111 | world_size=args.world_size, 112 | ) 113 | 114 | run(args) 115 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_bert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_bert.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import BertModel, BertConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = BertConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = BertModel 25 | model_name = "BertModel" 26 | bert = model_class(config) 27 | bert.to(args.device) 28 | bert.eval() 29 | if args.rank == 0: 30 | print(bert.config) 31 | print(f"Total number of params = {get_number_of_params(bert) // 10 ** 6}M") 32 | print(bert) 33 | 34 | # Example microbatch inputs 35 | example_mb = generate_inputs_for_model( 36 | model_class, bert, model_name, args.batch_size // args.chunks, args.device) 37 | 38 | # Split points 39 | layers_per_rank = bert.config.num_hidden_layers // args.world_size 40 | split_spec = { 41 | f"encoder.layer.{i * layers_per_rank}": SplitPoint.BEGINNING 42 | for i in range(1, args.world_size) 43 | } 44 | 45 | # Create pipeline 46 | pipe = pipeline( 47 | bert, 48 | mb_args=(), 49 | mb_kwargs=example_mb, 50 | split_spec=split_spec, 51 | ) 52 | 53 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 54 | smod = pipe.get_stage_module(args.rank) 55 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 56 | 57 | # Create schedule runtime 58 | stage = pipe.build_stage( 59 | args.rank, 60 | device=args.device, 61 | ) 62 | 63 | # Attach to a schedule 64 | schedule = ScheduleGPipe(stage, args.chunks) 65 | 66 | # Full batch inputs as in single-worker case 67 | inputs = generate_inputs_for_model( 68 | model_class, bert, model_name, args.batch_size, args.device) 69 | 70 | # Run 71 | if args.rank == 0: 72 | schedule.step(**inputs) 73 | else: 74 | out = schedule.step() 75 | 76 | dist.barrier() 77 | dist.destroy_process_group() 78 | print(f"Rank {args.rank} completes") 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 84 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 85 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 86 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 87 | parser.add_argument('--schedule', type=str, default="FillDrain") 88 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 89 | parser.add_argument("--chunks", type=int, default=4) 90 | parser.add_argument('--batch_size', type=int, default=4) 91 | parser.add_argument('--batches', type=int, default=1) 92 | 93 | args = parser.parse_args() 94 | 95 | if args.cuda: 96 | dev_id = args.rank % torch.cuda.device_count() 97 | args.device = torch.device(f"cuda:{dev_id}") 98 | else: 99 | args.device = torch.device("cpu") 100 | 101 | # Init process group 102 | backend = "nccl" if args.cuda else "gloo" 103 | dist.init_process_group( 104 | backend=backend, 105 | rank=args.rank, 106 | world_size=args.world_size, 107 | ) 108 | 109 | run(args) 110 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_blenderbot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_blenderbot.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import BlenderbotForCausalLM, BlenderbotConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = BlenderbotConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = BlenderbotForCausalLM 25 | model_name = "BlenderbotForCausalLM" 26 | blenderbot = model_class(config) 27 | blenderbot.to(args.device) 28 | blenderbot.eval() 29 | if args.rank == 0: 30 | print(blenderbot.config) 31 | print(f"Total number of params = {get_number_of_params(blenderbot) // 10 ** 6}M") 32 | print(blenderbot) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, blenderbot, model_name, args.batch_size, args.device) 37 | input_ids = example_inputs["input_ids"] 38 | 39 | # Split points 40 | layers_per_rank = blenderbot.config.decoder_layers // args.world_size 41 | split_spec = { 42 | f"model.decoder.layers.{i * layers_per_rank}": SplitPoint.BEGINNING 43 | for i in range(1, args.world_size) 44 | } 45 | 46 | # Create pipeline 47 | pipe = pipeline( 48 | blenderbot, 49 | num_chunks=args.chunks, 50 | example_args=(input_ids, ), 51 | split_spec=split_spec, 52 | ) 53 | 54 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 55 | smod = pipe.get_stage_module(args.rank) 56 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 57 | 58 | # Create schedule runtime 59 | stage = PipelineStage( 60 | pipe, 61 | args.rank, 62 | device=args.device, 63 | ) 64 | 65 | # Attach to a schedule 66 | schedule = ScheduleGPipe(stage, args.chunks) 67 | 68 | # Run 69 | if args.rank == 0: 70 | schedule.step(input_ids) 71 | else: 72 | out = schedule.step() 73 | 74 | dist.barrier() 75 | dist.destroy_process_group() 76 | print(f"Rank {args.rank} completes") 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 82 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 83 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 84 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 85 | parser.add_argument('--schedule', type=str, default="FillDrain") 86 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 87 | parser.add_argument("--chunks", type=int, default=4) 88 | parser.add_argument('--batch_size', type=int, default=32) 89 | parser.add_argument('--batches', type=int, default=1) 90 | 91 | args = parser.parse_args() 92 | 93 | if args.cuda: 94 | dev_id = args.rank % torch.cuda.device_count() 95 | args.device = torch.device(f"cuda:{dev_id}") 96 | else: 97 | args.device = torch.device("cpu") 98 | 99 | # Init process group 100 | backend = "nccl" if args.cuda else "gloo" 101 | dist.init_process_group( 102 | backend=backend, 103 | rank=args.rank, 104 | world_size=args.world_size, 105 | ) 106 | 107 | run(args) 108 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_camemBert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_camemBert.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import CamembertModel, CamembertConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = CamembertConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = CamembertModel 25 | model_name = "CamembertModel" 26 | camembert = model_class(config) 27 | camembert.to(args.device) 28 | camembert.eval() 29 | if args.rank == 0: 30 | print(camembert.config) 31 | print(f"Total number of params = {get_number_of_params(camembert) // 10 ** 6}M") 32 | print(camembert) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, camembert, model_name, args.batch_size, args.device) 37 | 38 | # Split points 39 | layers_per_rank = camembert.config.num_hidden_layers // args.world_size 40 | split_spec = { 41 | f"encoder.layer.{i * layers_per_rank}": SplitPoint.BEGINNING 42 | for i in range(1, args.world_size) 43 | } 44 | 45 | # Create pipeline 46 | pipe = pipeline( 47 | camembert, 48 | num_chunks=args.chunks, 49 | example_args=(), 50 | example_kwargs=example_inputs, 51 | split_spec=split_spec, 52 | ) 53 | 54 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 55 | smod = pipe.get_stage_module(args.rank) 56 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 57 | 58 | # Create schedule runtime 59 | stage = PipelineStage( 60 | pipe, 61 | args.rank, 62 | device=args.device, 63 | ) 64 | 65 | # Attach to a schedule 66 | schedule = ScheduleGPipe(stage, args.chunks) 67 | 68 | # Run 69 | if args.rank == 0: 70 | schedule.step(**example_inputs) 71 | else: 72 | out = schedule.step() 73 | 74 | dist.barrier() 75 | dist.destroy_process_group() 76 | print(f"Rank {args.rank} completes") 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 82 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 83 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 84 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 85 | parser.add_argument('--schedule', type=str, default="FillDrain") 86 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 87 | parser.add_argument("--chunks", type=int, default=4) 88 | parser.add_argument('--batch_size', type=int, default=4) 89 | parser.add_argument('--batches', type=int, default=1) 90 | 91 | args = parser.parse_args() 92 | 93 | if args.cuda: 94 | dev_id = args.rank % torch.cuda.device_count() 95 | args.device = torch.device(f"cuda:{dev_id}") 96 | else: 97 | args.device = torch.device("cpu") 98 | 99 | # Init process group 100 | backend = "nccl" if args.cuda else "gloo" 101 | dist.init_process_group( 102 | backend=backend, 103 | rank=args.rank, 104 | world_size=args.world_size, 105 | ) 106 | 107 | run(args) 108 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_convBert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_convBert.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import ConvBertForMaskedLM, ConvBertConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = ConvBertConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = ConvBertForMaskedLM 25 | model_name = "ConvBertForMaskedLM" 26 | convbert = model_class(config) 27 | convbert.to(args.device) 28 | convbert.eval() 29 | if args.rank == 0: 30 | print(convbert.config) 31 | print(f"Total number of params = {get_number_of_params(convbert) // 10 ** 6}M") 32 | print(convbert) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, convbert, model_name, args.batch_size, args.device) 37 | input_ids = example_inputs["input_ids"] 38 | 39 | # Split points 40 | # The first rank takes embedding 41 | split_spec = {} 42 | split_spec["convbert.embeddings"] = SplitPoint.END 43 | # The last rank takes generation 44 | split_spec["generator_predictions"] = SplitPoint.BEGINNING 45 | # The rest ranks divide encoder layers 46 | layers_per_rank = convbert.config.num_hidden_layers // (args.world_size - 2) 47 | for i in range(1, args.world_size - 2): 48 | split_spec[f"convbert.encoder.layer.{i * layers_per_rank}"] = SplitPoint.BEGINNING 49 | 50 | # Create pipeline 51 | pipe = pipeline( 52 | convbert, 53 | num_chunks=args.chunks, 54 | example_args=(input_ids, ), 55 | split_spec=split_spec, 56 | ) 57 | 58 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 59 | smod = pipe.get_stage_module(args.rank) 60 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 61 | 62 | # Create schedule runtime 63 | stage = PipelineStage( 64 | pipe, 65 | args.rank, 66 | device=args.device, 67 | ) 68 | 69 | # Attach to a schedule 70 | schedule = ScheduleGPipe(stage, args.chunks) 71 | 72 | # Run 73 | if args.rank == 0: 74 | schedule.step(input_ids) 75 | else: 76 | out = schedule.step() 77 | 78 | dist.barrier() 79 | dist.destroy_process_group() 80 | print(f"Rank {args.rank} completes") 81 | 82 | 83 | if __name__ == "__main__": 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 86 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 87 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 88 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 89 | parser.add_argument('--schedule', type=str, default="FillDrain") 90 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 91 | parser.add_argument("--chunks", type=int, default=4) 92 | parser.add_argument('--batch_size', type=int, default=32) 93 | parser.add_argument('--batches', type=int, default=1) 94 | 95 | args = parser.parse_args() 96 | 97 | if args.cuda: 98 | dev_id = args.rank % torch.cuda.device_count() 99 | args.device = torch.device(f"cuda:{dev_id}") 100 | else: 101 | args.device = torch.device("cpu") 102 | 103 | # Init process group 104 | backend = "nccl" if args.cuda else "gloo" 105 | dist.init_process_group( 106 | backend=backend, 107 | rank=args.rank, 108 | world_size=args.world_size, 109 | ) 110 | 111 | run(args) 112 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_deberta.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_deberta.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import DebertaModel, DebertaConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = DebertaConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = DebertaModel 25 | model_name = "DebertaModel" 26 | deberta = model_class(config) 27 | deberta.to(args.device) 28 | deberta.eval() 29 | if args.rank == 0: 30 | print(deberta.config) 31 | print(f"Total number of params = {get_number_of_params(deberta) // 10 ** 6}M") 32 | print(deberta) 33 | 34 | # Example microbatch inputs 35 | mb_inputs = generate_inputs_for_model( 36 | model_class, deberta, model_name, args.batch_size // args.chunks, args.device) 37 | 38 | # Split points 39 | layers_per_rank = deberta.config.num_hidden_layers // args.world_size 40 | split_spec = { 41 | f"encoder.layer.{i * layers_per_rank}": SplitPoint.BEGINNING 42 | for i in range(1, args.world_size) 43 | } 44 | 45 | # Create pipeline 46 | pipe = pipeline( 47 | deberta, 48 | mb_args=(), 49 | mb_kwargs=mb_inputs, 50 | split_spec=split_spec, 51 | ) 52 | 53 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 54 | smod = pipe.get_stage_module(args.rank) 55 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 56 | 57 | # Create schedule runtime 58 | stage = pipe.build_stage( 59 | args.rank, 60 | device=args.device, 61 | ) 62 | 63 | # Attach to a schedule 64 | schedule = ScheduleGPipe(stage, args.chunks) 65 | 66 | # Full batch inputs as in single-worker case 67 | inputs = generate_inputs_for_model( 68 | model_class, deberta, model_name, args.batch_size, args.device) 69 | 70 | # Run 71 | if args.rank == 0: 72 | schedule.step(**inputs) 73 | else: 74 | out = schedule.step() 75 | 76 | dist.barrier() 77 | dist.barrier() 78 | dist.destroy_process_group() 79 | print(f"Rank {args.rank} completes") 80 | 81 | 82 | if __name__ == "__main__": 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 85 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 86 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 87 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 88 | parser.add_argument('--schedule', type=str, default="FillDrain") 89 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 90 | parser.add_argument("--chunks", type=int, default=4) 91 | parser.add_argument('--batch_size', type=int, default=32) 92 | parser.add_argument('--batches', type=int, default=1) 93 | 94 | args = parser.parse_args() 95 | 96 | if args.cuda: 97 | dev_id = args.rank % torch.cuda.device_count() 98 | args.device = torch.device(f"cuda:{dev_id}") 99 | else: 100 | args.device = torch.device("cpu") 101 | 102 | # Init process group 103 | backend = "nccl" if args.cuda else "gloo" 104 | dist.init_process_group( 105 | backend=backend, 106 | rank=args.rank, 107 | world_size=args.world_size, 108 | ) 109 | 110 | run(args) 111 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_debertaV2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_debertaV2.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import DebertaV2Model, DebertaConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def add_split_points(deberta, nranks): 19 | layers_per_rank = deberta.config.num_hidden_layers // nranks 20 | for i in range(1, nranks): 21 | annotate_split_points( 22 | deberta, {f"deberta.encoder.layer.{i * layers_per_rank}": SplitPoint.BEGINNING}) 23 | 24 | 25 | def run(args): 26 | # Model configs 27 | config = DebertaConfig() 28 | print("Using device:", args.device) 29 | 30 | # Create model 31 | model_class = DebertaV2Model 32 | model_name = "DebertaV2Model" 33 | deberta = model_class(config) 34 | deberta.to(args.device) 35 | deberta.eval() 36 | if args.rank == 0: 37 | print(deberta.config) 38 | print(f"Total number of params = {get_number_of_params(deberta) // 10 ** 6}M") 39 | print(deberta) 40 | 41 | # Input configs 42 | example_inputs = generate_inputs_for_model( 43 | model_class, deberta, model_name, args.batch_size, args.device) 44 | input_ids = example_inputs["input_ids"] 45 | 46 | # Split points 47 | layers_per_rank = deberta.config.num_hidden_layers // args.world_size 48 | split_spec = { 49 | f"encoder.layer.{i * layers_per_rank}": SplitPoint.BEGINNING 50 | for i in range(1, args.world_size) 51 | } 52 | 53 | # Create pipeline 54 | pipe = pipeline( 55 | deberta, 56 | num_chunks=args.chunks, 57 | example_args=(input_ids, ), 58 | split_spec=split_spec, 59 | ) 60 | 61 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 62 | smod = pipe.get_stage_module(args.rank) 63 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 64 | 65 | # Create schedule runtime 66 | stage = PipelineStage( 67 | pipe, 68 | args.rank, 69 | device=args.device, 70 | ) 71 | 72 | # Attach to a schedule 73 | schedule = ScheduleGPipe(stage, args.chunks) 74 | 75 | # Run 76 | if args.rank == 0: 77 | schedule.step(input_ids) 78 | else: 79 | out = schedule.step() 80 | 81 | dist.barrier() 82 | dist.destroy_process_group() 83 | print(f"Rank {args.rank} completes") 84 | 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 89 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 90 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 91 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 92 | parser.add_argument('--schedule', type=str, default="FillDrain") 93 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 94 | parser.add_argument("--chunks", type=int, default=4) 95 | parser.add_argument('--batch_size', type=int, default=8) 96 | parser.add_argument('--batches', type=int, default=1) 97 | 98 | args = parser.parse_args() 99 | 100 | if args.cuda: 101 | dev_id = args.rank % torch.cuda.device_count() 102 | args.device = torch.device(f"cuda:{dev_id}") 103 | else: 104 | args.device = torch.device("cpu") 105 | 106 | # Init process group 107 | backend = "nccl" if args.cuda else "gloo" 108 | dist.init_process_group( 109 | backend=backend, 110 | rank=args.rank, 111 | world_size=args.world_size, 112 | ) 113 | 114 | run(args) 115 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_distilBert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_distilBert.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import DistilBertForMaskedLM, DistilBertConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = DistilBertConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = DistilBertForMaskedLM 25 | model_name = "DistilBertForMaskedLM" 26 | distilbert = model_class(config) 27 | distilbert.to(args.device) 28 | distilbert.eval() 29 | if args.rank == 0: 30 | print(distilbert.config) 31 | print(f"Total number of params = {get_number_of_params(distilbert) // 10 ** 6}M") 32 | print(distilbert) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, distilbert, model_name, args.batch_size, args.device) 37 | input_ids = example_inputs["input_ids"] 38 | 39 | # Split points 40 | split_spec = {} 41 | # The first rank carries the embedding layer 42 | split_spec[f"distilbert.embeddings"] = SplitPoint.END 43 | # 6 Transformer layers divided over the rest 3 ranks 44 | layers_per_rank = distilbert.config.num_hidden_layers // (args.world_size - 1) 45 | for i in range(1, args.world_size - 1): 46 | split_spec[f"distilbert.transformer.layer.{i * layers_per_rank}"] = SplitPoint.BEGINNING 47 | 48 | # Create pipeline 49 | pipe = pipeline( 50 | distilbert, 51 | num_chunks=args.chunks, 52 | example_args=(input_ids, ), 53 | split_spec=split_spec, 54 | ) 55 | 56 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 57 | smod = pipe.get_stage_module(args.rank) 58 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 59 | 60 | # Create schedule runtime 61 | stage = PipelineStage( 62 | pipe, 63 | args.rank, 64 | device=args.device, 65 | ) 66 | 67 | # Attach to a schedule 68 | schedule = ScheduleGPipe(stage, args.chunks) 69 | 70 | # Run 71 | if args.rank == 0: 72 | schedule.step(input_ids) 73 | else: 74 | out = schedule.step() 75 | 76 | print(f"Rank {args.rank} completes") 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 82 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 83 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 84 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 85 | parser.add_argument('--schedule', type=str, default="FillDrain") 86 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 87 | parser.add_argument("--chunks", type=int, default=4) 88 | parser.add_argument('--batch_size', type=int, default=256) 89 | parser.add_argument('--batches', type=int, default=1) 90 | 91 | args = parser.parse_args() 92 | 93 | if args.cuda: 94 | dev_id = args.rank % torch.cuda.device_count() 95 | args.device = torch.device(f"cuda:{dev_id}") 96 | else: 97 | args.device = torch.device("cpu") 98 | 99 | # Init process group 100 | backend = "nccl" if args.cuda else "gloo" 101 | dist.init_process_group( 102 | backend=backend, 103 | rank=args.rank, 104 | world_size=args.world_size, 105 | ) 106 | 107 | run(args) 108 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_electra.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_electra.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import ElectraForCausalLM, ElectraConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = ElectraConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = ElectraForCausalLM 25 | model_name = "ElectraForCausalLM" 26 | electra = model_class(config) 27 | electra.to(args.device) 28 | electra.eval() 29 | if args.rank == 0: 30 | print(electra.config) 31 | print(f"Total number of params = {get_number_of_params(electra) // 10 ** 6}M") 32 | print(electra) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, electra, model_name, args.batch_size, args.device) 37 | input_ids = example_inputs["input_ids"] 38 | 39 | # Split points 40 | layers_per_rank = electra.config.num_hidden_layers // args.world_size 41 | split_spec = { 42 | f"electra.encoder.layer.{i * layers_per_rank}": SplitPoint.BEGINNING 43 | for i in range(1, args.world_size) 44 | } 45 | 46 | # Create pipeline 47 | pipe = pipeline( 48 | electra, 49 | num_chunks=args.chunks, 50 | example_args=(input_ids, ), 51 | split_spec=split_spec, 52 | ) 53 | 54 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 55 | smod = pipe.get_stage_module(args.rank) 56 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 57 | 58 | # Create schedule runtime 59 | stage = PipelineStage( 60 | pipe, 61 | args.rank, 62 | device=args.device, 63 | ) 64 | 65 | # Attach to a schedule 66 | schedule = ScheduleGPipe(stage, args.chunks) 67 | 68 | # Run 69 | if args.rank == 0: 70 | schedule.step(input_ids) 71 | else: 72 | out = schedule.step() 73 | 74 | dist.barrier() 75 | dist.destroy_process_group() 76 | print(f"Rank {args.rank} completes") 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 82 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 83 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 84 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 85 | parser.add_argument('--schedule', type=str, default="FillDrain") 86 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 87 | parser.add_argument("--chunks", type=int, default=4) 88 | parser.add_argument('--batch_size', type=int, default=64) 89 | parser.add_argument('--batches', type=int, default=1) 90 | 91 | args = parser.parse_args() 92 | 93 | if args.cuda: 94 | dev_id = args.rank % torch.cuda.device_count() 95 | args.device = torch.device(f"cuda:{dev_id}") 96 | else: 97 | args.device = torch.device("cpu") 98 | 99 | # Init process group 100 | backend = "nccl" if args.cuda else "gloo" 101 | dist.init_process_group( 102 | backend=backend, 103 | rank=args.rank, 104 | world_size=args.world_size, 105 | ) 106 | 107 | run(args) 108 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_fnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_fnet.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import FNetModel, FNetConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = FNetConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = FNetModel 25 | model_name = "FNetModel" 26 | fnet = model_class(config) 27 | fnet.to(args.device) 28 | fnet.eval() 29 | if args.rank == 0: 30 | print(fnet.config) 31 | print(f"Total number of params = {get_number_of_params(fnet) // 10 ** 6}M") 32 | print(fnet) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, fnet, model_name, args.batch_size, args.device) 37 | input_ids = example_inputs["input_ids"] 38 | 39 | # Split points 40 | layers_per_rank = fnet.config.num_hidden_layers // args.world_size 41 | split_spec = { 42 | f"encoder.layer.{i * layers_per_rank}": SplitPoint.BEGINNING 43 | for i in range(1, args.world_size) 44 | } 45 | 46 | # Create pipeline 47 | pipe = pipeline( 48 | fnet, 49 | num_chunks=args.chunks, 50 | example_args=(input_ids, ), 51 | split_spec=split_spec, 52 | ) 53 | 54 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 55 | smod = pipe.get_stage_module(args.rank) 56 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 57 | 58 | # Create schedule runtime 59 | stage = PipelineStage( 60 | pipe, 61 | args.rank, 62 | device=args.device, 63 | ) 64 | 65 | # Attach to a schedule 66 | schedule = ScheduleGPipe(stage, args.chunks) 67 | 68 | # Run 69 | if args.rank == 0: 70 | schedule.step(input_ids) 71 | else: 72 | out = schedule.step() 73 | 74 | dist.barrier() 75 | dist.destroy_process_group() 76 | print(f"Rank {args.rank} completes") 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 82 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 83 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 84 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 85 | parser.add_argument('--schedule', type=str, default="FillDrain") 86 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 87 | parser.add_argument("--chunks", type=int, default=4) 88 | parser.add_argument('--batch_size', type=int, default=32) 89 | parser.add_argument('--batches', type=int, default=1) 90 | 91 | args = parser.parse_args() 92 | 93 | if args.cuda: 94 | dev_id = args.rank % torch.cuda.device_count() 95 | args.device = torch.device(f"cuda:{dev_id}") 96 | else: 97 | args.device = torch.device("cpu") 98 | 99 | # Init process group 100 | backend = "nccl" if args.cuda else "gloo" 101 | dist.init_process_group( 102 | backend=backend, 103 | rank=args.rank, 104 | world_size=args.world_size, 105 | ) 106 | 107 | run(args) 108 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_gpt2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_gpt2.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import GPT2ForSequenceClassification, GPT2Config 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = GPT2Config() 21 | config.n_embd = args.n_embd or config.n_embd 22 | config.n_layer = args.n_layer or config.n_layer 23 | config.n_head = args.n_head or config.n_head 24 | print("[Rank {}] Using device: {}".format(args.rank, args.device)) 25 | 26 | # Create model 27 | model_class = GPT2ForSequenceClassification 28 | model_name = "GPT2ForSequenceClassification" 29 | gpt2 = model_class(config) 30 | gpt2.to(args.device) 31 | gpt2.eval() 32 | if args.rank == 0: 33 | print(gpt2.config) 34 | print(f"GPT-2 total number of params = {get_number_of_params(gpt2) // 10 ** 6}M") 35 | print(gpt2) 36 | 37 | # Example microbatch inputs 38 | mb_inputs = generate_inputs_for_model( 39 | model_class, gpt2, model_name, args.batch_size // args.chunks, args.device) 40 | 41 | # Pipeline split spec 42 | decoders_per_rank = (gpt2.config.n_layer + args.world_size - 1) // args.world_size 43 | print(f"decoders_per_rank = {decoders_per_rank}") 44 | split_spec = { 45 | f'transformer.h.{i * decoders_per_rank}': SplitPoint.BEGINNING 46 | for i in range(1, args.world_size) 47 | } 48 | 49 | # Create pipeline representation 50 | pipe = pipeline( 51 | gpt2, 52 | mb_args=(), 53 | mb_kwargs=mb_inputs, 54 | split_spec=split_spec, 55 | ) 56 | 57 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 58 | smod = pipe.get_stage_module(args.rank) 59 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 60 | 61 | # Create schedule runtime 62 | stage = pipe.build_stage( 63 | args.rank, 64 | device=args.device, 65 | ) 66 | 67 | # Attach to a schedule 68 | schedule = ScheduleGPipe(stage, args.chunks) 69 | 70 | # Full batch inputs as in single-worker case 71 | inputs = generate_inputs_for_model( 72 | model_class, gpt2, model_name, args.batch_size, args.device) 73 | 74 | # Run 75 | if args.rank == 0: 76 | schedule.step(**inputs) 77 | else: 78 | out = schedule.step() 79 | 80 | dist.barrier() 81 | dist.destroy_process_group() 82 | print(f"Rank {args.rank} completes") 83 | 84 | 85 | if __name__ == "__main__": 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 88 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 89 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 90 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 91 | parser.add_argument('--schedule', type=str, default="FillDrain") 92 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 93 | parser.add_argument("--chunks", type=int, default=4) 94 | # Note: this specific example requires: 1) a batch size that is divisible by 95 | # the number of chunks; 2) the division result (i.e. chunk size) must be 1, 96 | # otherwise padding token must be provided too (see GPT-2's forward function) 97 | parser.add_argument('--batch_size', type=int, default=4) 98 | parser.add_argument('--batches', type=int, default=1) 99 | parser.add_argument('--n_embd', type=int, default=None) 100 | parser.add_argument('--n_layer', type=int, default=None) 101 | parser.add_argument('--n_head', type=int, default=None) 102 | 103 | args = parser.parse_args() 104 | 105 | if args.cuda: 106 | dev_id = args.rank % torch.cuda.device_count() 107 | args.device = torch.device(f"cuda:{dev_id}") 108 | else: 109 | args.device = torch.device("cpu") 110 | 111 | # Init process group 112 | backend = "nccl" if args.cuda else "gloo" 113 | dist.init_process_group( 114 | backend=backend, 115 | rank=args.rank, 116 | world_size=args.world_size, 117 | ) 118 | 119 | run(args) 120 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_gptNeo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_gptNeo.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import GPTNeoForCausalLM, GPTNeoConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = GPTNeoConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = GPTNeoForCausalLM 25 | model_name = "GPTNeoForCausalLM" 26 | gptneo = model_class(config) 27 | gptneo.to(args.device) 28 | gptneo.eval() 29 | if args.rank == 0: 30 | print(gptneo.config) 31 | print(f"Total number of params = {get_number_of_params(gptneo) // 10 ** 6}M") 32 | print(gptneo) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, gptneo, model_name, args.batch_size, args.device) 37 | 38 | # Annotate split points 39 | layers_per_rank = (gptneo.config.num_layers + args.world_size - 1) // args.world_size 40 | print(f"decoders_per_rank = {layers_per_rank}") 41 | split_spec = { 42 | f'transformer.h.{i * layers_per_rank}': SplitPoint.BEGINNING 43 | for i in range(1, args.world_size) 44 | } 45 | 46 | # Create pipeline 47 | pipe = pipeline( 48 | gptneo, 49 | num_chunks=args.chunks, 50 | example_args=(), 51 | example_kwargs=example_inputs, 52 | split_spec=split_spec, 53 | ) 54 | 55 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 56 | smod = pipe.get_stage_module(args.rank) 57 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 58 | 59 | # Create schedule runtime 60 | stage = PipelineStage( 61 | pipe, 62 | args.rank, 63 | device=args.device, 64 | ) 65 | 66 | # Attach to a schedule 67 | schedule = ScheduleGPipe(stage, args.chunks) 68 | 69 | # Run 70 | if args.rank == 0: 71 | schedule.step(**example_inputs) 72 | else: 73 | out = schedule.step() 74 | 75 | dist.barrier() 76 | dist.destroy_process_group() 77 | print(f"Rank {args.rank} completes") 78 | 79 | 80 | if __name__ == "__main__": 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 83 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 84 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 85 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 86 | parser.add_argument('--schedule', type=str, default="FillDrain") 87 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 88 | parser.add_argument("--chunks", type=int, default=4) 89 | parser.add_argument('--batch_size', type=int, default=32) 90 | parser.add_argument('--batches', type=int, default=1) 91 | 92 | args = parser.parse_args() 93 | 94 | if args.cuda: 95 | dev_id = args.rank % torch.cuda.device_count() 96 | args.device = torch.device(f"cuda:{dev_id}") 97 | else: 98 | args.device = torch.device("cpu") 99 | 100 | # Init process group 101 | backend = "nccl" if args.cuda else "gloo" 102 | dist.init_process_group( 103 | backend=backend, 104 | rank=args.rank, 105 | world_size=args.world_size, 106 | ) 107 | 108 | run(args) 109 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_layoutLM.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_layoutLM.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import LayoutLMModel, LayoutLMConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = LayoutLMConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = LayoutLMModel 25 | model_name = "LayoutLMModel" 26 | layoutlm = model_class(config) 27 | layoutlm.to(args.device) 28 | layoutlm.eval() 29 | if args.rank == 0: 30 | print(layoutlm.config) 31 | print(f"Total number of params = {get_number_of_params(layoutlm) // 10 ** 6}M") 32 | print(layoutlm) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, layoutlm, model_name, args.batch_size, args.device) 37 | input_ids = example_inputs["input_ids"] 38 | 39 | # Split points 40 | split_spec = {} 41 | # First stage carries the embedding layer 42 | split_spec["embeddings"] = SplitPoint.END 43 | # 12 Transformer layers divided over the rest 3 ranks 44 | layers_per_rank = layoutlm.config.num_hidden_layers // (args.world_size - 1) 45 | for i in range(1, args.world_size - 1): 46 | split_spec[f"encoder.layer.{i * layers_per_rank}"] = SplitPoint.BEGINNING 47 | 48 | # Create pipeline 49 | pipe = pipeline( 50 | layoutlm, 51 | num_chunks=args.chunks, 52 | example_args=(input_ids, ), 53 | split_spec=split_spec, 54 | ) 55 | 56 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 57 | smod = pipe.get_stage_module(args.rank) 58 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 59 | 60 | # Create schedule runtime 61 | stage = PipelineStage( 62 | pipe, 63 | args.rank, 64 | device=args.device, 65 | ) 66 | 67 | # Attach to a schedule 68 | schedule = ScheduleGPipe(stage, args.chunks) 69 | 70 | # Run 71 | if args.rank == 0: 72 | schedule.step(input_ids) 73 | else: 74 | out = schedule.step() 75 | 76 | dist.barrier() 77 | dist.destroy_process_group() 78 | print(f"Rank {args.rank} completes") 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 84 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 85 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 86 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 87 | parser.add_argument('--schedule', type=str, default="FillDrain") 88 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 89 | parser.add_argument("--chunks", type=int, default=4) 90 | parser.add_argument('--batch_size', type=int, default=32) 91 | parser.add_argument('--batches', type=int, default=1) 92 | 93 | args = parser.parse_args() 94 | 95 | if args.cuda: 96 | dev_id = args.rank % torch.cuda.device_count() 97 | args.device = torch.device(f"cuda:{dev_id}") 98 | else: 99 | args.device = torch.device("cpu") 100 | 101 | # Init process group 102 | backend = "nccl" if args.cuda else "gloo" 103 | dist.init_process_group( 104 | backend=backend, 105 | rank=args.rank, 106 | world_size=args.world_size, 107 | ) 108 | 109 | run(args) 110 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_mbart.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_mbart.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import MBartForCausalLM, MBartConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = MBartConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = MBartForCausalLM 25 | model_name = "MBartForCausalLM" 26 | mbart = model_class(config) 27 | mbart.to(args.device) 28 | mbart.eval() 29 | if args.rank == 0: 30 | print(mbart.config) 31 | print(f"Total number of params = {get_number_of_params(mbart) // 10 ** 6}M") 32 | print(mbart) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, mbart, model_name, args.batch_size, args.device) 37 | input_ids = example_inputs["input_ids"] 38 | 39 | # Split points 40 | layers_per_rank = mbart.config.num_hidden_layers // args.world_size 41 | split_spec = { 42 | f"model.decoder.layers.{i * layers_per_rank}": SplitPoint.BEGINNING 43 | for i in range(1, args.world_size) 44 | } 45 | 46 | # Create pipeline 47 | pipe = pipeline( 48 | mbart, 49 | num_chunks=args.chunks, 50 | example_args=(input_ids, ), 51 | split_spec=split_spec, 52 | ) 53 | 54 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 55 | smod = pipe.get_stage_module(args.rank) 56 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 57 | 58 | # Create schedule runtime 59 | stage = PipelineStage( 60 | pipe, 61 | args.rank, 62 | device=args.device, 63 | ) 64 | 65 | # Attach to a schedule 66 | schedule = ScheduleGPipe(stage, args.chunks) 67 | 68 | # Run 69 | if args.rank == 0: 70 | schedule.step(input_ids) 71 | else: 72 | out = schedule.step() 73 | 74 | dist.barrier() 75 | dist.destroy_process_group() 76 | print(f"Rank {args.rank} completes") 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 82 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 83 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 84 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 85 | parser.add_argument('--schedule', type=str, default="FillDrain") 86 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 87 | parser.add_argument("--chunks", type=int, default=4) 88 | parser.add_argument('--batch_size', type=int, default=4) 89 | parser.add_argument('--batches', type=int, default=1) 90 | 91 | args = parser.parse_args() 92 | 93 | if args.cuda: 94 | dev_id = args.rank % torch.cuda.device_count() 95 | args.device = torch.device(f"cuda:{dev_id}") 96 | else: 97 | args.device = torch.device("cpu") 98 | 99 | # Init process group 100 | backend = "nccl" if args.cuda else "gloo" 101 | dist.init_process_group( 102 | backend=backend, 103 | rank=args.rank, 104 | world_size=args.world_size, 105 | ) 106 | 107 | run(args) 108 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_megatronBert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_megatronBert.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import MegatronBertModel, MegatronBertConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = MegatronBertConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = MegatronBertModel 25 | model_name = "MegatronBertModel" 26 | bert = model_class(config) 27 | bert.to(args.device) 28 | bert.eval() 29 | if args.rank == 0: 30 | print(bert.config) 31 | print(f"Total number of params = {get_number_of_params(bert) // 10 ** 6}M") 32 | print(bert) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, bert, model_name, args.batch_size, args.device) 37 | input_ids = example_inputs["input_ids"] 38 | 39 | # Annotate split points 40 | layers_per_rank = bert.config.num_hidden_layers // args.world_size 41 | split_spec = { 42 | f"encoder.layer.{i * layers_per_rank}": SplitPoint.BEGINNING 43 | for i in range(1, args.world_size) 44 | } 45 | 46 | # Create pipeline 47 | pipe = pipeline( 48 | bert, 49 | num_chunks=args.chunks, 50 | example_args=(input_ids, ), 51 | split_spec=split_spec, 52 | ) 53 | 54 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 55 | smod = pipe.get_stage_module(args.rank) 56 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 57 | 58 | # Create schedule runtime 59 | stage = PipelineStage( 60 | pipe, 61 | args.rank, 62 | device=args.device, 63 | ) 64 | 65 | # Attach to a schedule 66 | schedule = ScheduleGPipe(stage, args.chunks) 67 | 68 | # Run 69 | if args.rank == 0: 70 | schedule.step(input_ids) 71 | else: 72 | out = schedule.step() 73 | 74 | dist.barrier() 75 | dist.destroy_process_group() 76 | print(f"Rank {args.rank} completes") 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 82 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 83 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 84 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 85 | parser.add_argument('--schedule', type=str, default="FillDrain") 86 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 87 | parser.add_argument("--chunks", type=int, default=4) 88 | parser.add_argument('--batch_size', type=int, default=16) 89 | parser.add_argument('--batches', type=int, default=1) 90 | 91 | args = parser.parse_args() 92 | 93 | if args.cuda: 94 | dev_id = args.rank % torch.cuda.device_count() 95 | args.device = torch.device(f"cuda:{dev_id}") 96 | else: 97 | args.device = torch.device("cpu") 98 | 99 | # Init process group 100 | backend = "nccl" if args.cuda else "gloo" 101 | dist.init_process_group( 102 | backend=backend, 103 | rank=args.rank, 104 | world_size=args.world_size, 105 | ) 106 | 107 | run(args) 108 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_mobileBert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_mobileBert.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import MobileBertModel, MobileBertConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = MobileBertConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = MobileBertModel 25 | model_name = "MobileBertModel" 26 | mobilebert = model_class(config) 27 | mobilebert.to(args.device) 28 | mobilebert.eval() 29 | if args.rank == 0: 30 | print(mobilebert.config) 31 | print(f"Total number of params = {get_number_of_params(mobilebert) // 10 ** 6}M") 32 | print(mobilebert) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, mobilebert, model_name, args.batch_size, args.device) 37 | input_ids = example_inputs["input_ids"] 38 | 39 | # Split points 40 | split_spec = {} 41 | layers_per_rank = mobilebert.config.num_hidden_layers // args.world_size 42 | for i in range(1, args.world_size): 43 | split_spec[f"encoder.layer.{i * layers_per_rank}"] = SplitPoint.BEGINNING 44 | 45 | # Create pipeline 46 | pipe = pipeline( 47 | mobilebert, 48 | num_chunks=args.chunks, 49 | example_args=(input_ids, ), 50 | split_spec=split_spec, 51 | ) 52 | 53 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 54 | smod = pipe.get_stage_module(args.rank) 55 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 56 | 57 | # Create schedule runtime 58 | stage = PipelineStage( 59 | pipe, 60 | args.rank, 61 | device=args.device, 62 | ) 63 | 64 | # Attach to a schedule 65 | schedule = ScheduleGPipe(stage, args.chunks) 66 | 67 | # Run 68 | if args.rank == 0: 69 | schedule.step(input_ids) 70 | else: 71 | out = schedule.step() 72 | 73 | dist.barrier() 74 | dist.destroy_process_group() 75 | print(f"Rank {args.rank} completes") 76 | 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 81 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 82 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 83 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 84 | parser.add_argument('--schedule', type=str, default="FillDrain") 85 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 86 | parser.add_argument("--chunks", type=int, default=4) 87 | parser.add_argument('--batch_size', type=int, default=256) 88 | parser.add_argument('--batches', type=int, default=1) 89 | 90 | args = parser.parse_args() 91 | 92 | if args.cuda: 93 | dev_id = args.rank % torch.cuda.device_count() 94 | args.device = torch.device(f"cuda:{dev_id}") 95 | else: 96 | args.device = torch.device("cpu") 97 | 98 | # Init process group 99 | backend = "nccl" if args.cuda else "gloo" 100 | dist.init_process_group( 101 | backend=backend, 102 | rank=args.rank, 103 | world_size=args.world_size, 104 | ) 105 | 106 | run(args) 107 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_opt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 2 pippy_opt.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import OPTForCausalLM, OPTConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = OPTConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = OPTForCausalLM 25 | model_name = "OPTForCausalLM" 26 | opt = model_class(config) 27 | opt.to(args.device) 28 | opt.eval() 29 | if args.rank == 0: 30 | print(opt.config) 31 | print(f"Total number of params = {get_number_of_params(opt) // 10 ** 6}M") 32 | print(opt) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, opt, model_name, args.batch_size, args.device) 37 | 38 | # Split points 39 | layers_per_rank = opt.config.num_hidden_layers // args.world_size 40 | split_spec = { 41 | f"model.decoder.layers.{i * layers_per_rank}": SplitPoint.BEGINNING 42 | for i in range(1, args.world_size) 43 | } 44 | 45 | # Create pipeline 46 | pipe = pipeline( 47 | opt, 48 | num_chunks=args.chunks, 49 | example_args=(), 50 | example_kwargs=example_inputs, 51 | split_spec=split_spec, 52 | ) 53 | 54 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 55 | smod = pipe.get_stage_module(args.rank) 56 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 57 | 58 | # Create schedule runtime 59 | stage = PipelineStage( 60 | pipe, 61 | args.rank, 62 | device=args.device, 63 | ) 64 | 65 | # Attach to a schedule 66 | schedule = ScheduleGPipe(stage, args.chunks) 67 | 68 | # Run 69 | if args.rank == 0: 70 | schedule.step(**example_inputs) 71 | else: 72 | out = schedule.step() 73 | 74 | dist.barrier() 75 | dist.destroy_process_group() 76 | print(f"Rank {args.rank} completes") 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 82 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 83 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 84 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 85 | parser.add_argument('--schedule', type=str, default="FillDrain") 86 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 87 | parser.add_argument("--chunks", type=int, default=4) 88 | parser.add_argument('--batch_size', type=int, default=4) 89 | parser.add_argument('--batches', type=int, default=1) 90 | 91 | args = parser.parse_args() 92 | 93 | if args.cuda: 94 | dev_id = args.rank % torch.cuda.device_count() 95 | args.device = torch.device(f"cuda:{dev_id}") 96 | else: 97 | args.device = torch.device("cpu") 98 | 99 | # Init process group 100 | backend = "nccl" if args.cuda else "gloo" 101 | dist.init_process_group( 102 | backend=backend, 103 | rank=args.rank, 104 | world_size=args.world_size, 105 | ) 106 | 107 | run(args) 108 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_trOCR.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_trOCR.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import TrOCRForCausalLM, TrOCRConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = TrOCRConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = TrOCRForCausalLM 25 | model_name = "TrOCRForCausalLM" 26 | trocr = model_class(config) 27 | trocr.to(args.device) 28 | trocr.eval() 29 | if args.rank == 0: 30 | print(trocr.config) 31 | print(f"Total number of params = {get_number_of_params(trocr) // 10 ** 6}M") 32 | print(trocr) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, trocr, model_name, args.batch_size, args.device) 37 | input_ids = example_inputs["input_ids"] 38 | 39 | # Split points 40 | layers_per_rank = trocr.config.num_hidden_layers // args.world_size 41 | split_spec = { 42 | f"model.decoder.layers.{i * layers_per_rank}": SplitPoint.BEGINNING 43 | for i in range(1, args.world_size) 44 | } 45 | 46 | # Create pipeline 47 | pipe = pipeline( 48 | trocr, 49 | num_chunks=args.chunks, 50 | example_args=(input_ids, ), 51 | split_spec=split_spec, 52 | ) 53 | 54 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 55 | smod = pipe.get_stage_module(args.rank) 56 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 57 | 58 | # Create schedule runtime 59 | stage = PipelineStage( 60 | pipe, 61 | args.rank, 62 | device=args.device, 63 | ) 64 | 65 | # Attach to a schedule 66 | schedule = ScheduleGPipe(stage, args.chunks) 67 | 68 | # Run 69 | if args.rank == 0: 70 | schedule.step(input_ids) 71 | else: 72 | out = schedule.step() 73 | 74 | dist.barrier() 75 | dist.destroy_process_group() 76 | print(f"Rank {args.rank} completes") 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 82 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 83 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 84 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 85 | parser.add_argument('--schedule', type=str, default="FillDrain") 86 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 87 | parser.add_argument("--chunks", type=int, default=4) 88 | parser.add_argument('--batch_size', type=int, default=64) 89 | parser.add_argument('--batches', type=int, default=1) 90 | 91 | args = parser.parse_args() 92 | 93 | if args.cuda: 94 | dev_id = args.rank % torch.cuda.device_count() 95 | args.device = torch.device(f"cuda:{dev_id}") 96 | else: 97 | args.device = torch.device("cpu") 98 | 99 | # Init process group 100 | backend = "nccl" if args.cuda else "gloo" 101 | dist.init_process_group( 102 | backend=backend, 103 | rank=args.rank, 104 | world_size=args.world_size, 105 | ) 106 | 107 | run(args) 108 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_unet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 2 pippy_unet.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from diffusers import UNet2DModel 14 | 15 | from hf_utils import get_number_of_params 16 | 17 | 18 | def run(args): 19 | print("Using device:", args.device) 20 | 21 | # Create model 22 | # See https://github.com/huggingface/diffusers?tab=readme-ov-file#quickstart 23 | unet = UNet2DModel.from_pretrained("google/ddpm-cat-256") 24 | unet.to(args.device) 25 | unet.eval() 26 | if args.rank == 0: 27 | print(f"Total number of params = {get_number_of_params(unet) // 10 ** 6}M") 28 | print(unet) 29 | 30 | # Input configs 31 | sample_size = unet.config.sample_size 32 | noise = torch.randn((args.batch_size, 3, sample_size, sample_size), device=args.device) 33 | timestep = 1 34 | 35 | # Split model into two stages: 36 | # Stage 0: down_blocks + mid_block 37 | # Stage 2: up_blocks 38 | split_spec = {"mid_block": SplitPoint.END} 39 | 40 | # Create pipeline 41 | pipe = pipeline( 42 | unet, 43 | num_chunks=args.chunks, 44 | example_args=(noise, timestep), 45 | split_spec=split_spec, 46 | ) 47 | 48 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 49 | smod = pipe.get_stage_module(args.rank) 50 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 51 | 52 | # Create schedule runtime 53 | stage = PipelineStage( 54 | pipe, 55 | args.rank, 56 | device=args.device, 57 | ) 58 | 59 | # Attach to a schedule 60 | schedule = ScheduleGPipe(stage, args.chunks) 61 | 62 | # Run 63 | if args.rank == 0: 64 | schedule.step(noise) 65 | else: 66 | out = schedule.step() 67 | 68 | dist.barrier() 69 | dist.destroy_process_group() 70 | print(f"Rank {args.rank} completes") 71 | 72 | if __name__ == "__main__": 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 2))) 75 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 76 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 77 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 78 | parser.add_argument('--schedule', type=str, default="FillDrain") 79 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 80 | parser.add_argument("--chunks", type=int, default=2) 81 | parser.add_argument('--batch_size', type=int, default=2) 82 | parser.add_argument('--batches', type=int, default=1) 83 | 84 | args = parser.parse_args() 85 | 86 | if args.cuda: 87 | dev_id = args.rank % torch.cuda.device_count() 88 | args.device = torch.device(f"cuda:{dev_id}") 89 | else: 90 | args.device = torch.device("cpu") 91 | 92 | # Init process group 93 | backend = "nccl" if args.cuda else "gloo" 94 | dist.init_process_group( 95 | backend=backend, 96 | rank=args.rank, 97 | world_size=args.world_size, 98 | ) 99 | 100 | run(args) 101 | -------------------------------------------------------------------------------- /examples/huggingface/pippy_xlnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # $ torchrun --nproc-per-node 4 pippy_xlnet.py 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe, SplitPoint 12 | 13 | from transformers import XLNetLMHeadModel, XLNetConfig 14 | 15 | from hf_utils import generate_inputs_for_model, get_number_of_params 16 | 17 | 18 | def run(args): 19 | # Model configs 20 | config = XLNetConfig() 21 | print("Using device:", args.device) 22 | 23 | # Create model 24 | model_class = XLNetLMHeadModel 25 | model_name = "XLNetLMHeadModel" 26 | xlnet = model_class(config) 27 | xlnet.to(args.device) 28 | xlnet.eval() 29 | if args.rank == 0: 30 | print(xlnet.config) 31 | print(f"Total number of params = {get_number_of_params(xlnet) // 10 ** 6}M") 32 | print(xlnet) 33 | 34 | # Input configs 35 | example_inputs = generate_inputs_for_model( 36 | model_class, xlnet, model_name, args.batch_size, args.device) 37 | input_ids = example_inputs["input_ids"] 38 | 39 | # Split points 40 | layers_per_rank = xlnet.config.num_hidden_layers // args.world_size 41 | split_spec = { 42 | f"transformer.layer.{i * layers_per_rank}": SplitPoint.BEGINNING 43 | for i in range(1, args.world_size) 44 | } 45 | 46 | # Create pipeline 47 | pipe = pipeline( 48 | xlnet, 49 | num_chunks=args.chunks, 50 | example_args=(input_ids, ), 51 | split_spec=split_spec, 52 | ) 53 | 54 | assert pipe.num_stages == args.world_size, f"nstages = {pipe.num_stages} nranks = {args.world_size}" 55 | smod = pipe.get_stage_module(args.rank) 56 | print(f"Pipeline stage {args.rank} {get_number_of_params(smod) // 10 ** 6}M params") 57 | 58 | # Create schedule runtime 59 | stage = PipelineStage( 60 | pipe, 61 | args.rank, 62 | device=args.device, 63 | ) 64 | 65 | # Attach to a schedule 66 | schedule = ScheduleGPipe(stage, args.chunks) 67 | 68 | # Run 69 | if args.rank == 0: 70 | schedule.step(input_ids) 71 | else: 72 | out = schedule.step() 73 | 74 | dist.barrier() 75 | dist.destroy_process_group() 76 | print(f"Rank {args.rank} completes") 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) 82 | parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) 83 | parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) 84 | parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) 85 | parser.add_argument('--schedule', type=str, default="FillDrain") 86 | parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) 87 | parser.add_argument("--chunks", type=int, default=4) 88 | parser.add_argument('--batch_size', type=int, default=16) 89 | parser.add_argument('--batches', type=int, default=1) 90 | 91 | args = parser.parse_args() 92 | 93 | if args.cuda: 94 | dev_id = args.rank % torch.cuda.device_count() 95 | args.device = torch.device(f"cuda:{dev_id}") 96 | else: 97 | args.device = torch.device("cpu") 98 | 99 | # Init process group 100 | backend = "nccl" if args.cuda else "gloo" 101 | dist.init_process_group( 102 | backend=backend, 103 | rank=args.rank, 104 | world_size=args.world_size, 105 | ) 106 | 107 | run(args) 108 | -------------------------------------------------------------------------------- /examples/llama/README.md: -------------------------------------------------------------------------------- 1 | ``` 2 | $ torchrun --nproc-per-node 2 pippy_llama.py 3 | ``` 4 | ``` 5 | $ torchrun --nproc-per-node 4 pippy_llama.py 6 | ``` 7 | ``` 8 | $ torchrun --nproc-per-node 8 pippy_llama.py 9 | ``` 10 | ``` 11 | prompts = ( 12 | "How do you", "I like to", "Can I help", "You need to", 13 | "The weather is", "I found a", "What is your", "You are so", 14 | ) 15 | Outputs: 16 | ['make', 'think', 'you', 'be', 'getting', 'great', 'favorite', 'right'] 17 | ``` 18 | -------------------------------------------------------------------------------- /examples/llama/pippy_llama.py: -------------------------------------------------------------------------------- 1 | # $ torchrun --nproc-per-node 4 pippy_llama.py 2 | import os 3 | import torch 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | from torch.distributed.pipelining import SplitPoint, pipeline, ScheduleGPipe 6 | 7 | # Grab the model 8 | llama = AutoModelForCausalLM.from_pretrained( 9 | "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True 10 | ) 11 | print(llama) 12 | 13 | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") 14 | tokenizer.pad_token = tokenizer.eos_token 15 | mb_prompts = ( 16 | "How do you", "I like to", 17 | ) # microbatch size = 2 18 | 19 | rank = int(os.environ["RANK"]) 20 | world_size = int(os.environ["WORLD_SIZE"]) 21 | device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") 22 | torch.distributed.init_process_group(rank=rank, world_size=world_size) 23 | 24 | llama.to(device).eval() 25 | 26 | # Cut model by equal number of layers per rank 27 | layers_per_rank = llama.config.num_hidden_layers // world_size 28 | print(f"layers_per_rank = {layers_per_rank}") 29 | split_spec = { 30 | f"model.layers.{i * layers_per_rank}": SplitPoint.BEGINNING 31 | for i in range(1, world_size) 32 | } 33 | 34 | # Create a pipeline representation from the model 35 | mb_inputs = tokenizer(mb_prompts, return_tensors="pt", padding=True).to(device) 36 | pipe = pipeline(llama, mb_args=(mb_inputs["input_ids"],)) 37 | 38 | # Create pipeline stage for each rank 39 | stage = pipe.build_stage(rank, device=device) 40 | 41 | # Run time inputs 42 | full_batch_prompts = ( 43 | "How do you", "I like to", "Can I help", "You need to", 44 | "The weather is", "I found a", "What is your", "You are so", 45 | ) # full batch size = 8 46 | inputs = tokenizer(full_batch_prompts, return_tensors="pt", padding=True).to(device) 47 | 48 | # Attach to a schedule 49 | # number of microbatches = 8 // 2 = 4 50 | num_mbs = 4 51 | schedule = ScheduleGPipe(stage, num_mbs) 52 | 53 | # Run 54 | if rank == 0: 55 | args = inputs["input_ids"] 56 | else: 57 | args = None 58 | 59 | output = schedule.step(args) 60 | 61 | # Decode 62 | if output is not None: 63 | next_token_logits = output[0][:, -1, :] 64 | next_token = torch.argmax(next_token_logits, dim=-1) 65 | print(tokenizer.batch_decode(next_token)) 66 | -------------------------------------------------------------------------------- /examples/mixture_of_experts/dist_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | 3 | # Minimum effort to run this example: 4 | # torchrun --nproc-per-node 5 dist_moe.py 5 | # You need use 5 ranks because there are 3 experts, one pre-processor and one gatherer. 6 | 7 | """ 8 | pre-proc 9 | / | \ 10 | expert 0 expert 1 expert 2 11 | \ | / 12 | gatherer 13 | """ 14 | 15 | import torch 16 | import torch.distributed as dist 17 | 18 | from pippy import annotate_split_points, pipeline, PipelineStage, SplitPoint 19 | from pippy.PipelineSchedule import ScheduleGPipe 20 | 21 | 22 | d_hid = 16 23 | n_experts = 3 24 | batch_size = 4 25 | 26 | torch.manual_seed(0) 27 | 28 | # Each expert is a MLP 29 | class ExpertLayer(torch.nn.Module): 30 | def __init__(self, d_hid) -> None: 31 | super(ExpertLayer, self).__init__() 32 | self.net1 = torch.nn.Linear(d_hid, d_hid) 33 | self.relu = torch.nn.ReLU() 34 | self.net2 = torch.nn.Linear(d_hid, d_hid) 35 | 36 | def forward(self, x) -> torch.Tensor: 37 | x = self.net1(x) 38 | x = self.relu(x) 39 | x = self.net2(x) 40 | return x 41 | 42 | # Full model comprising n experts 43 | class MoE(torch.nn.Module): 44 | def __init__(self, n_experts: int) -> None: 45 | super().__init__() 46 | self.pre_proc = torch.nn.Linear(d_hid, d_hid) 47 | self.experts = torch.nn.ModuleList( 48 | [ 49 | ExpertLayer(d_hid) 50 | for _ in range(n_experts) 51 | ] 52 | ) 53 | 54 | def forward(self, x: torch.Tensor) -> torch.Tensor: 55 | x = self.pre_proc(x) 56 | outputs = [] 57 | for expert in self.experts: 58 | outputs.append(expert(x)) 59 | return torch.cat(outputs, dim=1) 60 | 61 | 62 | dist.init_process_group() 63 | rank = dist.get_rank() 64 | world_size = dist.get_world_size() 65 | device = torch.device(f"cuda:{rank}") 66 | 67 | model = MoE(n_experts) 68 | x = torch.randn(batch_size, d_hid) 69 | 70 | # Mark the split point for each expert 71 | annotate_split_points(model, {f"pre_proc": SplitPoint.END}) 72 | for i in range(n_experts): 73 | annotate_split_points( 74 | model, {f"experts.{i}": SplitPoint.END} 75 | ) 76 | 77 | pippy_model = pipeline(model, 1, (x,)) 78 | 79 | assert pippy_model.num_stages == world_size 80 | if rank == 0: 81 | print("Original model:\n", model) 82 | print("PiPPy model:") 83 | pippy_model.print_readable() 84 | 85 | # Check representation equivalence 86 | ref_out = model(x) 87 | pippy_out = pippy_model(x)[0] 88 | torch.testing.assert_close(pippy_out, ref_out) 89 | print(f"PiPPy model equivalent: {torch.sum(pippy_out)} ref {torch.sum(ref_out)}") 90 | 91 | # Create distributed runtime 92 | expert = PipelineStage(pippy_model, rank, device=device) 93 | 94 | # Attach to a schedule 95 | # Use a microbatch of 1, i.e. no pipelining 96 | schedule = ScheduleGPipe(expert, 1) 97 | if rank == 0: 98 | x = x.to(device) 99 | schedule.step(x) 100 | else: 101 | dist_out = schedule.step() 102 | 103 | # Check equivalence 104 | if rank == dist.get_world_size() - 1: 105 | print(f"Distributed model equivalent: {torch.sum(dist_out)} ref {torch.sum(ref_out)}") 106 | print(f"dist_out: {dist_out.shape}") 107 | print(f"ref_out: {ref_out.shape}") 108 | -------------------------------------------------------------------------------- /examples/profiling/mlp_profiling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # Run command: 3 | # torchrun --nproc-per-node 4 mlp_profiling.py 4 | 5 | import argparse 6 | import os 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from torch.profiler import profile, ProfilerActivity 11 | 12 | from pippy.compile import compile_stage 13 | from pippy import pipe_split 14 | 15 | 16 | d_hid = 1024 17 | chunk_size = 1024 18 | 19 | torch.manual_seed(0) 20 | 21 | 22 | class MLPModule(torch.nn.Module): 23 | def __init__(self, d_hid): 24 | super(MLPModule, self).__init__() 25 | self.net1 = torch.nn.Linear(d_hid, d_hid) 26 | self.relu = torch.nn.ReLU() 27 | self.net2 = torch.nn.Linear(d_hid, d_hid) 28 | 29 | def forward(self, x): 30 | x = self.net1(x) 31 | x = self.relu(x) 32 | x = self.net2(x) 33 | return x 34 | 35 | 36 | class ExampleCode(torch.nn.Module): 37 | def __init__(self, d_hid): 38 | super().__init__() 39 | self.mlp0 = MLPModule(d_hid) 40 | self.mlp1 = MLPModule(d_hid) 41 | self.mlp2 = MLPModule(d_hid) 42 | self.mlp3 = MLPModule(d_hid) 43 | self.mse_loss = torch.nn.MSELoss(reduction="sum") 44 | 45 | def forward(self, x, target): 46 | x = self.mlp0(x) 47 | pipe_split() 48 | x = self.mlp1(x) 49 | pipe_split() 50 | x = self.mlp2(x) 51 | pipe_split() 52 | x = self.mlp3(x) 53 | loss = self.mse_loss(x, target) 54 | return {"logits": x, "loss": loss} 55 | 56 | 57 | def run_worker(args): 58 | ec = ExampleCode(d_hid) 59 | ec.to(args.device) 60 | ec.train() 61 | 62 | ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) 63 | target = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) 64 | 65 | stage = compile_stage( 66 | ec, 67 | args.rank, 68 | args.world_size, 69 | args.chunks, 70 | args.device, 71 | None, 72 | [ec_x, target], 73 | ) 74 | 75 | # Run 76 | for _ in range(10): 77 | if args.rank == 0: 78 | out = stage(ec_x) 79 | elif args.rank == args.world_size - 1: 80 | out = stage(target) 81 | else: 82 | stage() 83 | 84 | dist.barrier() 85 | print(f"Rank {args.rank} warmup completes") 86 | 87 | with profile( 88 | activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 89 | ) as prof: 90 | for _ in range(20): 91 | if args.rank == 0: 92 | out = stage(ec_x) 93 | elif args.rank == args.world_size - 1: 94 | out = stage(target) 95 | else: 96 | stage() 97 | 98 | print(f"Rank {args.rank} profiling run completed") 99 | prof.export_chrome_trace( 100 | f"{os.path.splitext(os.path.basename(__file__))[0]}_{args.rank}.json" 101 | ) 102 | 103 | 104 | def main(args=None): 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument( 107 | "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) 108 | ) 109 | parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) 110 | parser.add_argument( 111 | "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") 112 | ) 113 | parser.add_argument( 114 | "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") 115 | ) 116 | parser.add_argument( 117 | "--cuda", type=int, default=int(torch.cuda.is_available()) 118 | ) 119 | parser.add_argument( 120 | "--chunks", 121 | type=int, 122 | default=4, 123 | ) 124 | args = parser.parse_args(args) 125 | 126 | if args.cuda: 127 | dev_id = args.rank % torch.cuda.device_count() 128 | args.device = torch.device(f"cuda:{dev_id}") 129 | else: 130 | args.device = torch.device("cpu") 131 | 132 | # Init process group 133 | backend = "nccl" if args.cuda else "gloo" 134 | dist.init_process_group( 135 | backend=backend, 136 | rank=args.rank, 137 | world_size=args.world_size, 138 | ) 139 | 140 | run_worker(args) 141 | 142 | 143 | if __name__ == "__main__": 144 | main() 145 | -------------------------------------------------------------------------------- /examples/tp+pp/pippy_tp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import argparse 3 | import os 4 | import unittest 5 | 6 | import torch 7 | 8 | import pippy 9 | import pippy.fx 10 | from pippy import pipe_split 11 | from pippy.compile import compile_stage 12 | 13 | import torch.distributed as dist 14 | from torch.distributed._tensor import ( 15 | DeviceMesh, 16 | ) 17 | from torch.distributed.tensor.parallel import ( 18 | PairwiseParallel, 19 | parallelize_module, 20 | ) 21 | 22 | 23 | pippy.fx.Tracer.proxy_buffer_attributes = True 24 | 25 | 26 | class MLPModule(torch.nn.Module): 27 | def __init__(self, d_hid): 28 | super(MLPModule, self).__init__() 29 | self.net1 = torch.nn.Linear(d_hid, d_hid) 30 | self.relu = torch.nn.ReLU() 31 | self.net2 = torch.nn.Linear(d_hid, d_hid) 32 | 33 | def forward(self, x): 34 | x = self.net1(x) 35 | x = self.relu(x) 36 | x = self.net2(x) 37 | return x 38 | 39 | 40 | class ExampleCode(torch.nn.Module): 41 | def __init__(self, d_hid): 42 | super().__init__() 43 | self.mlp0 = MLPModule(d_hid) 44 | self.mlp1 = MLPModule(d_hid) 45 | self.mlp2 = MLPModule(d_hid) 46 | self.mlp3 = MLPModule(d_hid) 47 | 48 | def forward(self, x): 49 | x = self.mlp0(x) 50 | pipe_split() 51 | x = self.mlp1(x) 52 | pipe_split() 53 | x = self.mlp2(x) 54 | pipe_split() 55 | x = self.mlp3(x) 56 | return x 57 | 58 | 59 | d_hid = 256 60 | batch_size_per_chunk = 8 61 | 62 | 63 | def run_all(args): 64 | # The seed here has two purposes: 65 | # - Ensure all TP ranks have same input 66 | # - Ensure the model (ec) created are the same, as if it comes from a 67 | # single, big model before partitioning 68 | torch.manual_seed(0) 69 | 70 | # Create original model 71 | ec = ExampleCode(d_hid) 72 | ec.to(args.device) 73 | 74 | # Create input 75 | inp_size = [args.chunks * batch_size_per_chunk, d_hid] 76 | device_type = args.device.type 77 | inp = torch.rand(*inp_size, device=args.device) 78 | 79 | # Create global DeviceMesh 80 | ranks = torch.arange(args.world_size) 81 | rank_mesh = ranks.reshape(args.pp_group_size, args.tp_group_size) 82 | pp_dim = 0 83 | tp_dim = 1 84 | dm = DeviceMesh( 85 | device_type, 86 | rank_mesh, 87 | ) 88 | 89 | # Figure out my PP and TP rank 90 | pp_rank = args.rank // args.tp_group_size 91 | tp_rank = args.rank % args.tp_group_size 92 | print(f"Global rank {args.rank}, pp rank: {pp_rank}, tp rank: {tp_rank}") 93 | 94 | # Get pp group 95 | # `tp_rank` can serve as pipeline id 96 | print(f"Rank {args.rank} Instantiating pipeline with ranks {dm.mesh[:, tp_rank]}") 97 | pp_group = dm.get_dim_groups()[pp_dim] 98 | 99 | # Get stage module (on all pp ranks) 100 | stage = compile_stage( 101 | ec, 102 | pp_rank, 103 | args.pp_group_size, 104 | args.chunks, 105 | args.device, 106 | pp_group, 107 | example_inputs=[inp], 108 | ) 109 | 110 | # Tensor parallelize submodules 111 | print(f"Rank {args.rank} TP-lize submodule with {dm.mesh[pp_rank]}") 112 | parallelize_module(stage.submod, dm, PairwiseParallel(), tp_mesh_dim = tp_dim) 113 | 114 | if pp_rank == 0: 115 | out = stage(inp) 116 | elif pp_rank == args.pp_group_size - 1: 117 | out = stage() 118 | else: 119 | stage() 120 | 121 | dist.barrier() 122 | print(f"Rank {args.rank} completes") 123 | 124 | # Last rank checks result 125 | if pp_rank == args.pp_group_size - 1: 126 | ref_out = ec(inp) 127 | torch.testing.assert_close(out, ref_out) 128 | print( 129 | f"Pipeline {tp_rank} equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" 130 | ) 131 | 132 | 133 | def main(args=None): 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument( 136 | "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 8)) 137 | ) 138 | # ExampleCode has 4 stages 139 | parser.add_argument( 140 | "--pp_group_size", type=int, default=4, 141 | ) 142 | # in row-major 143 | # TP ranks are contiguous rows of size `args.tp_group_size` 144 | # PP ranks are non-contiguous columns of size `args.pp_group_size` 145 | # 146 | # if tp_group_size = 4 and pp_group_size = 3 147 | # 148 | # 0 1 2 3 149 | # 4 5 6 7 150 | # 8 9 10 11 151 | # 152 | # TP ranks are [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11] 153 | # PP ranks are [0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11] 154 | parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) 155 | parser.add_argument( 156 | "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") 157 | ) 158 | parser.add_argument( 159 | "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") 160 | ) 161 | parser.add_argument( 162 | "--cuda", type=int, default=int(torch.cuda.is_available()) 163 | ) 164 | parser.add_argument( 165 | "--chunks", type=int, default=4, 166 | ) 167 | args = parser.parse_args(args) 168 | 169 | # Use world size to determine TP group size 170 | assert args.world_size % args.pp_group_size == 0 171 | args.tp_group_size = args.world_size // args.pp_group_size 172 | if args.rank == 0: 173 | print( 174 | f"Pipeline parallel size: {args.pp_group_size}\n" 175 | f"Tensor parallel size: {args.tp_group_size}" 176 | ) 177 | 178 | if args.cuda: 179 | dev_id = args.rank % torch.cuda.device_count() 180 | args.device = torch.device(f"cuda:{dev_id}") 181 | # HACK: we need to pin device here because `DeviceMesh` currently does 182 | # an all_gather with device_type only, without device id 183 | # https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/device_mesh.py#L191-L192 184 | torch.cuda.set_device(args.device) 185 | else: 186 | args.device = torch.device("cpu") 187 | 188 | # Init process group 189 | backend = "nccl" if args.cuda else "gloo" 190 | dist.init_process_group( 191 | backend=backend, 192 | rank=args.rank, 193 | world_size=args.world_size, 194 | ) 195 | 196 | run_all(args) 197 | 198 | 199 | if __name__ == "__main__": 200 | main() 201 | 202 | 203 | class LocalTestPiPPyTP(unittest.TestCase): 204 | def test_pp_tp(self): 205 | import random 206 | 207 | port = random.randint(29500, 30000) 208 | args = [ 209 | "--master_port", 210 | str(port), 211 | ] 212 | main(args) 213 | -------------------------------------------------------------------------------- /examples/unrolling/README.md: -------------------------------------------------------------------------------- 1 | ## What does this example do? 2 | 3 | This is a synthetic example used to demonstrate PiPPy's functionality in unrolling iterative blocks in a model. 4 | 5 | We create a model that runs an iteration block in a for loop: 6 | ```python 7 | class IterationBlock(torch.nn.Module): 8 | def __init__(self, d_hid): 9 | super().__init__() 10 | self.lin = torch.nn.Linear(d_hid, d_hid) 11 | 12 | def forward(self, x): 13 | x = self.lin(x) 14 | x = torch.relu(x) 15 | return x 16 | 17 | 18 | class IterativeNetwork(torch.nn.Module): 19 | def __init__(self, d_hid, num_iters): 20 | super().__init__() 21 | self.num_iters = num_iters 22 | self.iter_block = IterationBlock(d_hid) 23 | # 10 output classes 24 | self.output_proj = torch.nn.Linear(d_hid, 10) 25 | 26 | def forward(self, x): 27 | for i in range(self.num_iters): 28 | x = self.iter_block(x) 29 | return self.output_proj(x) 30 | ``` 31 | 32 | If we annotate the model as follows, we will create a pipeline stage per 33 | iteration block: 34 | 35 | ```python 36 | # Add a split point after each iter_block 37 | annotate_split_points( 38 | model, 39 | {"iter_block": SplitPoint.END}, 40 | ) 41 | ``` 42 | 43 | That is, PiPPy would create a split point every time it sees "self.iter_block". 44 | 45 | Run it with 4 ranks: 46 | ``` 47 | $ torchrun --nproc-per-node 4 pippy_unroll.py 48 | ``` 49 | 50 | Print-out of the pipe: 51 | ``` 52 | ************************************* pipe ************************************* 53 | GraphModule( 54 | (submod_0): PipeStageModule( 55 | (L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True) 56 | ) 57 | (submod_1): PipeStageModule( 58 | (L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True) 59 | ) 60 | (submod_2): PipeStageModule( 61 | (L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True) 62 | ) 63 | (submod_3): PipeStageModule( 64 | (L__self___output_proj): Linear(in_features=512, out_features=10, bias=True) 65 | ) 66 | ) 67 | 68 | def forward(self, arg0): 69 | submod_0 = self.submod_0(arg0); arg0 = None 70 | submod_1 = self.submod_1(submod_0); submod_0 = None 71 | submod_2 = self.submod_2(submod_1); submod_1 = None 72 | submod_3 = self.submod_3(submod_2); submod_2 = None 73 | return [submod_3] 74 | ``` 75 | We can see 4 stages as expected (3 iterations plus 1 output projection). 76 | 77 | If we print one of the stages, we can see that it contains the code of one iteration: 78 | ``` 79 | *********************************** submod0 ************************************ 80 | PipeStageModule( 81 | (L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True) 82 | ) 83 | 84 | def forward(self, l_x_): 85 | l__self___iter_block_mod_lin = self.L__self___iter_block_mod_lin(l_x_); l_x_ = None 86 | relu = torch.relu(l__self___iter_block_mod_lin); l__self___iter_block_mod_lin = None 87 | return relu 88 | ``` 89 | 90 | ## How can this functionality help? 91 | Increase throughput of your model. 92 | 93 | Imagine your for loop needs to iterate on the data for `n` times, and it takes time `t` to process 1 sample (yielding a throughput of `1/t`). If we were to unroll the for loop onto `n` devices, then we can push `n` microbatches into the pipeline, each microbatch containing 1 sample. Then at any timeslot, the pipeline is processing `n` samples, yielding a throughput of `n/t`. 94 | -------------------------------------------------------------------------------- /examples/unrolling/pippy_unroll.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # Minimal effort to run this code: 3 | # $ torchrun --nproc-per-node 4 pippy_unroll.py 4 | 5 | import os 6 | import torch 7 | import torch.distributed as dist 8 | 9 | from pippy import pipeline, PipelineStage, annotate_split_points, SplitPoint 10 | 11 | 12 | class IterationBlock(torch.nn.Module): 13 | def __init__(self, d_hid): 14 | super().__init__() 15 | self.lin = torch.nn.Linear(d_hid, d_hid) 16 | 17 | def forward(self, x): 18 | x = self.lin(x) 19 | x = torch.relu(x) 20 | return x 21 | 22 | 23 | class IterativeNetwork(torch.nn.Module): 24 | def __init__(self, d_hid, num_iters): 25 | super().__init__() 26 | self.num_iters = num_iters 27 | self.iter_block = IterationBlock(d_hid) 28 | # 10 output classes 29 | self.output_proj = torch.nn.Linear(d_hid, 10) 30 | 31 | def forward(self, x): 32 | for i in range(self.num_iters): 33 | x = self.iter_block(x) 34 | return self.output_proj(x) 35 | 36 | 37 | # We are using `torchrun` to run this example with multiple processes. 38 | # `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`. 39 | torch.manual_seed(0) 40 | rank = int(os.environ["RANK"]) 41 | world_size = int(os.environ["WORLD_SIZE"]) 42 | 43 | # Figure out device to use 44 | if torch.cuda.is_available(): 45 | device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") 46 | else: 47 | device = torch.device("cpu") 48 | 49 | # Create the model 50 | d_hid = 512 51 | # (n-1) iterations + 1 output projection 52 | num_iters = world_size - 1 53 | model = IterativeNetwork(d_hid, num_iters).to(device) 54 | 55 | # Add a split point after each iter_block 56 | annotate_split_points( 57 | model, 58 | {"iter_block": SplitPoint.END}, 59 | ) 60 | 61 | batch_size = 32 62 | example_input = torch.randn(batch_size, d_hid, device=device) 63 | chunks = world_size 64 | 65 | pipe = pipeline(model, chunks, example_args=(example_input,)) 66 | 67 | if rank == 0: 68 | print(" pipe ".center(80, "*")) 69 | print(pipe) 70 | print(" submod0 ".center(80, "*")) 71 | print(pipe.split_gm.submod_0) 72 | 73 | # Initialize distributed environment 74 | dist.init_process_group(rank=rank, world_size=world_size) 75 | 76 | # Pipeline stage is our main pipeline runtime. It takes in the pipe object, 77 | # the rank of this process, and the device. 78 | stage = PipelineStage(pipe, rank, device) 79 | 80 | # Input data 81 | x = torch.randn(batch_size, d_hid, device=device) 82 | 83 | # Run the pipeline with input `x`. Divide the batch into n micro-batches 84 | # and run them in parallel on the pipeline 85 | if rank == 0: 86 | stage(x) 87 | elif rank == world_size - 1: 88 | output = stage() 89 | else: 90 | stage() 91 | 92 | if rank == world_size - 1: 93 | # Run the original code and get the output for comparison 94 | reference_output = model(x) 95 | # Compare numerics of pipeline and original model 96 | torch.testing.assert_close(output, reference_output) 97 | print(" Pipeline parallel model ran successfully! ".center(80, "*")) 98 | -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # USAGE: ./format.sh [--show-targets] [--check] [TARGETS] 4 | # When used with --show-targets, list all default targets and exits. 5 | # When used with --check, reports errors but changes nothing. 6 | 7 | DEFAULT_TARGETS=() 8 | for f in $(git ls-files | grep '\.py$'); do 9 | case "$f" in 10 | 'pippy/_unflatten.py') 11 | # ignore 12 | ;; 13 | 14 | 'pippy/'*) 15 | DEFAULT_TARGETS+=( "$f" ) 16 | ;; 17 | 18 | 'examples/'*) 19 | # ignore 20 | ;; 21 | 22 | 'docs/'*) 23 | # ignore 24 | ;; 25 | 26 | 'test/'*fx*) 27 | # ignore 28 | ;; 29 | 30 | *) 31 | # include 32 | DEFAULT_TARGETS+=( "$f" ) 33 | ;; 34 | esac 35 | done 36 | 37 | function format() { 38 | local TARGET="$1" 39 | 40 | # TODO: enable autoflake and isort. 41 | # these are not currently enabeled because the existing 42 | # import structure has magic side-effects that need to 43 | # be cleaned up so that isort and autoflake don't break them. 44 | 45 | # | autoflake \ 46 | # --stdin-display-name "$TARGET" \ 47 | # --remove-all-unused-imports \ 48 | # - \ 49 | # | isort \ 50 | # --filename "$TARGET" \ 51 | # - \ 52 | 53 | cat "$TARGET" \ 54 | | black \ 55 | -q \ 56 | --stdin-filename "$TARGET" \ 57 | - 58 | 59 | return ${PIPESTATUS[-1]} 60 | } 61 | 62 | function format_check() { 63 | local TARGET="$1" 64 | local TFILE=$(mktemp) 65 | trap "rm $TFILE" EXIT 66 | 67 | format "$TARGET" > "$TFILE" 68 | 69 | diff -u "$TARGET" "$TFILE" 70 | 71 | return $? 72 | } 73 | 74 | function reformat_inplace() { 75 | local TARGET="$1" 76 | local TFILE=$(mktemp) 77 | trap "rm $TFILE" EXIT 78 | 79 | format "$TARGET" > "$TFILE" 80 | if (( $? )); then 81 | return $?; 82 | fi 83 | 84 | diff -q "$TARGET" "$TFILE" > /dev/null 85 | if (( $? )); then 86 | cat "$TFILE" > "$TARGET"; 87 | fi 88 | 89 | return 0 90 | } 91 | 92 | 93 | function main() { 94 | local CHECK 95 | local TARGETS 96 | 97 | CHECK=0 98 | TARGETS=() 99 | 100 | for x in "$@"; do 101 | case "$x" in 102 | '--show-targets') 103 | for f in "${DEFAULT_TARGETS[@]}"; do 104 | echo $f; 105 | done 106 | exit 0; 107 | ;; 108 | 109 | '--check') 110 | CHECK=1; 111 | ;; 112 | 113 | *) 114 | TARGETS+=( "$x" ) 115 | ;; 116 | esac 117 | done 118 | 119 | if (( ${#TARGETS[@]} == 0 )); then 120 | TARGETS=( ${DEFAULT_TARGETS[@]} ) 121 | fi 122 | 123 | PY_TARGETS=() 124 | for x in "${TARGETS[@]}"; do 125 | if [[ -d "$x" ]]; then 126 | PY_TARGETS+=( $(find "$x" -name '*.py' -or -name '*.pyi') ) 127 | 128 | elif [[ -f "$x" ]]; then 129 | case "$x" in 130 | *.py) 131 | PY_TARGETS+=( "$x" ); 132 | ;; 133 | esac 134 | fi 135 | done 136 | 137 | if (( $CHECK )); then 138 | local result 139 | result=0 140 | for x in "${PY_TARGETS[@]}"; do 141 | format_check "$x"; 142 | (( result|=$? )) 143 | done 144 | exit $result 145 | else 146 | local result 147 | result=0 148 | for x in "${PY_TARGETS[@]}"; do 149 | reformat_inplace "$x"; 150 | (( result|=$? )) 151 | done 152 | exit $result 153 | fi 154 | } 155 | 156 | main "$@" 157 | 158 | -------------------------------------------------------------------------------- /pippy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | from ._IR import ( 3 | annotate_split_points, 4 | ArgsChunkSpec, 5 | KwargsChunkSpec, 6 | Pipe, 7 | pipe_split, 8 | pipeline, 9 | SplitPoint, 10 | ) 11 | from ._PipelineStage import PipelineStage 12 | from .ManualPipelineStage import ManualPipelineStage 13 | from .ModelSplit import ( 14 | split_by_graph, 15 | split_into_equal_size, 16 | split_on_size_threshold, 17 | ) 18 | from .PipelineSchedule import ( 19 | Schedule1F1B, 20 | ScheduleGPipe, 21 | ScheduleInterleaved1F1B, 22 | ScheduleLoopedBFS, 23 | ) 24 | 25 | 26 | __all__ = [ 27 | "Pipe", 28 | "PipelineStage", 29 | "pipe_split", 30 | "SplitPoint", 31 | "annotate_split_points", 32 | "split_into_equal_size", 33 | "split_on_size_threshold", 34 | "split_by_graph", 35 | "pipeline", 36 | "Schedule1F1B", 37 | "ScheduleGPipe", 38 | "ScheduleInterleaved1F1B", 39 | "ScheduleLoopedBFS", 40 | "ManualPipelineStage", 41 | "ArgsChunkSpec", 42 | "KwargsChunkSpec", 43 | ] 44 | -------------------------------------------------------------------------------- /pippy/_backward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | from typing import List, Optional 3 | 4 | import torch 5 | 6 | from ._debug import map_debug_info 7 | 8 | 9 | def stage_backward( 10 | stage_output, 11 | output_grads, 12 | input_values, 13 | outputs_with_grads_idxs: Optional[List[int]] = None, # deprecated, not used 14 | ): 15 | """ 16 | This is a helper function to: 17 | 1. compute the gradients for the stage inputs, and 18 | 2. accumulate gradients for the stage module's parameters. 19 | 20 | Given the input value(s) and the corresponding gradient for the output 21 | value(s), compute and accumulate gradients for all parameter values (leaves 22 | in the autograd trace) as well as return a list of the gradients for the 23 | input values 24 | """ 25 | if outputs_with_grads_idxs is not None: 26 | # Deprecated, not used in runtime calls, only exists in compiler 27 | stage_output = [stage_output[i] for i in outputs_with_grads_idxs] 28 | output_grads = [output_grads[i] for i in outputs_with_grads_idxs] 29 | 30 | try: 31 | # stage_output may be a composite datatype like dict. Extract all individual 32 | # tensor values here 33 | stage_output_tensors = [] 34 | output_grad_tensors = [] 35 | 36 | def extract_tensors_with_grads(output_val, grad_val): 37 | if isinstance(output_val, torch.Tensor): 38 | if not output_val.requires_grad and output_val.grad_fn is None: 39 | return 40 | assert isinstance( 41 | grad_val, (torch.Tensor, type(None)) 42 | ), f"Expected Tensor or None gradient but got {type(grad_val)}" 43 | stage_output_tensors.append(output_val) 44 | output_grad_tensors.append(grad_val) 45 | elif isinstance(output_val, (tuple, list)): 46 | if grad_val is None: 47 | return 48 | assert isinstance( 49 | grad_val, (tuple, list) 50 | ), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" 51 | assert len(output_val) == len(grad_val) 52 | for ov, gv in zip(output_val, grad_val): 53 | extract_tensors_with_grads(ov, gv) 54 | elif isinstance(output_val, dict): 55 | if grad_val is None: 56 | return 57 | assert isinstance(grad_val, dict) 58 | assert set(output_val.keys()) == set(grad_val.keys()) 59 | for k in output_val.keys(): 60 | extract_tensors_with_grads(output_val[k], grad_val[k]) 61 | else: 62 | # Output is a non-tensor type; just ignore it 63 | pass 64 | 65 | extract_tensors_with_grads(stage_output, output_grads) 66 | 67 | torch.autograd.backward( 68 | stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type] 69 | ) 70 | 71 | # Extract gradients wrt the input values 72 | grad_inputs = [] 73 | for val in input_values: 74 | if isinstance(val, torch.Tensor): 75 | grad_inputs.append(val.grad) 76 | else: 77 | grad_inputs.append(None) 78 | 79 | # Alternative impl: `torch.autograd.grad`. 80 | # Note that `torch.autograd.grad` will not accumulate gradients into the 81 | # model's parameters. 82 | """ 83 | inputs_with_grad = [] 84 | for val in input_values: 85 | if isinstance(val, torch.Tensor) and val.requires_grad: 86 | inputs_with_grad.append(val) 87 | 88 | grad_inputs = torch.autograd.grad( 89 | stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type] 90 | ) 91 | """ 92 | 93 | except Exception as e: 94 | exc_msg = f""" 95 | Failed to run stage backward: 96 | Stage output: {map_debug_info(stage_output)} 97 | Output gradient: {map_debug_info(output_grads)} 98 | Input: {map_debug_info(input_values)} 99 | """ 100 | raise RuntimeError(exc_msg) from e 101 | 102 | return grad_inputs 103 | 104 | 105 | # TODO: handling requires_grad=False dynamically. Can we analyze this during initial 106 | # IR emission? 107 | def _null_coalesce_accumulate(lhs, rhs): 108 | """ 109 | Coalesce two values, even if one of them is null, returning the non-null 110 | value. 111 | """ 112 | if lhs is None: 113 | return rhs 114 | elif rhs is None: 115 | return lhs 116 | else: 117 | return torch.add(lhs, rhs) 118 | -------------------------------------------------------------------------------- /pippy/_debug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import logging 3 | import os 4 | 5 | import torch 6 | 7 | 8 | # PIPPY_VERBOSITY is an environment variable that controls the logging level. 9 | # It can be set to one of the following: 10 | # - WARNING (default) 11 | # - INFO 12 | # - DEBUG 13 | PIPPY_VERBOSITY = os.getenv("PIPPY_VERBOSITY", "WARNING") 14 | if PIPPY_VERBOSITY not in ["WARNING", "INFO", "DEBUG"]: 15 | logging.warning(f"Unsupported PIPPY_VERBOSITY level: {PIPPY_VERBOSITY}") 16 | PIPPY_VERBOSITY = "WARNING" 17 | 18 | logging.getLogger("pippy").setLevel(PIPPY_VERBOSITY) 19 | # It seems we need to print something to make the level setting effective 20 | # for child loggers. Doing it here. 21 | logging.warning(f"Setting PiPPy logging level to: {PIPPY_VERBOSITY}") 22 | 23 | 24 | def friendly_debug_info(v): 25 | """ 26 | Helper function to print out debug info in a friendly way. 27 | """ 28 | if isinstance(v, torch.Tensor): 29 | return f"Tensor({v.shape}, grad={v.requires_grad})" 30 | else: 31 | return str(v) 32 | 33 | 34 | def map_debug_info(a): 35 | """ 36 | Helper function to apply `friendly_debug_info` to items in `a`. 37 | `a` may be a list, tuple, or dict. 38 | """ 39 | return torch.fx.node.map_aggregate(a, friendly_debug_info) 40 | -------------------------------------------------------------------------------- /pippy/_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import logging 3 | from typing import Dict 4 | 5 | import torch 6 | from torch import fx 7 | from torch.export.unflatten import InterpreterModule 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def flatten_args_detach(args): 14 | """ 15 | Flatten the args into a list form and detach the tensors from computational graph. 16 | """ 17 | flat_detached_args = [] 18 | 19 | def extract_tensor_args(a): 20 | nonlocal flat_detached_args 21 | if isinstance(a, torch.Tensor): 22 | val = a.detach().requires_grad_(a.requires_grad) 23 | flat_detached_args.append(val) 24 | return val 25 | else: 26 | flat_detached_args.append(a) 27 | return a 28 | 29 | new_args = fx.node.map_aggregate( 30 | args, 31 | extract_tensor_args, 32 | ) 33 | 34 | return new_args, flat_detached_args 35 | 36 | 37 | def flatten_args(args): 38 | """ 39 | Flatten the args into a list form. 40 | """ 41 | flat_args = [] 42 | 43 | def extract_tensor_args(a): 44 | nonlocal flat_args 45 | flat_args.append(a) 46 | return a 47 | 48 | fx.node.map_aggregate( 49 | args, 50 | extract_tensor_args, 51 | ) 52 | 53 | return flat_args 54 | 55 | 56 | def modify_graph_op_device( 57 | gm: torch.fx.GraphModule, 58 | new_device: torch.device, 59 | ): 60 | """ 61 | Modify the device argument of all "call_function" nodes in the graph. This 62 | is useful for moving the graph to a different device. In particular for 63 | generator ops, like torch.ones. 64 | """ 65 | modified = False 66 | for node in gm.graph.nodes: 67 | if node.op == "call_function": 68 | if "device" in node.kwargs and node.kwargs["device"] != new_device: 69 | logger.debug( 70 | f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" 71 | ) 72 | node.update_kwarg("device", new_device) 73 | modified = True 74 | elif node.op == "call_module": 75 | # Recursively modify "device" in submodules 76 | submod = gm.get_submodule(node.target) 77 | if isinstance(submod, torch.fx.GraphModule): 78 | modify_graph_op_device(submod, new_device) 79 | elif isinstance(submod, InterpreterModule): 80 | # If unflattening has been performed, we need to access its graph module by `.graph_module` 81 | modify_graph_op_device(submod.graph_module, new_device) 82 | else: 83 | logger.warning( 84 | f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" 85 | ) 86 | 87 | if modified: 88 | gm.recompile() 89 | 90 | 91 | class QualnameMapMixin: 92 | """ 93 | A mixin class that helps a `Pipe` object to remap its qualnames back to 94 | original qualnames. 95 | """ 96 | 97 | def __init__( 98 | self, 99 | splitter_qualname_map: Dict[str, str] = None, 100 | tracer_qualname_map: Dict[str, str] = None, 101 | ): 102 | self.new_to_old_qualname_mapping: Dict[str, str] = ( 103 | splitter_qualname_map or {} 104 | ) 105 | self.tracer_qualname_map = tracer_qualname_map 106 | 107 | def remap_qualname(self, qualname: str): 108 | # TODO: annoying 109 | if qualname.startswith("split_gm."): 110 | qualname = qualname[len("split_gm.") :] 111 | 112 | name_before_split = None 113 | if qualname in self.new_to_old_qualname_mapping: 114 | name_before_split = self.new_to_old_qualname_mapping[qualname] 115 | else: 116 | # The qualname map does not store recursive items, thus, 117 | # when passed a qualname with leaves, we need to perform longest prefix match 118 | # Split from the right, one each time 119 | split_names = qualname.rsplit(".", 1) 120 | leaf = split_names[-1] 121 | while len(split_names) > 1: 122 | prefix = split_names[0] 123 | if prefix in self.new_to_old_qualname_mapping: 124 | old_prefix = self.new_to_old_qualname_mapping[prefix] 125 | name_before_split = ".".join([old_prefix, leaf]) 126 | break 127 | split_names = prefix.rsplit(".", 1) 128 | leaf = ".".join([split_names[-1], leaf]) 129 | 130 | if name_before_split is None: 131 | raise RuntimeError(f"Could not find mapping for {qualname}") 132 | 133 | if self.tracer_qualname_map is not None: 134 | return self.tracer_qualname_map[name_before_split] 135 | else: 136 | return name_before_split 137 | -------------------------------------------------------------------------------- /pippy/utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/PiPPy/1bcb2bfb2d6cc4ac2125c0edb37c35585bb9695f/pippy/utilities/__init__.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 80 3 | 4 | [tool.mypy] 5 | warn_unused_configs = true 6 | ignore_missing_imports = true 7 | warn_redundant_casts = true 8 | show_error_codes = true 9 | show_column_numbers = true 10 | check_untyped_defs = true 11 | follow_imports = "silent" 12 | warn_unused_ignores = false 13 | exclude = [ 14 | '^docs', 15 | ] 16 | 17 | [tool.ufmt] 18 | excludes = [ 19 | "pippy/_unflatten.py", 20 | ] 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 2.3.0.dev 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import distutils.command.clean 3 | import glob 4 | import os 5 | import shutil 6 | import subprocess 7 | from typing import Dict 8 | from setuptools import setup, find_packages 9 | 10 | 11 | # Package name 12 | package_name = "torchpippy" 13 | 14 | # Version information 15 | cwd = os.path.dirname(os.path.abspath(__file__)) 16 | version_txt = os.path.join(cwd, "version.txt") 17 | with open(version_txt, "r") as f: 18 | version = f.readline().strip() 19 | 20 | try: 21 | sha = ( 22 | subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd) 23 | .decode("ascii") 24 | .strip() 25 | ) 26 | except Exception: 27 | sha = "Unknown" 28 | 29 | if os.getenv("BUILD_VERSION"): 30 | version = os.getenv("BUILD_VERSION", version) 31 | elif os.getenv("VERSION_NO_GIT", "0") == "1": 32 | pass 33 | elif sha != "Unknown": 34 | version += "+" + sha[:7] 35 | 36 | 37 | def write_version_file(): 38 | version_path = os.path.join(cwd, "pippy", "version.py") 39 | with open(version_path, "w") as f: 40 | f.write("__version__ = '{}'\n".format(version)) 41 | f.write("git_version = {}\n".format(repr(sha))) 42 | 43 | 44 | # Package requirements 45 | requirements = [ 46 | # If the torch version has a ".dev" suffix, it would represent a nightly version of PyTorch. 47 | # It can be installed as a binary or from source. 48 | "torch>=2.3.0.dev", 49 | ] 50 | 51 | extras: Dict = {} 52 | 53 | 54 | long_description = """ 55 | The PiPPy project stands for Pipeline Parallelism for PyTorch. It consists of a 56 | compiler and runtime stack for automated parallelism and scaling of PyTorch 57 | models. PiPPy partitions the code of the model in a pipelined fashion and 58 | enables multiple micro-batches to execute different parts of the model code 59 | concurrently. For details, please visit PiPPy's [GitHub 60 | page](https://github.com/pytorch/PiPPy). 61 | """ 62 | 63 | 64 | class clean(distutils.command.clean.clean): # type: ignore 65 | def run(self): 66 | with open(".gitignore", "r") as f: 67 | ignores = f.read() 68 | for wildcard in filter(None, ignores.split("\n")): 69 | for filename in glob.glob(wildcard): 70 | try: 71 | os.remove(filename) 72 | except OSError: 73 | shutil.rmtree(filename, ignore_errors=True) 74 | 75 | # It's an old-style class in Python 2.7... 76 | distutils.command.clean.clean.run(self) 77 | 78 | 79 | if __name__ == "__main__": 80 | write_version_file() 81 | 82 | setup( 83 | # Metadata 84 | name=package_name, 85 | version=version, 86 | author="PiPPy Team", 87 | url="https://github.com/pytorch/PiPPy", 88 | description="Pipeline Parallelism for PyTorch", 89 | license="BSD", 90 | # Package info 91 | packages=find_packages(), 92 | install_requires=requirements, 93 | extras_require=extras, 94 | cmdclass={ 95 | "clean": clean, 96 | }, 97 | long_description=long_description, 98 | long_description_content_type="text/markdown", 99 | ) 100 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | -------------------------------------------------------------------------------- /test/multinode_trainer.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=test_pipeline_schedules 4 | 5 | #SBATCH --ntasks=2 6 | 7 | #SBATCH --nodes=2 8 | 9 | #SBATCH --gpus-per-task=8 10 | 11 | #SBATCH --cpus-per-task=96 12 | 13 | #SBATCH --partition=train 14 | 15 | 16 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) 17 | nodes_array=($nodes) 18 | head_node=${nodes_array[0]} 19 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) 20 | 21 | echo Node IP: $head_node_ip 22 | export LOGLEVEL=INFO 23 | # Enable for A100 24 | export FI_PROVIDER="efa" 25 | # Ensure that P2P is available 26 | # export NCCL_P2P_DISABLE=1 27 | export NCCL_IB_DISABLE=1 28 | 29 | # debugging flags (optional) 30 | export NCCL_DEBUG=WARN 31 | export PYTHONFAULTHANDLER=1 32 | # optional debug settings 33 | # export NCCL_DEBUG=INFO 34 | # NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV 35 | 36 | export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH 37 | export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH 38 | export CUDA_LAUNCH_BLOCKING=0 39 | 40 | # on your cluster you might need these: 41 | # set the network interface 42 | export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond" 43 | export NCCL_BUFFSIZE=2097152 44 | #export TORCH_DIST_INIT_BARRIER=1 45 | export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 46 | 47 | dcgmi profile --pause 48 | # adjust sbatch --ntasks and sbatch --nodes above and --nnodes below 49 | # to your specific node count, and update target launch file. 50 | srun torchrun --nnodes 2 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./test_pipeline_schedule.py --schedules gpipe looped_bfs 51 | dcgmi profile --resume 52 | -------------------------------------------------------------------------------- /test/test_autosplit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import argparse 3 | import os 4 | import unittest 5 | 6 | import pippy 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from pippy import pipeline, PipelineStage, ScheduleGPipe, split_into_equal_size 11 | 12 | 13 | pippy.microbatch._debug_mask_minibatches = True 14 | 15 | d_hid = 512 16 | batch_size = 256 17 | 18 | torch.manual_seed(0) 19 | 20 | 21 | class ExampleCode(torch.nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 25 | self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 26 | self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 27 | self.lin1 = torch.nn.Linear(d_hid, d_hid) 28 | self.lin2 = torch.nn.Linear(d_hid, d_hid) 29 | self.register_buffer("buffer", torch.randn(batch_size + 100, d_hid)) 30 | 31 | def forward(self, x): 32 | x = torch.mm(x, self.mm_param0) 33 | x = torch.relu(x) 34 | x = torch.mm(x, self.mm_param1) + self.buffer[: x.shape[0]] 35 | x = self.lin1(x) 36 | x = torch.relu(x) 37 | x = torch.mm(x, self.mm_param2) 38 | x = self.lin2(x) 39 | x = torch.relu(x) 40 | return x 41 | 42 | 43 | def run_worker(args): 44 | mod = ExampleCode() 45 | mod.to(args.device) 46 | 47 | x = torch.randn(batch_size, d_hid, device=args.device) 48 | 49 | split_policy = split_into_equal_size(args.world_size) 50 | 51 | pipe = pipeline( 52 | mod, 53 | args.chunks, 54 | example_args=(x,), 55 | split_policy=split_policy, 56 | ) 57 | 58 | # Check returned number of stages 59 | assert ( 60 | pipe.num_stages == args.world_size 61 | ), f"Model is split into {pipe.num_stages} stages instead of {args.world_size}" 62 | print(f"Split test passed: got {pipe.num_stages} stages") 63 | 64 | stage = PipelineStage( 65 | pipe, 66 | args.rank, 67 | device=args.device, 68 | ) 69 | 70 | # Attach to a schedule 71 | schedule = ScheduleGPipe(stage, args.chunks) 72 | 73 | # Run 74 | if args.rank == 0: 75 | schedule.step(x) 76 | else: 77 | out = schedule.step() 78 | 79 | dist.barrier() 80 | print(f"Rank {args.rank} completes") 81 | 82 | # Last rank checks result 83 | if args.rank == args.world_size - 1: 84 | ref_out = mod(x) 85 | torch.testing.assert_close(out, ref_out) 86 | print( 87 | f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" 88 | ) 89 | 90 | 91 | def main(args=None): 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument( 94 | "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) 95 | ) 96 | parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) 97 | parser.add_argument( 98 | "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") 99 | ) 100 | parser.add_argument( 101 | "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") 102 | ) 103 | parser.add_argument( 104 | "--cuda", type=int, default=int(torch.cuda.is_available()) 105 | ) 106 | parser.add_argument( 107 | "--chunks", 108 | type=int, 109 | default=4, 110 | ) 111 | args = parser.parse_args(args) 112 | 113 | if args.cuda: 114 | dev_id = args.rank % torch.cuda.device_count() 115 | args.device = torch.device(f"cuda:{dev_id}") 116 | else: 117 | args.device = torch.device("cpu") 118 | 119 | # Init process group 120 | backend = "nccl" if args.cuda else "gloo" 121 | dist.init_process_group( 122 | backend=backend, 123 | rank=args.rank, 124 | world_size=args.world_size, 125 | ) 126 | 127 | run_worker(args) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | 133 | 134 | class LocalTestAutoSplit(unittest.TestCase): 135 | def test_auto_split(self): 136 | import random 137 | 138 | port = random.randint(29500, 30000) 139 | args = [ 140 | "--master_port", 141 | str(port), 142 | ] 143 | main(args) 144 | -------------------------------------------------------------------------------- /test/test_bwd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import argparse 3 | import os 4 | import unittest 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | from pippy import ( 10 | pipe_split, 11 | pipeline, 12 | PipelineStage, 13 | Schedule1F1B, 14 | ScheduleGPipe, 15 | ) 16 | 17 | 18 | schedule_map = { 19 | "gpipe": ScheduleGPipe, 20 | "1f1b": Schedule1F1B, 21 | } 22 | 23 | d_hid = 512 24 | batch_size = 256 25 | 26 | torch.manual_seed(0) 27 | 28 | 29 | # Basic example 30 | class ExampleCode(torch.nn.Module): 31 | def __init__(self): 32 | super().__init__() 33 | self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 34 | self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 35 | self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 36 | self.register_buffer("cval", torch.randn((d_hid,), requires_grad=False)) 37 | self.lin1 = torch.nn.Linear(d_hid, d_hid) 38 | self.lin2 = torch.nn.Linear(d_hid, d_hid) 39 | 40 | def forward(self, x): 41 | x = torch.mm(x, self.mm_param0) 42 | x = torch.relu(x) 43 | pipe_split() 44 | x = torch.mm(x, self.mm_param1) 45 | # try passing a value that doesn't require_grad across skip boundaries 46 | a_constant = self.cval.clone() 47 | x = self.lin1(x) 48 | pipe_split() 49 | x = torch.relu(x) + a_constant 50 | x = torch.mm(x, self.mm_param2) 51 | pipe_split() 52 | x = self.lin2(x) 53 | logits = torch.relu(x) 54 | return logits 55 | 56 | 57 | def run_worker(args): 58 | mod = ExampleCode() 59 | mod.to(args.device) 60 | 61 | x = torch.randn(batch_size, d_hid, device=args.device) 62 | target = torch.randn(batch_size, d_hid, device=args.device) 63 | loss_fn = torch.nn.MSELoss(reduction="sum") 64 | 65 | pipe = pipeline( 66 | mod, 67 | args.chunks, 68 | example_args=(x,), 69 | ) 70 | 71 | stage = PipelineStage( 72 | pipe, 73 | args.rank, 74 | device=args.device, 75 | ) 76 | 77 | # Attach to a schedule 78 | ScheduleClass = schedule_map[args.schedule] 79 | schedule = ScheduleClass(stage, args.chunks, loss_fn=loss_fn) 80 | 81 | # Run 82 | if args.rank == 0: 83 | schedule.step(x) 84 | elif args.rank == args.world_size - 1: 85 | losses = [] 86 | out = schedule.step(target=target, losses=losses) 87 | else: 88 | schedule.step() 89 | 90 | dist.barrier() 91 | print(f"Rank {args.rank} completes") 92 | 93 | # Last rank checks result 94 | if args.rank == args.world_size - 1: 95 | ref_out = mod(x) 96 | ref_loss = loss_fn(ref_out, target) 97 | pipe_loss = sum(losses) 98 | torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=5e-3) 99 | torch.testing.assert_close(pipe_loss, ref_loss) 100 | print( 101 | f"equivalence test passed pipe_loss={pipe_loss} ref_loss={ref_loss}" 102 | ) 103 | 104 | 105 | def main(args=None): 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument( 108 | "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) 109 | ) 110 | parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) 111 | parser.add_argument( 112 | "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") 113 | ) 114 | parser.add_argument( 115 | "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") 116 | ) 117 | parser.add_argument( 118 | "--cuda", type=int, default=int(torch.cuda.is_available()) 119 | ) 120 | parser.add_argument( 121 | "--chunks", 122 | type=int, 123 | default=4, 124 | ) 125 | parser.add_argument( 126 | "--schedule", 127 | type=str, 128 | default="gpipe", 129 | choices=schedule_map.keys(), 130 | ) 131 | args = parser.parse_args(args) 132 | 133 | if args.cuda: 134 | dev_id = args.rank % torch.cuda.device_count() 135 | args.device = torch.device(f"cuda:{dev_id}") 136 | else: 137 | args.device = torch.device("cpu") 138 | 139 | # Init process group 140 | backend = "nccl" if args.cuda else "gloo" 141 | dist.init_process_group( 142 | backend=backend, 143 | rank=args.rank, 144 | world_size=args.world_size, 145 | ) 146 | 147 | run_worker(args) 148 | 149 | 150 | if __name__ == "__main__": 151 | main() 152 | 153 | 154 | class TestBwd(unittest.TestCase): 155 | def test_bwd(self): 156 | import random 157 | 158 | port = random.randint(29500, 30000) 159 | args = [ 160 | "--master_port", 161 | str(port), 162 | ] 163 | main(args) 164 | -------------------------------------------------------------------------------- /test/test_chunkspec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import unittest 3 | 4 | import torch 5 | 6 | from pippy import ArgsChunkSpec, KwargsChunkSpec, pipe_split, pipeline 7 | 8 | 9 | d_hid = 512 10 | batch_size = 256 11 | chunks = 4 12 | 13 | torch.manual_seed(0) 14 | 15 | 16 | class ExampleCode(torch.nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 20 | self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 21 | self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 22 | self.lin1 = torch.nn.Linear(d_hid, d_hid) 23 | self.lin2 = torch.nn.Linear(d_hid, d_hid) 24 | 25 | def forward(self, x, y, z=torch.zeros(batch_size, d_hid)): 26 | x = torch.mm(x, self.mm_param0) 27 | x = x + y 28 | x = torch.relu(x) 29 | x = x + z 30 | pipe_split() 31 | x = torch.mm(x, self.mm_param1) 32 | x = self.lin1(x) 33 | pipe_split() 34 | x = torch.relu(x) 35 | x = torch.mm(x, self.mm_param2) 36 | pipe_split() 37 | x = self.lin2(x) 38 | x = torch.relu(x) 39 | return x 40 | 41 | 42 | def main(args=None): 43 | mod = ExampleCode() 44 | 45 | x = torch.randn(batch_size, d_hid) 46 | y = torch.randn(batch_size, d_hid) 47 | z = torch.randn(batch_size, d_hid) 48 | 49 | with ArgsChunkSpec((0, 0)), KwargsChunkSpec({"z": 0}): 50 | pipe = pipeline( 51 | mod, 52 | chunks, 53 | example_args=(x, y), 54 | example_kwargs={"z": z}, 55 | ) 56 | 57 | assert pipe.num_stages == 4 58 | 59 | ref = mod(x, y, z) 60 | out = pipe(x, y, z)[0] 61 | torch.testing.assert_close(out, ref) 62 | print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | 68 | 69 | class TestChunkSpec(unittest.TestCase): 70 | def test_chunk_spec(self): 71 | main() 72 | -------------------------------------------------------------------------------- /test/test_cpu_init.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import argparse 3 | import os 4 | import unittest 5 | 6 | import pippy 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from pippy import pipe_split, pipeline, PipelineStage, ScheduleGPipe 11 | 12 | 13 | pippy.microbatch._debug_mask_minibatches = True 14 | 15 | d_hid = 512 16 | batch_size = 256 17 | 18 | torch.manual_seed(0) 19 | 20 | 21 | class ExampleCode(torch.nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 25 | self.lin = torch.nn.Linear(d_hid, d_hid) 26 | 27 | def forward(self, x): 28 | # Test change of tensor creation device after tracing 29 | a = torch.ones(batch_size, d_hid, device=x.device) 30 | x = x + a 31 | x = torch.mm(x, self.mm_param) 32 | x = torch.relu(x) 33 | pipe_split() 34 | # Test change of tensor creation device after tracing 35 | b = torch.zeros(batch_size, d_hid, device=x.device) 36 | x = self.lin(x) 37 | x = x + b 38 | x = torch.relu(x) 39 | return x 40 | 41 | 42 | def run_worker(args): 43 | # Create module and trace model in CPU 44 | mod = ExampleCode() 45 | 46 | xe = torch.randn(batch_size, d_hid) 47 | 48 | pipe = pipeline( 49 | mod, 50 | args.chunks, 51 | example_args=(xe,), 52 | ) 53 | 54 | # Create pipeline stages and move stage to GPU 55 | stage = PipelineStage( 56 | pipe, 57 | args.rank, 58 | device=args.device, 59 | ) 60 | 61 | # Attach to a schedule 62 | schedule = ScheduleGPipe( 63 | stage, 64 | args.chunks, 65 | ) 66 | 67 | # Create real input on real device 68 | x = torch.randn(batch_size, d_hid, device=args.device) 69 | 70 | # Run 71 | if args.rank == 0: 72 | schedule.step(x) 73 | else: 74 | out = schedule.step() 75 | 76 | dist.barrier() 77 | print(f"Rank {args.rank} completes") 78 | 79 | # Last rank checks result 80 | if args.rank == args.world_size - 1: 81 | mod.to(args.device) 82 | ref_out = mod(x) 83 | torch.testing.assert_close(out, ref_out) 84 | print( 85 | f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" 86 | ) 87 | 88 | 89 | def main(args=None): 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument( 92 | "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 2)) 93 | ) 94 | parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) 95 | parser.add_argument( 96 | "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") 97 | ) 98 | parser.add_argument( 99 | "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") 100 | ) 101 | parser.add_argument( 102 | "--cuda", type=int, default=int(torch.cuda.is_available()) 103 | ) 104 | parser.add_argument( 105 | "--chunks", 106 | type=int, 107 | default=4, 108 | ) 109 | args = parser.parse_args(args) 110 | 111 | if args.cuda: 112 | dev_id = args.rank % torch.cuda.device_count() 113 | args.device = torch.device(f"cuda:{dev_id}") 114 | else: 115 | args.device = torch.device("cpu") 116 | 117 | # Init process group 118 | backend = "nccl" if args.cuda else "gloo" 119 | dist.init_process_group( 120 | backend=backend, 121 | rank=args.rank, 122 | world_size=args.world_size, 123 | ) 124 | 125 | run_worker(args) 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | 131 | 132 | class TestFwd(unittest.TestCase): 133 | def test_fwd(self): 134 | import random 135 | 136 | port = random.randint(29500, 30000) 137 | args = [ 138 | "--master_port", 139 | str(port), 140 | ] 141 | main(args) 142 | -------------------------------------------------------------------------------- /test/test_fwd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import argparse 3 | import os 4 | import unittest 5 | 6 | import pippy 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from pippy import pipe_split, pipeline, PipelineStage, ScheduleGPipe 11 | 12 | 13 | pippy.microbatch._debug_mask_minibatches = True 14 | 15 | d_hid = 512 16 | batch_size = 256 17 | 18 | torch.manual_seed(0) 19 | 20 | 21 | class ExampleCode(torch.nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 25 | self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 26 | self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 27 | self.lin1 = torch.nn.Linear(d_hid, d_hid) 28 | self.lin2 = torch.nn.Linear(d_hid, d_hid) 29 | 30 | def forward(self, x, y=torch.zeros(batch_size, d_hid)): 31 | x = torch.mm(x, self.mm_param0) 32 | x = x + y 33 | x = torch.relu(x) 34 | pipe_split() 35 | x = torch.mm(x, self.mm_param1) 36 | x = self.lin1(x) 37 | pipe_split() 38 | x = torch.relu(x) 39 | x = torch.mm(x, self.mm_param2) 40 | pipe_split() 41 | x = self.lin2(x) 42 | x = torch.relu(x) 43 | return x 44 | 45 | 46 | def run_worker(args): 47 | mod = ExampleCode() 48 | mod.to(args.device) 49 | 50 | x = torch.randn(batch_size, d_hid, device=args.device) 51 | y = torch.randn(batch_size, d_hid, device=args.device) 52 | 53 | pipe = pipeline( 54 | mod, 55 | args.chunks, 56 | example_args=(x,), 57 | example_kwargs={"y": y}, 58 | ) 59 | 60 | stage = PipelineStage( 61 | pipe, 62 | args.rank, 63 | device=args.device, 64 | ) 65 | 66 | # Attach to a schedule 67 | schedule = ScheduleGPipe(stage, args.chunks) 68 | 69 | # Run 70 | if args.rank == 0: 71 | schedule.step(x, y=y) 72 | else: 73 | out = schedule.step() 74 | 75 | dist.barrier() 76 | print(f"Rank {args.rank} completes") 77 | 78 | # Last rank checks result 79 | if args.rank == args.world_size - 1: 80 | ref_out = mod(x, y=y) 81 | torch.testing.assert_close(out, ref_out) 82 | print( 83 | f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" 84 | ) 85 | 86 | # Test qualname mapping 87 | submod_keys = stage.submod.state_dict().keys() 88 | print(f"Rank {args.rank} state dict keys: {submod_keys}") 89 | # Confirm keys are consistent with original model 90 | old_keys = mod.state_dict().keys() 91 | assert all(k in old_keys for k in submod_keys) 92 | print(f"Qualname test passed") 93 | 94 | 95 | def main(args=None): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument( 98 | "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) 99 | ) 100 | parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) 101 | parser.add_argument( 102 | "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") 103 | ) 104 | parser.add_argument( 105 | "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") 106 | ) 107 | parser.add_argument( 108 | "--cuda", type=int, default=int(torch.cuda.is_available()) 109 | ) 110 | parser.add_argument( 111 | "--chunks", 112 | type=int, 113 | default=4, 114 | ) 115 | args = parser.parse_args(args) 116 | 117 | if args.cuda: 118 | dev_id = args.rank % torch.cuda.device_count() 119 | args.device = torch.device(f"cuda:{dev_id}") 120 | else: 121 | args.device = torch.device("cpu") 122 | 123 | # Init process group 124 | backend = "nccl" if args.cuda else "gloo" 125 | dist.init_process_group( 126 | backend=backend, 127 | rank=args.rank, 128 | world_size=args.world_size, 129 | ) 130 | 131 | run_worker(args) 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | 137 | 138 | class TestFwd(unittest.TestCase): 139 | def test_fwd(self): 140 | import random 141 | 142 | port = random.randint(29500, 30000) 143 | args = [ 144 | "--master_port", 145 | str(port), 146 | ] 147 | main(args) 148 | -------------------------------------------------------------------------------- /test/test_grad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import argparse 3 | import copy 4 | import os 5 | import unittest 6 | 7 | import torch 8 | import torch.distributed as dist 9 | 10 | from pippy import ( 11 | pipe_split, 12 | pipeline, 13 | PipelineStage, 14 | Schedule1F1B, 15 | ScheduleGPipe, 16 | ) 17 | 18 | 19 | schedule_map = { 20 | "gpipe": ScheduleGPipe, 21 | "1f1b": Schedule1F1B, 22 | } 23 | 24 | d_hid = 512 25 | batch_size = 256 26 | 27 | torch.manual_seed(0) 28 | 29 | 30 | # MLP example 31 | class MLPModule(torch.nn.Module): 32 | def __init__(self, d_hid): 33 | super(MLPModule, self).__init__() 34 | self.net1 = torch.nn.Linear(d_hid, d_hid) 35 | self.relu = torch.nn.ReLU() 36 | self.net2 = torch.nn.Linear(d_hid, d_hid) 37 | 38 | def forward(self, x): 39 | x = self.net1(x) 40 | x = self.relu(x) 41 | x = self.net2(x) 42 | return x 43 | 44 | 45 | class MultiMLP(torch.nn.Module): 46 | def __init__(self): 47 | super().__init__() 48 | self.mlp0 = MLPModule(d_hid) 49 | self.mlp1 = MLPModule(d_hid) 50 | self.mlp2 = MLPModule(d_hid) 51 | self.mlp3 = MLPModule(d_hid) 52 | 53 | def forward(self, x): 54 | x = self.mlp0(x) 55 | pipe_split() 56 | x = self.mlp1(x) 57 | pipe_split() 58 | x = self.mlp2(x) 59 | pipe_split() 60 | x = self.mlp3(x) 61 | return x 62 | 63 | 64 | def run_worker(args): 65 | mod = MultiMLP() 66 | mod.to(args.device) 67 | 68 | ref_mod = copy.deepcopy(mod) 69 | x = torch.randn(batch_size, d_hid, device=args.device) 70 | with torch.no_grad(): 71 | y = ref_mod(x) 72 | # Add a small perturbation 73 | target = y + torch.randn(batch_size, d_hid, device=args.device) 74 | 75 | loss_fn = torch.nn.MSELoss(reduction="sum") 76 | 77 | # Run reference 78 | for _ in range(2): 79 | ref_mod.zero_grad() 80 | ref_out = ref_mod(x) 81 | ref_loss = loss_fn(ref_out, target) 82 | ref_loss.backward() 83 | 84 | # Create a pipeline 85 | pipe = pipeline( 86 | mod, 87 | args.chunks, 88 | example_args=(x,), 89 | ) 90 | 91 | stage = PipelineStage( 92 | pipe, 93 | args.rank, 94 | device=args.device, 95 | ) 96 | 97 | # Attach to a schedule 98 | ScheduleClass = schedule_map[args.schedule] 99 | schedule = ScheduleClass(stage, args.chunks, loss_fn=loss_fn) 100 | 101 | # Run 102 | stage_module = pipe.get_stage_module(args.rank) 103 | for _ in range(2): 104 | # Zero gradients 105 | stage_module.zero_grad() 106 | if args.rank == 0: 107 | schedule.step(x) 108 | elif args.rank == args.world_size - 1: 109 | losses = [] 110 | out = schedule.step(target=target, losses=losses) 111 | else: 112 | schedule.step() 113 | 114 | dist.barrier() 115 | print(f"Rank {args.rank} completes") 116 | 117 | # Last rank checks result 118 | if args.rank == args.world_size - 1: 119 | # Check output 120 | torch.testing.assert_close(out, ref_out) 121 | print("Output test passed") 122 | # Check loss 123 | # Since the reduction used in the loss function above is "sum", we use 124 | # "sum" here to reduce microbatch losses into a single value too. 125 | pipe_loss = sum(losses) 126 | torch.testing.assert_close(pipe_loss, ref_loss) 127 | print("Loss test passed") 128 | 129 | # Every rank checks gradients 130 | for name, p in stage_module.named_parameters(): 131 | ref_p = ref_mod.get_parameter(name) 132 | try: 133 | torch.testing.assert_close(p.grad, ref_p.grad) 134 | except AssertionError: 135 | print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") 136 | raise 137 | print(f"Rank {args.rank} Gradient test passed") 138 | 139 | 140 | def main(args=None): 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument( 143 | "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) 144 | ) 145 | parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) 146 | parser.add_argument( 147 | "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") 148 | ) 149 | parser.add_argument( 150 | "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") 151 | ) 152 | parser.add_argument( 153 | "--cuda", type=int, default=int(torch.cuda.is_available()) 154 | ) 155 | parser.add_argument( 156 | "--chunks", 157 | type=int, 158 | default=4, 159 | ) 160 | parser.add_argument( 161 | "--schedule", 162 | type=str, 163 | default="gpipe", 164 | choices=schedule_map.keys(), 165 | ) 166 | args = parser.parse_args(args) 167 | 168 | if args.cuda: 169 | dev_id = args.rank % torch.cuda.device_count() 170 | args.device = torch.device(f"cuda:{dev_id}") 171 | else: 172 | args.device = torch.device("cpu") 173 | 174 | # Init process group 175 | backend = "nccl" if args.cuda else "gloo" 176 | dist.init_process_group( 177 | backend=backend, 178 | rank=args.rank, 179 | world_size=args.world_size, 180 | ) 181 | 182 | run_worker(args) 183 | 184 | 185 | if __name__ == "__main__": 186 | main() 187 | 188 | 189 | class TestGrad(unittest.TestCase): 190 | def test_grad(self): 191 | import random 192 | 193 | port = random.randint(29500, 30000) 194 | args = [ 195 | "--master_port", 196 | str(port), 197 | ] 198 | main(args) 199 | -------------------------------------------------------------------------------- /test/test_interleave.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import argparse 3 | import os 4 | import unittest 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | from pippy import ( 10 | annotate_split_points, 11 | pipeline, 12 | PipelineStage, 13 | ScheduleInterleaved1F1B, 14 | ScheduleLoopedBFS, 15 | SplitPoint, 16 | ) 17 | 18 | # Using same key words as single-stage tests for convenience in CI. 19 | schedule_map = { 20 | "gpipe": ScheduleLoopedBFS, # BFS is a generalization of gpipe 21 | "1f1b": ScheduleInterleaved1F1B, 22 | } 23 | 24 | d_hid = 16 25 | n_layers = 8 26 | batch_size = 16 27 | 28 | torch.manual_seed(0) 29 | 30 | 31 | class MLPModule(torch.nn.Module): 32 | def __init__(self, d_hid): 33 | super(MLPModule, self).__init__() 34 | self.net1 = torch.nn.Linear(d_hid, d_hid) 35 | self.relu = torch.nn.ReLU() 36 | self.net2 = torch.nn.Linear(d_hid, d_hid) 37 | 38 | def forward(self, x): 39 | x = self.net1(x) 40 | x = self.relu(x) 41 | x = self.net2(x) 42 | return x 43 | 44 | 45 | class TransformerLike(torch.nn.Module): 46 | def __init__(self) -> None: 47 | super().__init__() 48 | self.layers = torch.nn.Sequential( 49 | *[MLPModule(d_hid) for _ in range(n_layers)] 50 | ) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | return self.layers(x) 54 | 55 | 56 | def run_worker(args): 57 | model = TransformerLike().to(args.device) 58 | x = torch.randn(batch_size, d_hid, device=args.device) 59 | target = torch.randn(batch_size, d_hid, device=args.device) 60 | loss_fn = torch.nn.MSELoss(reduction="sum") 61 | 62 | # Two stages per rank 63 | num_stages = 2 * args.world_size 64 | 65 | # Split model into stages 66 | layers_per_stage = n_layers // num_stages 67 | for stage_idx in range(1, num_stages): 68 | annotate_split_points( 69 | model, 70 | {f"layers.{layers_per_stage * stage_idx}": SplitPoint.BEGINNING}, 71 | ) 72 | 73 | pipe = pipeline( 74 | model, 75 | args.chunks, 76 | (x,), 77 | ) 78 | assert pipe.num_stages == num_stages, f"{pipe.num_stages} != {num_stages}" 79 | 80 | # Collect my stages 81 | stages = [] 82 | for stage_idx in range(pipe.num_stages): 83 | if stage_idx % args.world_size == args.rank: 84 | stage = PipelineStage(pipe, stage_idx, device=args.device) 85 | stages.append(stage) 86 | 87 | # Attach to an interleaving schedule 88 | ScheduleClass = schedule_map[args.schedule] 89 | schedule = ScheduleClass(stages, args.chunks, loss_fn=loss_fn) 90 | 91 | # Run 92 | if args.rank == 0: 93 | schedule.step(x) 94 | elif args.rank == args.world_size - 1: 95 | losses = [] 96 | out = schedule.step(target=target, losses=losses) 97 | else: 98 | schedule.step() 99 | 100 | dist.barrier() 101 | print(f"Rank {args.rank} completes") 102 | 103 | # Last rank checks result 104 | if args.rank == args.world_size - 1: 105 | ref_out = model(x) 106 | ref_loss = loss_fn(ref_out, target) 107 | pipe_loss = sum(losses) 108 | torch.testing.assert_close(out, ref_out, rtol=1e-3, atol=1e-4) 109 | torch.testing.assert_close(pipe_loss, ref_loss) 110 | print( 111 | f"equivalence test passed pipe_loss={pipe_loss} ref_loss={ref_loss}" 112 | ) 113 | 114 | 115 | def main(args=None): 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument( 118 | "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) 119 | ) 120 | parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) 121 | parser.add_argument( 122 | "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") 123 | ) 124 | parser.add_argument( 125 | "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") 126 | ) 127 | parser.add_argument( 128 | "--cuda", type=int, default=int(torch.cuda.is_available()) 129 | ) 130 | parser.add_argument( 131 | "--chunks", 132 | type=int, 133 | default=4, 134 | ) 135 | parser.add_argument( 136 | "--schedule", 137 | type=str, 138 | default="1f1b", 139 | choices=schedule_map.keys(), 140 | ) 141 | args = parser.parse_args(args) 142 | 143 | if args.cuda: 144 | dev_id = args.rank % torch.cuda.device_count() 145 | args.device = torch.device(f"cuda:{dev_id}") 146 | else: 147 | args.device = torch.device("cpu") 148 | 149 | # Init process group 150 | backend = "nccl" if args.cuda else "gloo" 151 | dist.init_process_group( 152 | backend=backend, 153 | rank=args.rank, 154 | world_size=args.world_size, 155 | ) 156 | 157 | run_worker(args) 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | 163 | 164 | class TestInterleave(unittest.TestCase): 165 | def test_interleave(self): 166 | import random 167 | 168 | port = random.randint(29500, 30000) 169 | args = [ 170 | "--master_port", 171 | str(port), 172 | ] 173 | main(args) 174 | -------------------------------------------------------------------------------- /test/test_microbatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import unittest 3 | 4 | import torch 5 | 6 | from pippy.microbatch import ( 7 | merge_chunks, 8 | split_args_kwargs_into_chunks, 9 | TensorChunkSpec, 10 | ) 11 | 12 | 13 | d_hid = 512 14 | 15 | 16 | def main(): 17 | x0 = torch.randn(128, d_hid) 18 | x1 = torch.randn(256, d_hid) 19 | x2 = torch.randn(512, d_hid) 20 | 21 | args = (x0, x1, x2) 22 | kwargs = {"x0": x0, "x1": x1, "x2": x2} 23 | 24 | # Default chunking: dim 0 25 | arg_chunks, kwarg_chunks = split_args_kwargs_into_chunks(args, kwargs, 2) 26 | assert len(arg_chunks) == 2 27 | assert len(kwarg_chunks) == 2 28 | assert arg_chunks[0][0].shape == torch.Size([64, d_hid]) 29 | assert arg_chunks[1][0].shape == torch.Size([64, d_hid]) 30 | assert arg_chunks[0][1].shape == torch.Size([128, d_hid]) 31 | assert arg_chunks[0][2].shape == torch.Size([256, d_hid]) 32 | assert kwarg_chunks[0]["x0"].shape == torch.Size([64, d_hid]) 33 | assert kwarg_chunks[0]["x1"].shape == torch.Size([128, d_hid]) 34 | assert kwarg_chunks[1]["x2"].shape == torch.Size([256, d_hid]) 35 | 36 | # Merge chunks back together 37 | merged_args = merge_chunks( 38 | arg_chunks, 39 | (TensorChunkSpec(0), TensorChunkSpec(0), TensorChunkSpec(0)), 40 | ) 41 | torch.testing.assert_close(merged_args, args) 42 | 43 | merged_kwargs = merge_chunks( 44 | kwarg_chunks, 45 | { 46 | "x0": TensorChunkSpec(0), 47 | "x1": TensorChunkSpec(0), 48 | "x2": TensorChunkSpec(0), 49 | }, 50 | ) 51 | torch.testing.assert_close(merged_kwargs, kwargs) 52 | 53 | print("Microbatch test passed") 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | 59 | 60 | class TestMicrobatch(unittest.TestCase): 61 | def test_microbatch(self): 62 | main() 63 | -------------------------------------------------------------------------------- /test/test_optim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # Run this test with: 3 | # torchrun --nproc-per-node 4 test/local_test_optim.py 4 | 5 | import argparse 6 | import os 7 | import unittest 8 | 9 | import torch 10 | import torch.distributed as dist 11 | import torch.optim as optim 12 | 13 | from pippy import ( 14 | pipe_split, 15 | pipeline, 16 | PipelineStage, 17 | Schedule1F1B, 18 | ScheduleGPipe, 19 | ) 20 | 21 | 22 | schedule_map = { 23 | "gpipe": ScheduleGPipe, 24 | "1f1b": Schedule1F1B, 25 | } 26 | 27 | d_hid = 512 28 | batch_size = 256 29 | 30 | torch.manual_seed(0) 31 | 32 | 33 | # Basic example 34 | class ExampleCode(torch.nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 38 | self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 39 | self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 40 | self.lin1 = torch.nn.Linear(d_hid, d_hid) 41 | self.lin2 = torch.nn.Linear(d_hid, d_hid) 42 | 43 | def forward(self, x): 44 | x = torch.mm(x, self.mm_param0) 45 | x = torch.relu(x) 46 | pipe_split() 47 | x = torch.mm(x, self.mm_param1) 48 | x = self.lin1(x) 49 | pipe_split() 50 | x = torch.relu(x) 51 | x = torch.mm(x, self.mm_param2) 52 | pipe_split() 53 | x = self.lin2(x) 54 | x = torch.relu(x) 55 | return x 56 | 57 | 58 | def run_worker(args): 59 | mod = ExampleCode() 60 | mod.to(args.device) 61 | 62 | x = torch.randn(batch_size, d_hid, device=args.device) 63 | target = torch.randn(batch_size, d_hid, device=args.device) 64 | loss_fn = torch.nn.MSELoss(reduction="sum") 65 | 66 | pipe = pipeline( 67 | mod, 68 | args.chunks, 69 | example_args=(x,), 70 | ) 71 | 72 | stage = PipelineStage( 73 | pipe, 74 | args.rank, 75 | device=args.device, 76 | ) 77 | 78 | # Attach to a schedule 79 | ScheduleClass = schedule_map[args.schedule] 80 | schedule = ScheduleClass(stage, args.chunks, loss_fn=loss_fn) 81 | 82 | # Create an optimizer for stage submodule's parameters 83 | optimizer = optim.SGD(stage.submod.parameters(), lr=1e-3, momentum=0.9) 84 | 85 | for _ in range(2): 86 | # Zero gradients 87 | optimizer.zero_grad() 88 | 89 | # Run 90 | if args.rank == 0: 91 | schedule.step(x) 92 | elif args.rank == args.world_size - 1: 93 | losses = [] 94 | out = schedule.step(target=target, losses=losses) 95 | else: 96 | schedule.step() 97 | 98 | # Take an optimization step 99 | optimizer.step() 100 | 101 | dist.barrier() 102 | print(f"Rank {args.rank} completes") 103 | 104 | 105 | def main(args=None): 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument( 108 | "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) 109 | ) 110 | parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) 111 | parser.add_argument( 112 | "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") 113 | ) 114 | parser.add_argument( 115 | "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") 116 | ) 117 | parser.add_argument( 118 | "--cuda", type=int, default=int(torch.cuda.is_available()) 119 | ) 120 | parser.add_argument( 121 | "--chunks", 122 | type=int, 123 | default=4, 124 | ) 125 | parser.add_argument( 126 | "--schedule", 127 | type=str, 128 | default="gpipe", 129 | choices=schedule_map.keys(), 130 | ) 131 | args = parser.parse_args(args) 132 | 133 | if args.cuda: 134 | dev_id = args.rank % torch.cuda.device_count() 135 | args.device = torch.device(f"cuda:{dev_id}") 136 | else: 137 | args.device = torch.device("cpu") 138 | 139 | # Init process group 140 | backend = "nccl" if args.cuda else "gloo" 141 | dist.init_process_group( 142 | backend=backend, 143 | rank=args.rank, 144 | world_size=args.world_size, 145 | ) 146 | 147 | run_worker(args) 148 | 149 | 150 | if __name__ == "__main__": 151 | main() 152 | 153 | 154 | class TestOptimTest(unittest.TestCase): 155 | def test_optim(self): 156 | import random 157 | 158 | port = random.randint(29500, 30000) 159 | args = [ 160 | "--master_port", 161 | str(port), 162 | ] 163 | main(args) 164 | -------------------------------------------------------------------------------- /test/test_pipe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import argparse 3 | import unittest 4 | 5 | import torch 6 | 7 | from pippy import pipe_split, pipeline 8 | 9 | 10 | d_hid = 512 11 | batch_size = 256 12 | 13 | torch.manual_seed(0) 14 | 15 | 16 | # Basic example 17 | class ExampleCode(torch.nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 21 | self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 22 | self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 23 | self.lin1 = torch.nn.Linear(d_hid, d_hid) 24 | self.lin2 = torch.nn.Linear(d_hid, d_hid) 25 | 26 | def forward(self, x, y): 27 | x = torch.mm(x, self.mm_param0) 28 | skip_connection = x 29 | x = x + y 30 | x = torch.relu(x) 31 | pipe_split() 32 | x = torch.mm(x, self.mm_param1) 33 | x = self.lin1(x) 34 | pipe_split() 35 | x = torch.relu(x) 36 | x = x + skip_connection 37 | x = torch.mm(x, self.mm_param2) 38 | pipe_split() 39 | x = self.lin2(x) 40 | x = torch.relu(x) 41 | return x 42 | 43 | 44 | # MLP example 45 | class MLPModule(torch.nn.Module): 46 | def __init__(self, d_hid): 47 | super(MLPModule, self).__init__() 48 | self.net1 = torch.nn.Linear(d_hid, d_hid) 49 | self.relu = torch.nn.ReLU() 50 | self.net2 = torch.nn.Linear(d_hid, d_hid) 51 | 52 | def forward(self, x): 53 | x = self.net1(x) 54 | x = self.relu(x) 55 | x = self.net2(x) 56 | return x 57 | 58 | 59 | class MultiMLP(torch.nn.Module): 60 | def __init__(self): 61 | super().__init__() 62 | self.mlp0 = MLPModule(d_hid) 63 | self.mlp1 = MLPModule(d_hid) 64 | self.mlp2 = MLPModule(d_hid) 65 | self.mlp3 = MLPModule(d_hid) 66 | 67 | def forward(self, x, y): 68 | x = self.mlp0(x) 69 | pipe_split() 70 | x = self.mlp1(x) 71 | pipe_split() 72 | x = self.mlp2(x) 73 | pipe_split() 74 | x = self.mlp3(x) 75 | return x - y 76 | 77 | 78 | def run_worker(args, model_class): 79 | mod = model_class() 80 | x = torch.randn(batch_size, d_hid) 81 | y = torch.randn(batch_size, d_hid) 82 | 83 | pipe = pipeline( 84 | mod, 85 | args.chunks, 86 | example_args=(x, y), 87 | ) 88 | 89 | assert pipe.num_stages == 4 90 | 91 | ref_out = mod(x, y) 92 | out = pipe(x, y)[0] 93 | torch.testing.assert_close(out, ref_out) 94 | print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}") 95 | 96 | # Check qualname 97 | # state_dict.keys include both parameters and persistent buffers 98 | old_names = set(mod.state_dict().keys()) 99 | new_names = set() 100 | for idx in range(pipe.num_stages): 101 | stage_mod = pipe.get_stage_module(idx) 102 | new_names.update(stage_mod.state_dict().keys()) 103 | 104 | assert ( 105 | old_names == new_names 106 | ), f""" 107 | old names {old_names} 108 | new names {new_names} 109 | """ 110 | print("Qualname check passed") 111 | 112 | 113 | def main(args=None): 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument( 116 | "--chunks", 117 | type=int, 118 | default=4, 119 | ) 120 | args = parser.parse_args(args) 121 | 122 | for model_class in [ExampleCode, MultiMLP]: 123 | print("Testing ", model_class.__name__) 124 | run_worker(args, model_class) 125 | 126 | 127 | if __name__ == "__main__": 128 | main() 129 | 130 | 131 | class TestPipe(unittest.TestCase): 132 | def test_pipe(self): 133 | main(args) 134 | -------------------------------------------------------------------------------- /test/test_pipe_bwd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import argparse 3 | import unittest 4 | 5 | import torch 6 | from pippy import pipe_split, pipeline 7 | 8 | 9 | d_hid = 512 10 | batch_size = 256 11 | 12 | torch.manual_seed(0) 13 | 14 | 15 | # Basic example 16 | class ExampleCode(torch.nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 20 | self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 21 | self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 22 | self.lin1 = torch.nn.Linear(d_hid, d_hid) 23 | self.lin2 = torch.nn.Linear(d_hid, d_hid) 24 | self.mse_loss = torch.nn.MSELoss(reduction="sum") 25 | 26 | def forward(self, x, y): 27 | x = torch.mm(x, self.mm_param0) 28 | skip_connection = x 29 | x = torch.relu(x) 30 | pipe_split() 31 | x = torch.mm(x, self.mm_param1) 32 | x = self.lin1(x) 33 | pipe_split() 34 | x = torch.relu(x) 35 | x = x + skip_connection 36 | x = torch.mm(x, self.mm_param2) 37 | pipe_split() 38 | x = self.lin2(x) 39 | logits = torch.relu(x) 40 | loss = self.mse_loss(x, y) 41 | return logits, loss 42 | 43 | 44 | # MLP example 45 | class MLPModule(torch.nn.Module): 46 | def __init__(self, d_hid): 47 | super(MLPModule, self).__init__() 48 | self.net1 = torch.nn.Linear(d_hid, d_hid) 49 | self.relu = torch.nn.ReLU() 50 | self.net2 = torch.nn.Linear(d_hid, d_hid) 51 | 52 | def forward(self, x): 53 | x = self.net1(x) 54 | x = self.relu(x) 55 | x = self.net2(x) 56 | return x 57 | 58 | 59 | class MultiMLP(torch.nn.Module): 60 | def __init__(self): 61 | super().__init__() 62 | self.mlp0 = MLPModule(d_hid) 63 | self.mlp1 = MLPModule(d_hid) 64 | self.mlp2 = MLPModule(d_hid) 65 | self.mlp3 = MLPModule(d_hid) 66 | self.mse_loss = torch.nn.MSELoss(reduction="sum") 67 | 68 | def forward(self, x, y): 69 | x = self.mlp0(x) 70 | pipe_split() 71 | x = self.mlp1(x) 72 | pipe_split() 73 | x = self.mlp2(x) 74 | pipe_split() 75 | x = self.mlp3(x) 76 | loss = self.mse_loss(x, y) 77 | return x, loss 78 | 79 | 80 | def run_worker(args, model_class): 81 | mod = model_class() 82 | x = torch.randn(batch_size, d_hid) 83 | y = torch.randn(batch_size, d_hid) 84 | 85 | pipe = pipeline( 86 | mod, 87 | args.chunks, 88 | example_args=(x, y), 89 | ) 90 | 91 | ref_out = mod(x, y) 92 | out = pipe(x, y) 93 | torch.testing.assert_close(out, ref_out) 94 | print(f"equivalence test passed loss={out[1]} ref_loss={ref_out[1]}") 95 | 96 | 97 | def main(args=None): 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument( 100 | "--chunks", 101 | type=int, 102 | default=4, 103 | ) 104 | args = parser.parse_args(args) 105 | 106 | for model_class in [ExampleCode, MultiMLP]: 107 | print("Testing ", model_class.__name__) 108 | run_worker(args, model_class) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | 114 | 115 | class TestPipeBwd(unittest.TestCase): 116 | def test_pipe_bwd(self): 117 | main(args) 118 | -------------------------------------------------------------------------------- /test/test_pipeline_stage.py: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | import unittest 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | 9 | from pippy import ( 10 | annotate_split_points, 11 | ManualPipelineStage, 12 | pipeline, 13 | PipelineStage, 14 | ScheduleGPipe, 15 | SplitPoint, 16 | ) 17 | 18 | # torch.testing._internal.common_distributed requires "expecttest" 19 | from torch.testing._internal.common_distributed import MultiProcessTestCase 20 | from torch.testing._internal.common_utils import ( 21 | FILE_SCHEMA, 22 | instantiate_parametrized_tests, 23 | parametrize, 24 | ) 25 | 26 | # Example models and helper utils 27 | ########################## 28 | 29 | 30 | class MLP(nn.Module): 31 | def __init__( 32 | self, 33 | dim: int, 34 | hidden_dim: int, 35 | out_dim: int, 36 | ): 37 | super().__init__() 38 | self.w1 = nn.Linear(dim, hidden_dim, bias=False) 39 | self.w2 = nn.Linear(hidden_dim, out_dim, bias=False) 40 | self.relu = nn.ReLU() 41 | 42 | def forward(self, x): 43 | x = self.w1(x) 44 | x = self.w2(x) 45 | x = self.relu(x) 46 | return x 47 | 48 | 49 | # Tests defined below 50 | ########################## 51 | 52 | 53 | # python -m unittest test_pipeline_stage.TestPipelineStage. 54 | # or 55 | # pytest test_pipeline_stage.py -vsk 56 | class TestPipelineStage(MultiProcessTestCase): 57 | @property 58 | def world_size(self) -> int: 59 | # covers first_stage, middle_stage, last_stage cases 60 | return 2 61 | 62 | @property 63 | def init_method(self) -> str: 64 | return f"{FILE_SCHEMA}{self.file_name}" 65 | 66 | def setUp(self): 67 | super().setUp() 68 | # starts world_size processes 69 | self._spawn_processes() 70 | 71 | def init_distributed(self, use_cuda): 72 | if use_cuda: 73 | torch.cuda.set_device(self.rank) 74 | dist.init_process_group( 75 | init_method=self.init_method, 76 | backend="nccl", 77 | rank=self.rank, 78 | world_size=self.world_size, 79 | ) 80 | else: 81 | dist.init_process_group( 82 | init_method=self.init_method, 83 | backend="gloo", 84 | rank=self.rank, 85 | world_size=self.world_size, 86 | ) 87 | 88 | @parametrize("pipeline_stage_type", ["manual", "tracing"]) 89 | @parametrize("use_cuda", [True, False]) 90 | def test_pipeline_stage(self, pipeline_stage_type, use_cuda): 91 | device = ( 92 | torch.device(f"cuda:{self.rank}") 93 | if use_cuda 94 | else torch.device("cpu") 95 | ) 96 | self.init_distributed(use_cuda=use_cuda) 97 | 98 | in_dim = hidden_dim = out_dim = 10 99 | model = MLP(dim=in_dim, hidden_dim=hidden_dim, out_dim=out_dim).to( 100 | device 101 | ) 102 | batch_size = 32 103 | example_input = torch.randn(batch_size, in_dim, device=device) 104 | chunks = 2 105 | 106 | if pipeline_stage_type == "tracing": 107 | annotate_split_points( 108 | model, 109 | { 110 | "w1": SplitPoint.END, 111 | }, 112 | ) 113 | pipe = pipeline(model, chunks, example_args=(example_input,)) 114 | stage = PipelineStage(pipe, self.rank, device) 115 | elif pipeline_stage_type == "manual": 116 | stage = ManualPipelineStage( 117 | model, 118 | self.rank, 119 | self.world_size, 120 | device, 121 | chunks, 122 | input_args=example_input.chunk(chunks)[0], 123 | ) 124 | else: 125 | raise ValueError( 126 | f"Unknown pipeline stage type {pipeline_stage_type}" 127 | ) 128 | 129 | # Define a loss function 130 | loss_fn = torch.nn.MSELoss(reduction="sum") 131 | 132 | # Attach to a schedule 133 | schedule = ScheduleGPipe(stage, chunks, loss_fn=loss_fn) 134 | 135 | # Input data 136 | x = torch.randn(batch_size, in_dim, device=device) 137 | target = torch.randn(batch_size, out_dim, device=device) 138 | 139 | # Run the pipeline with input `x`. Divide the batch into 4 micro-batches 140 | # and run them in parallel on the pipeline 141 | if self.rank == 0: 142 | schedule.step(x) 143 | elif self.rank == self.world_size - 1: 144 | losses = [] 145 | output = schedule.step(target=target, losses=losses) 146 | else: 147 | schedule.step() 148 | 149 | 150 | instantiate_parametrized_tests(TestPipelineStage) 151 | 152 | if __name__ == "__main__": 153 | unittest.main() 154 | -------------------------------------------------------------------------------- /test/test_skip_conn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import argparse 3 | import os 4 | import unittest 5 | 6 | import pippy 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from pippy import pipe_split, pipeline, PipelineStage, ScheduleGPipe 11 | 12 | 13 | pippy.microbatch._debug_mask_minibatches = True 14 | 15 | d_hid = 512 16 | batch_size = 256 17 | 18 | torch.manual_seed(0) 19 | 20 | 21 | class ExampleCode(torch.nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 25 | self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 26 | self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 27 | self.lin1 = torch.nn.Linear(d_hid, d_hid) 28 | self.lin2 = torch.nn.Linear(d_hid, d_hid) 29 | 30 | def forward(self, x, y=torch.zeros(batch_size, d_hid)): 31 | x = torch.mm(x, self.mm_param0) 32 | x = x + y 33 | x = torch.relu(x) 34 | skip_conn = x 35 | pipe_split() 36 | x = torch.mm(x, self.mm_param1) 37 | x = self.lin1(x) 38 | pipe_split() 39 | x = torch.relu(x) 40 | x = torch.mm(x, self.mm_param2) 41 | pipe_split() 42 | x = x + skip_conn 43 | x = self.lin2(x) 44 | x = torch.relu(x) 45 | return x 46 | 47 | 48 | def run_worker(args): 49 | mod = ExampleCode() 50 | mod.to(args.device) 51 | 52 | x = torch.randn(batch_size, d_hid, device=args.device) 53 | y = torch.randn(batch_size, d_hid, device=args.device) 54 | 55 | pipe = pipeline( 56 | mod, 57 | args.chunks, 58 | example_args=(x,), 59 | example_kwargs={"y": y}, 60 | ) 61 | 62 | stage = PipelineStage( 63 | pipe, 64 | args.rank, 65 | device=args.device, 66 | ) 67 | 68 | # Attach to a schedule 69 | schedule = ScheduleGPipe(stage, args.chunks) 70 | 71 | # Run 72 | if args.rank == 0: 73 | schedule.step(x, y=y) 74 | else: 75 | out = schedule.step() 76 | 77 | dist.barrier() 78 | print(f"Rank {args.rank} completes") 79 | 80 | # Last rank checks result 81 | if args.rank == args.world_size - 1: 82 | ref_out = mod(x, y=y) 83 | torch.testing.assert_close(out, ref_out) 84 | print( 85 | f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" 86 | ) 87 | 88 | 89 | def main(args=None): 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument( 92 | "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) 93 | ) 94 | parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) 95 | parser.add_argument( 96 | "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") 97 | ) 98 | parser.add_argument( 99 | "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") 100 | ) 101 | parser.add_argument( 102 | "--cuda", type=int, default=int(torch.cuda.is_available()) 103 | ) 104 | parser.add_argument( 105 | "--chunks", 106 | type=int, 107 | default=4, 108 | ) 109 | args = parser.parse_args(args) 110 | 111 | if args.cuda: 112 | dev_id = args.rank % torch.cuda.device_count() 113 | args.device = torch.device(f"cuda:{dev_id}") 114 | else: 115 | args.device = torch.device("cpu") 116 | 117 | # Init process group 118 | backend = "nccl" if args.cuda else "gloo" 119 | dist.init_process_group( 120 | backend=backend, 121 | rank=args.rank, 122 | world_size=args.world_size, 123 | ) 124 | 125 | run_worker(args) 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | 131 | 132 | class TestSkipConn(unittest.TestCase): 133 | def test_skip_conn(self): 134 | import random 135 | 136 | port = random.randint(29500, 30000) 137 | args = [ 138 | "--master_port", 139 | str(port), 140 | ] 141 | main(args) 142 | -------------------------------------------------------------------------------- /test/test_stage_backward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import copy 3 | import unittest 4 | 5 | import torch 6 | 7 | from pippy._backward import stage_backward 8 | 9 | 10 | d_hid = 512 11 | batch_size = 256 12 | 13 | 14 | # MLP as a stage module 15 | class MLPModule(torch.nn.Module): 16 | def __init__(self, d_hid): 17 | super(MLPModule, self).__init__() 18 | self.net1 = torch.nn.Linear(d_hid, d_hid) 19 | self.relu = torch.nn.ReLU() 20 | self.net2 = torch.nn.Linear(d_hid, d_hid) 21 | 22 | def forward(self, x): 23 | x = self.net1(x) 24 | x = self.relu(x) 25 | x = self.net2(x) 26 | return x 27 | 28 | 29 | def main(args=None): 30 | mod = MLPModule(d_hid) 31 | x = torch.randn(batch_size, d_hid) 32 | # As in a pipeline stage, the inputs to this stage requires gradients 33 | x.requires_grad_(True) 34 | target = torch.randn(batch_size, d_hid) 35 | loss_fn = torch.nn.MSELoss(reduction="sum") 36 | 37 | # Make a copy 38 | ref_mod = copy.deepcopy(mod) 39 | ref_x = x.detach().requires_grad_(x.requires_grad) 40 | ref_target = target.detach() 41 | 42 | # Forward and backward in stage manner 43 | out = mod(x) 44 | loss = loss_fn(out, target) 45 | grad_inputs = stage_backward( 46 | stage_output=loss, 47 | output_grads=None, 48 | input_values=(x,), 49 | ) 50 | 51 | # Run reference 52 | ref_out = ref_mod(ref_x) 53 | ref_loss = loss_fn(ref_out, ref_target) 54 | ref_loss.backward() 55 | 56 | torch.testing.assert_close(grad_inputs[0], ref_x.grad) 57 | 58 | # Every rank checks gradients 59 | for name, p in mod.named_parameters(): 60 | ref_p = ref_mod.get_parameter(name) 61 | try: 62 | torch.testing.assert_close(p.grad, ref_p.grad) 63 | except AssertionError: 64 | print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") 65 | raise 66 | 67 | print(f"Gradient test passed") 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | 73 | 74 | class TestStageBackward(unittest.TestCase): 75 | def test_stage_backward(self): 76 | main() 77 | -------------------------------------------------------------------------------- /test/test_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import torch 3 | from pippy import annotate_split_points, Pipe, SplitPoint 4 | 5 | 6 | d_hid = 16 7 | n_layers = 8 8 | batch_size = 4 9 | 10 | 11 | class MLPModule(torch.nn.Module): 12 | def __init__(self, d_hid): 13 | super(MLPModule, self).__init__() 14 | self.net1 = torch.nn.Linear(d_hid, d_hid) 15 | self.relu = torch.nn.ReLU() 16 | self.net2 = torch.nn.Linear(d_hid, d_hid) 17 | 18 | def forward(self, x): 19 | x = self.net1(x) 20 | x = self.relu(x) 21 | x = self.net2(x) 22 | return x 23 | 24 | 25 | class TransformerLike(torch.nn.Module): 26 | def __init__(self) -> None: 27 | super().__init__() 28 | self.layers = torch.nn.Sequential( 29 | *[MLPModule(d_hid) for _ in range(n_layers)] 30 | ) 31 | 32 | def forward(self, x: torch.Tensor) -> torch.Tensor: 33 | return self.layers(x) 34 | 35 | 36 | transformer = TransformerLike() 37 | print("Original model:\n", transformer) 38 | x = torch.randn(batch_size, d_hid) 39 | 40 | # Split into 2 stages 41 | annotate_split_points( 42 | transformer, {f"layers.{n_layers // 2}": SplitPoint.BEGINNING} 43 | ) 44 | 45 | pipe = Pipe.from_tracing( 46 | transformer, 47 | 1, 48 | (x,), 49 | ) 50 | assert pipe.num_stages == 2 51 | 52 | 53 | def get_layers(module): 54 | layers = [name for name, _ in module.layers.named_children()] 55 | return layers 56 | 57 | 58 | # Collect all layers in pipe 59 | layers = [] 60 | for stage_idx in range(pipe.num_stages): 61 | stage_mod = pipe.get_stage_module(stage_idx) 62 | print(f"\nStage {stage_idx}: \n", stage_mod) 63 | layers += get_layers(stage_mod) 64 | 65 | # Check layer completeness 66 | orig_layers = get_layers(transformer) 67 | assert sorted(layers) == sorted(orig_layers), f"{layers} != {orig_layers}" 68 | print(f"Layers matched! ", layers) 69 | 70 | # Check equivalence 71 | ref = transformer(x) 72 | out = pipe(x)[0] 73 | torch.testing.assert_close(out, ref) 74 | print(f"\nEquivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") 75 | -------------------------------------------------------------------------------- /test/test_unflatten.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | import torch 3 | from pippy import Pipe, pipe_split 4 | 5 | 6 | # Building block for model 7 | class Block(torch.nn.Module): 8 | def __init__(self) -> None: 9 | super().__init__() 10 | self.conv = torch.nn.Conv2d( 11 | in_channels=16, out_channels=16, kernel_size=3, padding=1 12 | ) 13 | self.lin0 = torch.nn.Linear(256, 256) 14 | self.relu = torch.nn.ReLU() 15 | self.lin1 = torch.nn.Linear(256, 256) 16 | 17 | def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor: 18 | x = self.conv(x) 19 | x = self.lin0(x) 20 | pipe_split() 21 | x.add_(constant) 22 | x = self.lin1(x) 23 | return self.relu(x) 24 | 25 | 26 | # Full model 27 | class M(torch.nn.Module): 28 | def __init__(self) -> None: 29 | super().__init__() 30 | self.block0 = Block() 31 | self.block1 = Block() 32 | 33 | def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor: 34 | x = self.block0(x, constant=constant) 35 | pipe_split() 36 | x = self.block1(x, constant=constant) 37 | return x 38 | 39 | 40 | x = torch.randn(1, 16, 256, 256) 41 | constant = torch.ones(1, 16, 256, 256) 42 | 43 | mod = M() 44 | print("Original model:\n", mod) 45 | 46 | pipe = Pipe.from_tracing( 47 | mod, 48 | 1, 49 | (x,), 50 | {"constant": constant}, 51 | ) 52 | 53 | assert pipe.num_stages == 4 54 | orig_state_dict = mod.state_dict() 55 | 56 | # Check qualnames 57 | print("\nParameters of each stage:") 58 | for stage_idx in range(pipe.num_stages): 59 | print(f"\nStage {stage_idx}:") 60 | stage_mod = pipe.get_stage_module(stage_idx) 61 | for param_name, param in stage_mod.named_parameters(): 62 | assert ( 63 | param_name in orig_state_dict 64 | ), f"{param_name} not in original state dict" 65 | print(f"{param_name}: {param.size()}") 66 | 67 | # Check equivalence 68 | ref = mod(x, constant) 69 | out = pipe(x, constant)[0] 70 | torch.testing.assert_close(out, ref) 71 | print(f"\nEquivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") 72 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.2.0 2 | --------------------------------------------------------------------------------