├── .github └── workflows │ └── tests.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── 3d-parallelism.png ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs └── index.md ├── examples ├── README.md └── hybrid_parallelism.py ├── mkdocs.yml ├── pipegoose ├── __init__.py ├── constants.py ├── core │ └── bucket │ │ ├── bucket.py │ │ ├── dist.py │ │ ├── exception.py │ │ ├── manager.py │ │ └── utils.py ├── distributed │ ├── __init__.py │ ├── _initializers │ │ ├── __init__.py │ │ ├── initialize_data.py │ │ ├── initialize_expert.py │ │ ├── initialize_pipeline.py │ │ ├── initialize_tensor.py │ │ └── initializer.py │ ├── _p2p.py │ ├── functional.py │ ├── parallel_context.py │ └── parallel_mode.py ├── nn │ ├── __init__.py │ ├── data_parallel │ │ ├── __init__.py │ │ └── data_parallel.py │ ├── expert_parallel │ │ ├── __init__.py │ │ ├── expert_context.py │ │ ├── expert_parallel.py │ │ ├── experts.py │ │ ├── layers.py │ │ ├── loss.py │ │ ├── parallel_mapping.py │ │ ├── routers.py │ │ └── utils.py │ ├── parallel.py │ ├── parallel_mapping.py │ ├── pipeline_parallel │ │ ├── _comm.py │ │ ├── _job │ │ │ ├── backward.py │ │ │ ├── callback.py │ │ │ ├── creator.py │ │ │ ├── forward.py │ │ │ ├── job.py │ │ │ ├── job_type.py │ │ │ └── register.py │ │ ├── _package.py │ │ ├── _utils.py │ │ ├── _worker.py │ │ ├── exception.py │ │ ├── microbatch.py │ │ ├── partitioner.py │ │ ├── pipeline.py │ │ ├── pipeline_context.py │ │ ├── pipeline_engine.py │ │ ├── pipeline_parallel.py │ │ ├── queue.py │ │ ├── scheduler.py │ │ ├── sync │ │ │ ├── callback.py │ │ │ ├── handshake.py │ │ │ └── progress_tracker.py │ │ └── task.py │ ├── tensor_parallel │ │ ├── _functional.py │ │ ├── _utils.py │ │ ├── embedding.py │ │ ├── layer_norm.py │ │ ├── linear.py │ │ ├── loss.py │ │ ├── parallel_mapping.py │ │ ├── parallelizer.py │ │ └── tensor_parallel.py │ └── utils.py ├── optim │ ├── __init__.py │ ├── base_optim.py │ └── zero │ │ ├── __init__.py │ │ ├── optim.py │ │ ├── sharding.py │ │ └── utils.py ├── partitioning │ └── profile.py ├── testing │ └── utils.py ├── trainer │ ├── callback.py │ ├── logger.py │ ├── state.py │ └── trainer.py └── utils │ └── memory.py ├── poetry.lock ├── pyproject.toml └── tests ├── convergence ├── run_ep.py ├── run_hybrid_parallel.py ├── test_dp.py ├── test_pp.py └── test_tp.py ├── core └── bucket │ ├── test_bucket.py │ ├── test_bucket_distributor.py │ ├── test_bucket_manager.py │ └── test_bucket_utils.py ├── distributed ├── _initializers │ ├── test_initialize_data_parallel_group.py │ ├── test_initialize_expert_parallel_group.py │ ├── test_initialize_pipeline_parallel_group.py │ ├── test_initialize_tensor_parallel_group.py │ └── utils.py ├── conftest.py ├── test_functional.py ├── test_p2p.py ├── test_parallel_context.py ├── test_parallel_mode.py └── test_rpc.py ├── nn ├── __init__.py ├── data_parallel │ ├── __init__.py │ └── test_data_parallel.py ├── expert_parallel │ ├── test_expert_context.py │ ├── test_expert_loss.py │ ├── test_expert_parallel.py │ ├── test_expert_parallel_mapping.py │ ├── test_expert_utils.py │ ├── test_experts.py │ ├── test_hybrid_expert_parallel.py │ ├── test_layers.py │ └── test_routers.py ├── pipeline_parallel │ ├── conftest.py │ ├── job │ │ ├── test_backward.py │ │ ├── test_callback.py │ │ ├── test_creator.py │ │ ├── test_forward.py │ │ ├── test_hybrid_job.py │ │ ├── test_job.py │ │ └── test_register.py │ ├── sync │ │ ├── test_handshake.py │ │ └── test_progress_tracker.py │ ├── test_comm.py │ ├── test_microbatch.py │ ├── test_package.py │ ├── test_partitioner.py │ ├── test_pipeline_context.py │ ├── test_pipeline_engine.py │ ├── test_pipeline_parallel.py │ ├── test_pp_utils.py │ ├── test_queue.py │ ├── test_scheduler.py │ └── test_worker.py ├── tensor_parallel │ ├── conftest.py │ ├── test_embedding.py │ ├── test_functional_.py │ ├── test_layer_norm.py │ ├── test_linear.py │ ├── test_loss.py │ ├── test_parallel_mapping.py │ ├── test_parallelizer.py │ ├── test_tensor_parallel.py │ └── test_utils.py └── test_utils.py ├── optim └── zero │ ├── test_optim.py │ ├── test_optim_utils.py │ └── test_sharding.py ├── partitioning └── test_profile.py └── test_hybrid.py /.github/workflows/tests.yaml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | tests: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | 16 | - name: Set up Python 3.9 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: 3.9 20 | 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install -e . 25 | pip3 install torch torchvision torchaudio 26 | pip install pytest 27 | pip install pytest-cov 28 | 29 | - name: Print PyTorch and CUDA info 30 | run: | 31 | python -c "import torch; print(torch.__version__)" 32 | python -c "import torch; print(torch.cuda.is_available())" 33 | 34 | - name: Run tests 35 | run: pytest --color=yes --durations=0 --cov=pipegoose --verbose tests/ 36 | 37 | - name: Upload coverage reports to Codecov 38 | uses: codecov/codecov-action@v3 39 | env: 40 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .vscode/* 163 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/asottile/pyupgrade 3 | rev: v2.31.1 4 | hooks: 5 | - id: pyupgrade 6 | args: 7 | - --py37-plus 8 | - repo: https://github.com/PyCQA/isort 9 | rev: 5.12.0 10 | hooks: 11 | - id: isort 12 | args: 13 | - --profile=black 14 | - --skip-glob=wandb/**/* 15 | - --thirdparty=wandb 16 | - repo: https://github.com/myint/autoflake 17 | rev: v1.4 18 | hooks: 19 | - id: autoflake 20 | args: 21 | - -r 22 | - --exclude=wandb 23 | - --in-place 24 | - --remove-unused-variables 25 | - --remove-all-unused-imports 26 | - repo: https://github.com/python/black 27 | rev: 22.3.0 28 | hooks: 29 | - id: black 30 | args: 31 | - --line-length=127 32 | - --exclude=wandb 33 | - repo: https://github.com/codespell-project/codespell 34 | rev: v2.1.0 35 | hooks: 36 | - id: codespell 37 | args: 38 | - --ignore-words-list=nd,reacher,thist,ths,magent,ba 39 | - --skip=docs/css/termynal.css,docs/js/termynal.js,docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb 40 | - repo: https://github.com/python-poetry/poetry 41 | rev: 1.3.2 42 | hooks: 43 | - id: poetry-export 44 | name: poetry-export requirements.txt 45 | args: ["--without-hashes", "-o", "requirements/requirements.txt"] 46 | stages: [manual] 47 | - id: poetry-export 48 | name: poetry-export requirements-atari.txt 49 | args: ["--without-hashes", "-o", "requirements/requirements-atari.txt", "-E", "atari"] 50 | stages: [manual] 51 | - id: poetry-export 52 | name: poetry-export requirements-mujoco.txt 53 | args: ["--without-hashes", "-o", "requirements/requirements-mujoco.txt", "-E", "mujoco"] 54 | stages: [manual] 55 | - id: poetry-export 56 | name: poetry-export requirements-dm_control.txt 57 | args: ["--without-hashes", "-o", "requirements/requirements-dm_control.txt", "-E", "dm_control"] 58 | stages: [manual] 59 | - id: poetry-export 60 | name: poetry-export requirements-mujoco_py.txt 61 | args: ["--without-hashes", "-o", "requirements/requirements-mujoco_py.txt", "-E", "mujoco_py"] 62 | stages: [manual] 63 | - id: poetry-export 64 | name: poetry-export requirements-procgen.txt 65 | args: ["--without-hashes", "-o", "requirements/requirements-procgen.txt", "-E", "procgen"] 66 | stages: [manual] 67 | - id: poetry-export 68 | name: poetry-export requirements-envpool.txt 69 | args: ["--without-hashes", "-o", "requirements/requirements-envpool.txt", "-E", "envpool"] 70 | stages: [manual] 71 | - id: poetry-export 72 | name: poetry-export requirements-pettingzoo.txt 73 | args: ["--without-hashes", "-o", "requirements/requirements-pettingzoo.txt", "-E", "pettingzoo"] 74 | stages: [manual] 75 | - id: poetry-export 76 | name: poetry-export requirements-jax.txt 77 | args: ["--without-hashes", "-o", "requirements/requirements-jax.txt", "-E", "jax"] 78 | stages: [manual] 79 | - id: poetry-export 80 | name: poetry-export requirements-optuna.txt 81 | args: ["--without-hashes", "-o", "requirements/requirements-optuna.txt", "-E", "optuna"] 82 | stages: [manual] 83 | - id: poetry-export 84 | name: poetry-export requirements-docs.txt 85 | args: ["--without-hashes", "-o", "requirements/requirements-docs.txt", "-E", "docs"] 86 | stages: [manual] 87 | - id: poetry-export 88 | name: poetry-export requirements-cloud.txt 89 | args: ["--without-hashes", "-o", "requirements/requirements-cloud.txt", "-E", "cloud"] 90 | stages: [manual] 91 | -------------------------------------------------------------------------------- /3d-parallelism.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xrsrke/pipegoose/fe6bcfc2ad4d592fcb11beda41481d9ce8cfc28c/3d-parallelism.png -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | We're building an end-to-end multi-modal MoE that works in 3D parallelism, and do pre-training in a decentraized way as proposed in the paper [DiLoCo](https://arxiv.org/abs/2311.08105) 2 | 3 | If you want to contribute, please check the following links 4 | 5 | - High priority tasks [[link]](https://github.com/xrsrke/pipegoose/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22+label%3A%22High+Priority%22) 6 | - Beginner tasks [[link]](https://github.com/xrsrke/pipegoose/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22+label%3A%22good+first+issue%22) 7 | - All tasks that need help (include beginner and high priority))[[link]](https://github.com/xrsrke/pipegoose/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 XλRI-U5 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to MkDocs 2 | 3 | For full documentation visit [mkdocs.org](https://www.mkdocs.org). 4 | 5 | ## Commands 6 | 7 | * `mkdocs new [dir-name]` - Create a new project. 8 | * `mkdocs serve` - Start the live-reloading docs server. 9 | * `mkdocs build` - Build the documentation site. 10 | * `mkdocs -h` - Print help message and exit. 11 | 12 | ## Project layout 13 | 14 | mkdocs.yml # The configuration file. 15 | docs/ 16 | index.md # The documentation homepage. 17 | ... # Other markdown pages, images and other files. 18 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | ### Hybrid tensor parallelism and data parallelism training 2 | 3 | Hybrid 3D parallelism for 🤗 `transformers` will be available in the upcoming weeks (it's basically done, but it doesn't support 🤗 `transformers` yet) 4 | 5 | **You must have at least 4 GPUs to run 2D parallelism.**. `nproc-per-node` is equal to `tensor_parallel_size` * `pipeline_parallel_size` * `data_parallel_size`. 6 | 7 | ```bash 8 | torchrun --standalone --nnodes=1 --nproc-per-node=4 hybrid_parallelism.py 9 | ``` 10 | -------------------------------------------------------------------------------- /examples/hybrid_parallelism.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from torch.optim import SGD 3 | from torch.utils.data import DataLoader 4 | from torch.utils.data.distributed import DistributedSampler 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | 7 | from pipegoose.distributed import ParallelContext, ParallelMode 8 | from pipegoose.nn import DataParallel, TensorParallel 9 | 10 | if __name__ == "__main__": 11 | DATA_PARALLEL_SIZE = 2 12 | TENSOR_PARALLEL_SIZE = 2 13 | PIPELINE_PARALLEL_SIZE = 1 14 | BATCH_SIZE = 4 15 | 16 | parallel_context = ParallelContext.from_torch( 17 | data_parallel_size=DATA_PARALLEL_SIZE, 18 | tensor_parallel_size=TENSOR_PARALLEL_SIZE, 19 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 20 | ) 21 | rank = parallel_context.get_global_rank() 22 | 23 | dataset = load_dataset("imdb", split="train[:100]") 24 | dataset = dataset.map(lambda x: {"text": x["text"][:30]}) # for demonstration purposes, you can remove this line 25 | 26 | dp_rank = parallel_context.get_local_rank(ParallelMode.DATA) 27 | sampler = DistributedSampler(dataset, num_replicas=DATA_PARALLEL_SIZE, rank=dp_rank, seed=69) 28 | dataloader = DataLoader(dataset, batch_size=BATCH_SIZE // DATA_PARALLEL_SIZE, shuffle=False, sampler=sampler) 29 | 30 | model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") 31 | tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") 32 | tokenizer.pad_token = tokenizer.eos_token 33 | 34 | model = TensorParallel(model, parallel_context).parallelize() 35 | model = DataParallel(model, parallel_context).parallelize() 36 | optim = SGD(model.parameters(), lr=1e-3) 37 | model.to("cuda") 38 | device = next(model.parameters()).device 39 | 40 | print(f"rank={rank}, moved to device: {device}") 41 | 42 | for epoch in range(100): 43 | sampler.set_epoch(epoch) 44 | 45 | for batch in dataloader: 46 | inputs = tokenizer(batch["text"], padding=True, truncation=True, max_length=1024, return_tensors="pt") 47 | inputs = {name: tensor.to(device) for name, tensor in inputs.items()} 48 | labels = inputs["input_ids"] 49 | 50 | outputs = model(**inputs, labels=labels) 51 | 52 | optim.zero_grad() 53 | outputs.loss.backward() 54 | optim.step() 55 | 56 | print(f"rank={rank}, loss={outputs.loss}") 57 | 58 | model.cpu() 59 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: PipeGoose - Megatron-LM 3D parallelism for 🤗 `transformers` model 2 | repo_url: https://github.com/xrsrke/pipegoose 3 | nav: 4 | - Home: index.md 5 | theme: 6 | name: material 7 | features: 8 | # - navigation.instant 9 | - navigation.tracking 10 | # - navigation.tabs 11 | # - navigation.tabs.sticky 12 | - navigation.sections 13 | - navigation.expand 14 | - navigation.top 15 | - search.suggest 16 | - search.highlight 17 | palette: 18 | - media: "(prefers-color-scheme: dark)" 19 | scheme: slate 20 | primary: teal 21 | accent: light black 22 | toggle: 23 | icon: material/lightbulb 24 | name: Switch to light mode 25 | - media: "(prefers-color-scheme: light)" 26 | scheme: default 27 | primary: teal 28 | accent: deep orange 29 | toggle: 30 | icon: material/lightbulb-outline 31 | name: Switch to dark mode 32 | plugins: 33 | - mkdocstrings 34 | watch: 35 | - . # reload docs for any file changes 36 | -------------------------------------------------------------------------------- /pipegoose/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xrsrke/pipegoose/fe6bcfc2ad4d592fcb11beda41481d9ce8cfc28c/pipegoose/__init__.py -------------------------------------------------------------------------------- /pipegoose/constants.py: -------------------------------------------------------------------------------- 1 | SEED = 69 2 | 3 | 4 | CHECKPOINT_WEIGHTS_NAME = "pytorch_model_tp_{}_pp_{}.bin" 5 | CHECKPOINT_PATH_NAME = "./" 6 | 7 | # NOTE: no single bucket size is optimal for all models 8 | BUCKET_SIZE_MB = 25 9 | 10 | 11 | # ================================================== 12 | # Distributed Communication 13 | # ================================================== 14 | 15 | # RPC global worker's name 16 | WORKER_NAME = "RPC_GLOBAL_WORKER_{}" 17 | 18 | 19 | # ================================================== 20 | # Pipeline Parallelism 21 | # ================================================== 22 | 23 | 24 | # NOTE: the minimum number of cocurrent worker threads that execute jobs 25 | # in the background of pipeline parallelism 26 | PIPELINE_MIN_WORKERS = 1 27 | PIPELINE_MAX_WORKERS = 1 28 | 29 | JOB_KEY_LENGTH = 15 30 | -------------------------------------------------------------------------------- /pipegoose/core/bucket/bucket.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pipegoose.core.bucket.exception import BucketClosedError, BucketFullError 4 | 5 | 6 | class Bucket: 7 | """Store tensors in a contiguous memory space.""" 8 | 9 | def __init__(self, size: int, dtype: torch.dtype): 10 | """Create a bucket that stores tensors in a contiguous memory space. 11 | 12 | Args: 13 | size (int): the number of elements in the bucket 14 | dtype (torch.dtype): the data type of an element in the bucket 15 | """ 16 | assert size > 0, "Bucket size must be greater than 0." 17 | assert isinstance(dtype, torch.dtype), "Data type must be a torch.dtype." 18 | 19 | self.size = size 20 | self.dtype = dtype 21 | 22 | self._buffer = torch.zeros(size, dtype=dtype) 23 | self._offset = 0 24 | self._is_closed = False 25 | self._num_tensors = 0 26 | 27 | @property 28 | def is_full(self) -> bool: 29 | """Whether the bucket is full.""" 30 | return self._buffer.storage().size() == self._offset 31 | 32 | @property 33 | def available_size(self) -> int: 34 | """The number of elements that can be added to the bucket.""" 35 | return self._buffer.storage().size() - self._offset 36 | 37 | def add_tensor(self, tensor: torch.Tensor) -> torch.Tensor: 38 | """Add a tensor to the bucket.""" 39 | assert isinstance(tensor, torch.Tensor), "Input must be a tensor." 40 | assert tensor.dtype == self._buffer.dtype, "Input tensor must have the same dtype as the bucket." 41 | 42 | if self.is_closed is True: 43 | raise BucketClosedError("Bucket is closed.") 44 | 45 | if self.is_full is True: 46 | raise BucketFullError("Bucket is full.") 47 | 48 | numel = tensor.numel() 49 | if numel > self.available_size: 50 | raise BucketFullError("Bucket does not have enough space.") 51 | 52 | self._buffer[self._offset : self._offset + numel].copy_(tensor.flatten()) 53 | # NOTE: set the tensor's storage to its corresponding storage portion in the bucket 54 | tensor.data = self._buffer[self._offset : self._offset + numel].view_as(tensor) 55 | self._offset += numel 56 | self._num_tensors += 1 57 | 58 | return tensor 59 | 60 | @property 61 | def is_closed(self) -> bool: 62 | """Whether the bucket is closed.""" 63 | return self._is_closed 64 | 65 | @property 66 | def is_free(self) -> bool: 67 | """Whether the bucket is free.""" 68 | return self.storage().size() == 0 69 | 70 | def storage(self) -> torch.Storage: 71 | """Return the bucket's storage.""" 72 | return self._buffer.storage() 73 | 74 | def close(self): 75 | """Close the bucket, and not allow any more tensors to be added to it.""" 76 | assert self.is_closed is False, "Bucket is already closed." 77 | self._is_closed = True 78 | 79 | def clear(self): 80 | """Clear all data in the bucket.""" 81 | assert self._offset != 0, "Bucket is empty, so no need to free memory." 82 | self._offset = 0 83 | self._num_tensors = 0 84 | self._buffer.zero_() 85 | 86 | def __len__(self) -> int: 87 | """Return the number of tensors in the bucket.""" 88 | return self._num_tensors 89 | -------------------------------------------------------------------------------- /pipegoose/core/bucket/dist.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Union 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | from pipegoose.core.bucket.bucket import Bucket 7 | from pipegoose.core.bucket.utils import mb_size_to_num_elements 8 | from pipegoose.distributed.functional import all_reduce 9 | from pipegoose.distributed.parallel_context import ParallelContext 10 | from pipegoose.distributed.parallel_mode import ParallelMode 11 | 12 | OPERATOR_MAPPING = {dist.all_reduce: all_reduce} 13 | 14 | DistOperator = Union[ 15 | dist.broadcast, 16 | dist.all_reduce, 17 | dist.reduce, 18 | dist.all_gather, 19 | dist.gather, 20 | dist.scatter, 21 | dist.reduce_scatter, 22 | dist.all_to_all, 23 | ] 24 | 25 | 26 | class BucketDistributor: 27 | """ 28 | Perform an asynchronous, distributed operation on a bucket, 29 | filling it until full before executing the operation. 30 | 31 | NOTE: Inspired from the design of FairScale's ReduceScatterBucketer 32 | https://github.com/facebookresearch/fairscale/blob/164cc0f3170b4a3951dd84dda29c3e1504ac4d6e/fairscale/internal/reduce_scatter_bucketer.py#L74 33 | """ 34 | 35 | # DIST_OPERATOR = [dist.broadcast, dist.all_reduce, dist.reduce, dist.all_gather, dist.gather, dist.scatter, dist.reduce_scatter, dist.all_to_all] 36 | 37 | def __init__(self, op: DistOperator, bucket_size_mb: Union[int, float], parallel_context: ParallelContext = None): 38 | assert op in OPERATOR_MAPPING, f"Operation must be one of {OPERATOR_MAPPING}." 39 | assert bucket_size_mb > 0, "Bucket size must be greater than 0." 40 | 41 | self.op = op 42 | self.bucket_size_mb = bucket_size_mb 43 | # NOTE: the number of elements in the bucket 44 | self.bucket_size = mb_size_to_num_elements(bucket_size_mb, torch.float32) 45 | self.parallel_context = parallel_context 46 | self.buckets: Dict[Tuple[torch.dtype, ParallelMode], Bucket] = {} 47 | 48 | @torch.no_grad() 49 | def execute(self, tensor: torch.Tensor, parallel_mode: ParallelMode): 50 | # NOTE: execute the operation if the tensor is larger than the bucket size 51 | if tensor.numel() > self.bucket_size: 52 | OPERATOR_MAPPING[self.op](tensor, parallel_context=self.parallel_context, parallel_mode=parallel_mode) 53 | return 54 | 55 | # NOTE: execute the bucket if the tensor is larger than the available space, 56 | # then empty and refill the bucket with the tensor 57 | key = (tensor.dtype, parallel_mode) 58 | if key not in self.buckets: 59 | 60 | self.buckets[key] = Bucket(self.bucket_size, tensor.dtype, self.parallel_context) 61 | else: 62 | bucket = self.buckets[key] 63 | 64 | bucket.add_tensor(tensor) 65 | 66 | def _create_bucket(self): 67 | pass 68 | -------------------------------------------------------------------------------- /pipegoose/core/bucket/exception.py: -------------------------------------------------------------------------------- 1 | class BucketFullError(Exception): 2 | """Exception raised when a bucket is full and a new item is added.""" 3 | 4 | 5 | class BucketClosedError(Exception): 6 | """Exception raised when a bucket is closed and a new item is added.""" 7 | -------------------------------------------------------------------------------- /pipegoose/core/bucket/manager.py: -------------------------------------------------------------------------------- 1 | class BucketManager: 2 | pass 3 | -------------------------------------------------------------------------------- /pipegoose/core/bucket/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_memory_address_of_tensor_storage(): 5 | pass 6 | 7 | 8 | def mb_size_to_num_elements(mb: float, dtype: torch.dtype) -> int: 9 | """Convert a size in megabytes to a number of elements in a tensor dtype.""" 10 | INFO_CLASSES = { 11 | torch.float16: torch.finfo, 12 | torch.float32: torch.finfo, 13 | torch.float64: torch.finfo, 14 | torch.complex64: torch.finfo, 15 | torch.complex128: torch.finfo, 16 | torch.uint8: torch.iinfo, 17 | torch.int8: torch.iinfo, 18 | torch.int16: torch.iinfo, 19 | torch.int32: torch.iinfo, 20 | torch.int64: torch.iinfo, 21 | } 22 | 23 | if dtype not in INFO_CLASSES: 24 | raise ValueError(f"Unsupported dtype: {dtype}.") 25 | 26 | bytes_per_dtype = INFO_CLASSES[dtype](dtype).bits // 8 27 | bytes_per_mb = 1024 * 1024 28 | total_bytes = mb * bytes_per_mb 29 | return total_bytes // bytes_per_dtype 30 | -------------------------------------------------------------------------------- /pipegoose/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from pipegoose.distributed.parallel_context import ParallelContext 2 | from pipegoose.distributed.parallel_mode import ParallelMode 3 | -------------------------------------------------------------------------------- /pipegoose/distributed/_initializers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xrsrke/pipegoose/fe6bcfc2ad4d592fcb11beda41481d9ce8cfc28c/pipegoose/distributed/_initializers/__init__.py -------------------------------------------------------------------------------- /pipegoose/distributed/_initializers/initialize_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 HPC-AI Technology Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Modified by pipegoose's contributors. 16 | 17 | import torch.distributed as dist 18 | 19 | from pipegoose.distributed._initializers.initializer import ( 20 | ProcessGroupInitializer, 21 | ProcessGroupResult, 22 | ) 23 | from pipegoose.distributed.parallel_mode import ParallelMode 24 | 25 | 26 | class DataParallelGroupInitializer(ProcessGroupInitializer): 27 | def __init__(self, *args, **kwargs): 28 | super().__init__(*args, **kwargs) 29 | self.num_pipeline_parallel_groups = self.world_size // self.pipeline_parallel_size 30 | 31 | def init_dist_group(self) -> ProcessGroupResult: 32 | local_rank = None 33 | process_group = None 34 | local_world_size = None 35 | ranks_in_group = None 36 | parallel_mode = ParallelMode.DATA 37 | 38 | for i in range(self.pipeline_parallel_size): 39 | start_rank = i * self.num_pipeline_parallel_groups 40 | end_rank = (i + 1) * self.num_pipeline_parallel_groups 41 | 42 | for j in range(self.tensor_parallel_size): 43 | ranks = list(range(start_rank + j, end_rank, self.tensor_parallel_size)) 44 | # NOTE: dist.new_group() must be called collectively by all the processes 45 | # that would be part of the group, which means every process in the group 46 | # needs to call this function. If only a subset of the processes call new_group(), 47 | # it will hang because it's waiting for the rest of the processes to join. 48 | group = dist.new_group(ranks=ranks) 49 | 50 | if self.rank in ranks: 51 | local_rank = ranks.index(self.rank) 52 | local_world_size = len(ranks) 53 | ranks_in_group = ranks 54 | process_group = group 55 | 56 | return { 57 | "local_rank": local_rank, 58 | "process_group": process_group, 59 | "local_world_size": local_world_size, 60 | "ranks_in_group": ranks_in_group, 61 | "parallel_mode": parallel_mode, 62 | } 63 | -------------------------------------------------------------------------------- /pipegoose/distributed/_initializers/initialize_expert.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | from pipegoose.distributed._initializers.initializer import ( 4 | ProcessGroupInitializer, 5 | ProcessGroupResult, 6 | ) 7 | from pipegoose.distributed.parallel_mode import ParallelMode 8 | 9 | 10 | class ExpertDataParallelGroupInitializer(ProcessGroupInitializer): 11 | """ 12 | Initialize the process group for data parallelism in expert parallelism. 13 | 14 | Pipeline MoE: A Flexible MoE Implementation with Pipeline Parallelism" by Xin Chen et al 15 | https://arxiv.org/abs/2304.11414 16 | 17 | NOTE: This looks similar to TensorParallelGroupInitializer, because it aligns with the paper. 18 | """ 19 | 20 | def init_dist_group(self) -> ProcessGroupResult: 21 | num_tensor_parallel_groups = self.world_size // self.tensor_parallel_size 22 | local_rank = None 23 | process_group = None 24 | local_world_size = None 25 | ranks_in_group = None 26 | parallel_mode = ParallelMode.EXPERT_DATA 27 | 28 | for i in range(num_tensor_parallel_groups): 29 | ranks = list(range(i * self.tensor_parallel_size, (i + 1) * self.tensor_parallel_size)) 30 | group = dist.new_group(ranks=ranks) 31 | 32 | if self.rank in ranks: 33 | local_rank = ranks.index(self.rank) 34 | local_world_size = len(ranks) 35 | ranks_in_group = ranks 36 | process_group = group 37 | 38 | return { 39 | "local_rank": local_rank, 40 | "local_world_size": local_world_size, 41 | "ranks_in_group": ranks_in_group, 42 | "process_group": process_group, 43 | "parallel_mode": parallel_mode, 44 | } 45 | -------------------------------------------------------------------------------- /pipegoose/distributed/_initializers/initialize_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 HPC-AI Technology Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Modified by pipegoose's contributors. 16 | 17 | import torch.distributed as dist 18 | 19 | from pipegoose.distributed._initializers.initializer import ( 20 | ProcessGroupInitializer, 21 | ProcessGroupResult, 22 | ) 23 | from pipegoose.distributed.parallel_mode import ParallelMode 24 | 25 | 26 | class PipelineParallelGroupInitializer(ProcessGroupInitializer): 27 | def init_dist_group(self) -> ProcessGroupResult: 28 | num_pipeline_parallel_groups = self.world_size // self.pipeline_parallel_size 29 | local_rank = None 30 | local_world_size = None 31 | ranks_in_group = None 32 | process_group = None 33 | parallel_mode = ParallelMode.PIPELINE 34 | 35 | for i in range(num_pipeline_parallel_groups): 36 | ranks = list(range(i, self.world_size, num_pipeline_parallel_groups)) 37 | 38 | # NOTE: dist.new_group() must be called collectively by all the processes 39 | # that would be part of the group, which means every process in the group 40 | # needs to call this function. If only a subset of the processes call new_group(), 41 | # it will hang because it's waiting for the rest of the processes to join. 42 | group = dist.new_group(ranks=ranks) 43 | 44 | if self.rank in ranks: 45 | local_rank = ranks.index(self.rank) 46 | local_world_size = len(ranks) 47 | ranks_in_group = ranks 48 | process_group = group 49 | 50 | return { 51 | "local_rank": local_rank, 52 | "process_group": process_group, 53 | "local_world_size": local_world_size, 54 | "ranks_in_group": ranks_in_group, 55 | "parallel_mode": parallel_mode, 56 | } 57 | -------------------------------------------------------------------------------- /pipegoose/distributed/_initializers/initialize_tensor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 HPC-AI Technology Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Modified by pipegoose's contributors. 16 | 17 | import torch.distributed as dist 18 | 19 | from pipegoose.distributed._initializers.initializer import ( 20 | ProcessGroupInitializer, 21 | ProcessGroupResult, 22 | ) 23 | from pipegoose.distributed.parallel_mode import ParallelMode 24 | 25 | 26 | class TensorParallelGroupInitializer(ProcessGroupInitializer): 27 | def init_dist_group(self) -> ProcessGroupResult: 28 | num_tensor_parallel_groups = self.world_size // self.tensor_parallel_size 29 | local_rank = None 30 | process_group = None 31 | local_world_size = None 32 | ranks_in_group = None 33 | parallel_mode = ParallelMode.TENSOR 34 | 35 | for i in range(num_tensor_parallel_groups): 36 | ranks = list(range(i * self.tensor_parallel_size, (i + 1) * self.tensor_parallel_size)) 37 | 38 | # NOTE: dist.new_group() must be called collectively by all the processes 39 | # that would be part of the group, which means every process in the group 40 | # needs to call this function. If only a subset of the processes call new_group(), 41 | # it will hang because it's waiting for the rest of the processes to join. 42 | group = dist.new_group(ranks=ranks) 43 | 44 | if self.rank in ranks: 45 | local_rank = ranks.index(self.rank) 46 | local_world_size = len(ranks) 47 | ranks_in_group = ranks 48 | process_group = group 49 | 50 | return { 51 | "local_rank": local_rank, 52 | "local_world_size": local_world_size, 53 | "ranks_in_group": ranks_in_group, 54 | "process_group": process_group, 55 | "parallel_mode": parallel_mode, 56 | } 57 | -------------------------------------------------------------------------------- /pipegoose/distributed/_initializers/initializer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 HPC-AI Technology Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Modified by pipegoose's contributors. 16 | 17 | from abc import ABC, abstractclassmethod 18 | from typing import TypedDict 19 | 20 | from torch.distributed import ProcessGroup 21 | 22 | from pipegoose.distributed.parallel_mode import ParallelMode 23 | 24 | 25 | class ProcessGroupResult(TypedDict): 26 | local_rank: int 27 | local_world_size: int 28 | process_group: ProcessGroup 29 | parallel_mode: ParallelMode 30 | 31 | 32 | class ProcessGroupInitializer(ABC): 33 | def __init__( 34 | self, rank: int, world_size: int, tensor_parallel_size: int, pipeline_parallel_size: int, data_parallel_size: int 35 | ): 36 | self.rank = rank 37 | self.world_size = world_size 38 | self.tensor_parallel_size = tensor_parallel_size 39 | self.pipeline_parallel_size = pipeline_parallel_size 40 | self.data_parallel_size = data_parallel_size 41 | 42 | @abstractclassmethod 43 | def init_dist_group(self) -> ProcessGroupResult: 44 | raise NotImplementedError 45 | -------------------------------------------------------------------------------- /pipegoose/distributed/_p2p.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | from pipegoose.distributed.parallel_context import ParallelContext 7 | from pipegoose.distributed.parallel_mode import ParallelMode 8 | 9 | ID_TO_DTYPE = [ 10 | torch.bfloat16, 11 | torch.float16, 12 | torch.float32, 13 | torch.float64, 14 | torch.uint8, 15 | torch.int8, 16 | torch.int16, 17 | torch.int32, 18 | torch.int64, 19 | torch.bool, 20 | ] 21 | 22 | DTYPE_TO_ID = {dtype: idx for idx, dtype in enumerate(ID_TO_DTYPE)} 23 | 24 | 25 | class _P2P: 26 | """ 27 | P2P Communication 28 | 29 | NOTE: Inspired from OSLO's P2P communication design 30 | https://github.com/EleutherAI/oslo/blob/d7c4e32e766a99cc9d56533bc090570360dc8b2a/oslo/torch/distributed/nn/_p2p.py#L62 31 | """ 32 | 33 | def __init__(self): 34 | self._INSTRUCTIONS = { 35 | torch.Tensor: {"send": self._send_tensor, "recv": self._recv_tensor}, 36 | } 37 | 38 | def _send_metadata( 39 | self, 40 | data: torch.Tensor, 41 | dst: int, 42 | parallel_context: ParallelContext, 43 | parallel_mode: ParallelMode, 44 | ): 45 | assert isinstance(data, torch.Tensor), "data must be a torch.Tensor" 46 | 47 | group = parallel_context.get_group(parallel_mode) 48 | 49 | dtype = torch.tensor(DTYPE_TO_ID[data.dtype]) 50 | dist.send(dtype, dst=dst, group=group) 51 | 52 | requires_grad = torch.tensor(1 if data.requires_grad else 0) 53 | dist.send(requires_grad, dst=dst, group=group) 54 | 55 | shape = torch.tensor(list(data.shape)) 56 | dist.send(shape, dst=dst, group=group) 57 | 58 | def _recv_metadata( 59 | self, 60 | src: int, 61 | parallel_context: ParallelContext, 62 | parallel_mode: ParallelMode, 63 | ): 64 | group = parallel_context.get_group(parallel_mode) 65 | 66 | dtype = torch.tensor(0) 67 | dist.recv(dtype, src=src, group=group) 68 | dtype = ID_TO_DTYPE[dtype.item()] 69 | 70 | requires_grad = torch.tensor(0) 71 | dist.recv(requires_grad, src=src, group=group) 72 | requires_grad = True if requires_grad == 1 else False 73 | 74 | shape = torch.tensor(0) 75 | dist.recv(shape, src=src, group=group) 76 | if isinstance(shape.tolist(), int): 77 | shape = (shape.item(),) 78 | else: 79 | shape = tuple(shape.tolist()) 80 | 81 | return dtype, requires_grad, shape 82 | 83 | def _send_tensor( 84 | self, 85 | data: torch.Tensor, 86 | dst: int, 87 | parallel_context: ParallelContext, 88 | parallel_mode: ParallelMode, 89 | ): 90 | assert isinstance(data, torch.Tensor), "data must be a torch.Tensor" 91 | 92 | self._send_metadata(data, dst, parallel_context, parallel_mode) 93 | 94 | group = parallel_context.get_group(parallel_mode) 95 | dist.send(data, dst=dst, group=group) 96 | 97 | def _recv_tensor(self, src: int, parallel_context: ParallelContext, parallel_mode: ParallelMode) -> torch.Tensor: 98 | group = parallel_context.get_group(parallel_mode) 99 | 100 | dtype, requires_grad, shape = self._recv_metadata(src, parallel_context, parallel_mode) 101 | 102 | data = torch.zeros(size=shape, dtype=dtype, requires_grad=requires_grad) 103 | dist.recv(data, src=src, group=group) 104 | 105 | return data 106 | 107 | def send(self, data: Any, dst: int, parallel_context: ParallelContext, parallel_mode: ParallelMode): 108 | _type = type(data) 109 | assert _type in self._INSTRUCTIONS, f"Type {_type} is not supported" 110 | 111 | self._INSTRUCTIONS[_type]["send"](data, dst, parallel_context, parallel_mode) 112 | 113 | def recv(self, src: int, parallel_context: ParallelContext, parallel_mode: ParallelMode) -> torch.Tensor: 114 | # TODO: Add support for other types 115 | return self._INSTRUCTIONS[torch.Tensor]["recv"](src, parallel_context, parallel_mode) 116 | -------------------------------------------------------------------------------- /pipegoose/distributed/parallel_mode.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ParallelMode(Enum): 5 | GLOBAL = "global" 6 | 7 | TENSOR = "tensor" 8 | PIPELINE = "pipeline" 9 | DATA = "data" 10 | 11 | # NOTE: for expert data parallelism 12 | EXPERT_DATA = "expert" 13 | -------------------------------------------------------------------------------- /pipegoose/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from pipegoose.nn.data_parallel.data_parallel import DataParallel 2 | from pipegoose.nn.tensor_parallel.tensor_parallel import TensorParallel 3 | from pipegoose.nn.pipeline_parallel.pipeline_parallel import PipelineParallel 4 | from pipegoose.nn.expert_parallel.expert_parallel import ExpertParallel 5 | -------------------------------------------------------------------------------- /pipegoose/nn/data_parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xrsrke/pipegoose/fe6bcfc2ad4d592fcb11beda41481d9ce8cfc28c/pipegoose/nn/data_parallel/__init__.py -------------------------------------------------------------------------------- /pipegoose/nn/data_parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from torch import nn 6 | 7 | from pipegoose.distributed.functional import all_reduce 8 | from pipegoose.distributed.parallel_context import ParallelContext 9 | from pipegoose.distributed.parallel_mode import ParallelMode 10 | from pipegoose.nn.parallel import Parallel 11 | 12 | 13 | class DataParallel(Parallel): 14 | def __init__(self, module: nn.Module, parallel_context: ParallelContext): 15 | self.module = module 16 | self.parallel_context = parallel_context 17 | 18 | @torch.no_grad() 19 | def parallelize(self) -> nn.Module: 20 | module = self.module 21 | 22 | if self.parallel_context.data_parallel_size > 1: 23 | self._register_grad_avg_hook(module) 24 | self._save_metadata(module, self.parallel_context) 25 | 26 | return module 27 | 28 | def _register_grad_avg_hook(self, module: nn.Module): 29 | for p in module.parameters(): 30 | if p.requires_grad is True: 31 | is_expert = getattr(p, "is_expert", False) 32 | p.register_hook(partial(self._average_grad, is_expert=is_expert)) 33 | 34 | def _average_grad(self, grad: torch.Tensor, is_expert: bool): 35 | # NOTE: (grad1 + grad2 + ... + gradn) / n = grad1/n + grad2/n + ... + gradn/n 36 | grad.div_(self.parallel_context.data_parallel_size) 37 | 38 | all_reduce( 39 | grad, 40 | op=dist.ReduceOp.SUM, 41 | parallel_context=self.parallel_context, 42 | parallel_mode=ParallelMode.EXPERT_DATA if is_expert else ParallelMode.DATA, 43 | ) 44 | -------------------------------------------------------------------------------- /pipegoose/nn/expert_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from pipegoose.nn.expert_parallel.expert_parallel import ExpertParallel 2 | from pipegoose.nn.expert_parallel.loss import ExpertLoss 3 | from pipegoose.nn.expert_parallel.routers import Top1Router, Top2Router, SwitchNoisePolicy 4 | -------------------------------------------------------------------------------- /pipegoose/nn/expert_parallel/expert_context.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | from torchtyping import TensorType 5 | 6 | 7 | class ExpertContext: 8 | _instance = None 9 | 10 | def __init__(self): 11 | self.aux_loss = [] 12 | self.z_loss = [] 13 | 14 | def push_aux_loss(self, aux_loss: TensorType): 15 | self.aux_loss.append(aux_loss) 16 | 17 | def pop_all_aux_loss(self) -> list[TensorType]: 18 | aux_loss, self.aux_loss = self.aux_loss, [] 19 | return aux_loss 20 | 21 | def push_z_loss(self, z_loss: TensorType): 22 | self.z_loss.append(z_loss) 23 | 24 | def pop_all_z_loss(self) -> list[TensorType]: 25 | z_loss, self.z_loss = self.z_loss, [] 26 | return z_loss 27 | 28 | @classmethod 29 | def get_instance(cls) -> ExpertContext: 30 | if not cls._instance: 31 | cls._instance = ExpertContext() 32 | return cls._instance 33 | -------------------------------------------------------------------------------- /pipegoose/nn/expert_parallel/expert_parallel.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Callable, List, Optional, Union 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from pipegoose.distributed.parallel_context import ParallelContext 8 | from pipegoose.distributed.parallel_mode import ParallelMode 9 | from pipegoose.nn.expert_parallel.layers import ExpertLayer 10 | from pipegoose.nn.parallel import Parallel 11 | 12 | 13 | class ExpertParallel(Parallel): 14 | """ 15 | Turn a model into an Mixture of Experts model. 16 | 17 | NOTE: The architecture is based on "Pipeline MoE: A Flexible MoE Implementation with Pipeline Parallelism" by Xin Chen et al. 18 | https://arxiv.org/abs/2304.11414 19 | """ 20 | 21 | def __init__( 22 | self, 23 | module: nn.Module, 24 | num_experts: int, 25 | expert: Optional[nn.Module] = None, 26 | mapping: Optional[List[int]] = None, 27 | router: Callable = None, 28 | # noise_poligy: Union[str, Callable], 29 | enable_tensor_parallelism: bool = False, 30 | parallel_context: ParallelContext = None, 31 | ): 32 | tensor_parallel_size = parallel_context.get_world_size(ParallelMode.TENSOR) 33 | assert parallel_context is not None, "parallel_context must be provided" 34 | assert num_experts % tensor_parallel_size == 0, "The number of experts must be divisible by the tensor parallel size." 35 | num_layers = module.config.num_hidden_layers 36 | assert [ 37 | 0 <= i < num_layers for i in mapping 38 | ], f"There is a layer index that out of range. Expected range: [0, {num_layers}-1]" 39 | 40 | if mapping is None: 41 | # NOTE: default mapping is to parallelize all MLP layers 42 | mapping = list(range(module.config.num_hidden_layers)) 43 | 44 | self.module = module 45 | self.num_experts = num_experts 46 | self.expert = expert 47 | self.mapping = mapping 48 | self.router = router 49 | # self.noise_policy = noise_poligy 50 | self.enable_tensor_parallelism = enable_tensor_parallelism 51 | self.parallel_context = parallel_context 52 | 53 | @torch.no_grad() 54 | def parallelize(self) -> nn.Module: 55 | # TODO: make it generalize 56 | def _is_mlp(name) -> Union[bool, Optional[int]]: 57 | pattern = re.compile(r"^transformer\.h\.(\d+)\.mlp$") 58 | match = pattern.match(name) 59 | if match: 60 | layer_idx = int(match.group(1)) 61 | return True, layer_idx 62 | else: 63 | return False, None 64 | 65 | for name, module in self.module.named_modules(): 66 | is_mlp, layer_idx = _is_mlp(name) 67 | if is_mlp: 68 | if layer_idx in self.mapping: 69 | expert_layer = ExpertLayer( 70 | self.num_experts, 71 | module if self.expert is None else self.expert, 72 | self.router, 73 | self.enable_tensor_parallelism, 74 | self.parallel_context, 75 | ) 76 | # TODO: make it generalize 77 | getattr(self.module, "transformer").h[layer_idx].mlp = expert_layer 78 | 79 | return self.module 80 | 81 | @torch.no_grad() 82 | def deparallelize(self): 83 | pass 84 | -------------------------------------------------------------------------------- /pipegoose/nn/expert_parallel/experts.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from einops import rearrange 7 | from torch import nn 8 | from torchtyping import TensorType 9 | 10 | from pipegoose.distributed.parallel_context import ParallelContext 11 | from pipegoose.distributed.parallel_mode import ParallelMode 12 | from pipegoose.nn.tensor_parallel._functional import all_reduce 13 | 14 | 15 | class Experts(nn.Module): 16 | """A collection of experts in an expert layer.""" 17 | 18 | def __init__( 19 | self, 20 | num_local_experts: int, 21 | expert: nn.Module, 22 | enable_tensor_parallel: bool, 23 | parallel_context: ParallelContext, 24 | ): 25 | super().__init__() 26 | self.enable_tensor_parallel = enable_tensor_parallel 27 | self.parallel_context = parallel_context 28 | 29 | expert = expert() if not isinstance(expert, nn.Module) else expert 30 | self.num_local_experts = num_local_experts 31 | self.experts = nn.ModuleList([deepcopy(expert) for _ in range(num_local_experts)]) 32 | self._set_expert_attr(self.experts) 33 | 34 | def _set_expert_attr(self, experts: nn.ModuleList): 35 | # NOTE: for filtering out the expert parameters later on 36 | # in data parallelism 37 | for expert in experts: 38 | for p in expert.parameters(): 39 | setattr(p, "is_expert", True) 40 | 41 | def forward( 42 | self, 43 | inputs: TensorType["batch_size", "seq_len", "d_model"], 44 | dispatch_order: TensorType["batch_size * seq_len"], 45 | *args, 46 | **kwargs, 47 | ) -> TensorType["batch_size", "seq_len", "d_model"]: 48 | outputs = torch.zeros_like(inputs) 49 | 50 | for expert_idx, expert in enumerate(self.experts): 51 | dispatched_inputs, indices = self._get_dispatch_inputs(inputs, dispatch_order, expert_idx) 52 | if dispatched_inputs.numel() == 0: 53 | # NOTE: if there are no tokens to dispatch to the expert, skip the expert 54 | continue 55 | 56 | if len(args) > 1: 57 | # NOTE: In some transformers models, it also passes last 58 | # hidden states or other arguments to the MLP expert. 59 | # how do we detect this and pass the corresponding arguments to the expert? 60 | # For example, hidden_states.shape = (batch_size, seq_len, hidden_size), 61 | # but we need to dispatch the hidden_states to the corresponding expert 62 | 63 | # NOTE: args[0] is the input embeddings 64 | # args[1] is the hidden_states, so we pass the input embeddings along 65 | # with the hidden_states to the expert 66 | selected_embeddings = rearrange(args[1], "batch_size seq_len d_dim -> (batch_size seq_len) d_dim")[indices] 67 | # selected_embeddings = rearrange(selected_embeddings, "(batch_size seq_len) d_dim -> batch_size seq_len d_dim", batch_size=inputs.shape[0]) 68 | 69 | expert_output = expert(dispatched_inputs, selected_embeddings, **kwargs) 70 | else: 71 | expert_output = expert(dispatched_inputs) 72 | 73 | outputs.view(-1, outputs.size(-1))[indices] = expert_output 74 | 75 | all_reduce( 76 | outputs, 77 | op=dist.ReduceOp.SUM, 78 | parallel_context=self.parallel_context, 79 | parallel_mode=ParallelMode.TENSOR, 80 | ) 81 | 82 | return outputs 83 | 84 | @torch.no_grad() 85 | def _get_dispatch_inputs( 86 | self, 87 | inputs: TensorType["batch_size", "seq_len", "d_model"], 88 | dispatch_order: TensorType["batch_size * seq_len"], 89 | expert_idx: int, 90 | ) -> Tuple[TensorType["batch_size * seq_len", "d_model"], TensorType["batch_size * seq_len"]]: 91 | """Dispatch embeddings to the corresponding expert.""" 92 | 93 | def get_global_expert_idx(expert_idx: int) -> int: 94 | rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR) 95 | global_expert_idx = rank * self.num_local_experts + expert_idx 96 | return global_expert_idx 97 | 98 | global_expert_idx = get_global_expert_idx(expert_idx) 99 | token_indices = (dispatch_order == global_expert_idx).nonzero(as_tuple=True)[0] 100 | inputs = rearrange(inputs, "b s d -> (b s) d") 101 | dispatched_inputs = inputs[token_indices] 102 | return dispatched_inputs, token_indices 103 | -------------------------------------------------------------------------------- /pipegoose/nn/expert_parallel/layers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchtyping import TensorType 3 | 4 | from pipegoose.distributed.parallel_context import ParallelContext 5 | from pipegoose.nn.expert_parallel.experts import Experts 6 | from pipegoose.nn.expert_parallel.routers import Router 7 | from pipegoose.nn.expert_parallel.utils import get_num_local_experts 8 | from pipegoose.nn.expert_parallel.expert_context import ExpertContext 9 | 10 | 11 | class ExpertLayer(nn.Module): 12 | """ 13 | An expert layer. 14 | 15 | NOTE: Switch Transformer: https://arxiv.org/abs/2101.03961 16 | """ 17 | 18 | def __init__( 19 | self, 20 | num_experts: int, 21 | expert: nn.Module, 22 | router: Router, 23 | enable_tensor_parallel: bool, 24 | parallel_context: ParallelContext 25 | ): 26 | super().__init__() 27 | self.router = router 28 | if enable_tensor_parallel is True: 29 | self.num_local_experts = num_experts 30 | else: 31 | self.num_local_experts = get_num_local_experts(num_experts, parallel_context) 32 | 33 | self._experts = Experts(self.num_local_experts, expert, enable_tensor_parallel, parallel_context) 34 | self.parallel_context = parallel_context 35 | 36 | @property 37 | def experts(self) -> nn.ModuleList: 38 | return self._experts.experts 39 | 40 | def forward(self, *args, **kwargs) -> TensorType["batch_size", "seq_len", "d_model"]: 41 | # TODO: use torch.fx to extract the inputs from args, and kwargs 42 | inputs = args[0] 43 | router_output = self.router(inputs) 44 | expert_context = ExpertContext.get_instance() 45 | expert_context.push_aux_loss(router_output.aux_loss) 46 | expert_context.push_z_loss(router_output.z_loss) 47 | outputs = self._experts(inputs, router_output.dispatching_order, *args, **kwargs) 48 | return outputs 49 | -------------------------------------------------------------------------------- /pipegoose/nn/expert_parallel/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List 2 | 3 | from torchtyping import TensorType 4 | 5 | from pipegoose.nn.expert_parallel.expert_context import ExpertContext 6 | 7 | 8 | class ExpertLoss: 9 | def __init__(self, loss_func: Callable, aux_weight: float = 0.01, z_weight: float = 0.1): 10 | self.loss_func = loss_func 11 | self.aux_weight = aux_weight 12 | self.z_weight = z_weight 13 | 14 | @property 15 | def aux_loss(self) -> List[float]: 16 | expert_context = ExpertContext.get_instance() 17 | return expert_context.aux_loss 18 | 19 | @property 20 | def z_loss(self) -> List[float]: 21 | expert_context = ExpertContext.get_instance() 22 | return expert_context.z_loss 23 | 24 | def __call__(self, *args, **kwargs) -> TensorType: 25 | loss = self.loss_func(*args, **kwargs) 26 | expert_context = ExpertContext.get_instance() 27 | loss += self.aux_weight * sum(expert_context.pop_all_aux_loss()) 28 | loss += self.z_weight * sum(expert_context.pop_all_z_loss()) 29 | return loss 30 | -------------------------------------------------------------------------------- /pipegoose/nn/expert_parallel/parallel_mapping.py: -------------------------------------------------------------------------------- 1 | from pipegoose.nn.parallel_mapping import ParallelInfo, ParallelMapping 2 | 3 | 4 | class MLP(ParallelInfo): 5 | pass 6 | 7 | 8 | class ExpertParallelMapping(ParallelMapping): 9 | __MAPPING__ = { 10 | "bloom-560m": [MLP("mlp")], 11 | } 12 | 13 | @staticmethod 14 | def is_mlp(module_name: str) -> bool: 15 | item = ExpertParallelMapping._search(module_name) 16 | if item is None: 17 | return False 18 | return isinstance(item, MLP) 19 | -------------------------------------------------------------------------------- /pipegoose/nn/expert_parallel/utils.py: -------------------------------------------------------------------------------- 1 | from pipegoose.distributed.parallel_context import ParallelContext 2 | from pipegoose.distributed.parallel_mode import ParallelMode 3 | 4 | 5 | def get_num_local_experts(num_experts: int, parallel_context: ParallelContext) -> int: 6 | """Return the number of local experts per device.""" 7 | tensor_parallel_size = parallel_context.get_world_size(ParallelMode.TENSOR) 8 | return num_experts // tensor_parallel_size 9 | -------------------------------------------------------------------------------- /pipegoose/nn/parallel.py: -------------------------------------------------------------------------------- 1 | from abc import abstractclassmethod 2 | from dataclasses import dataclass 3 | from functools import partial 4 | from typing import cast 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from pipegoose.distributed.parallel_context import ParallelContext 10 | from pipegoose.distributed.parallel_mode import ParallelMode 11 | 12 | 13 | @dataclass 14 | class ParallelMetadata: 15 | device: int = None 16 | local_device: int = None 17 | 18 | 19 | class Parallel: 20 | """A base class for a parallelized module.""" 21 | 22 | @abstractclassmethod 23 | def parallelize(self): 24 | """Parallelize the module.""" 25 | raise NotImplementedError 26 | 27 | @abstractclassmethod 28 | def deparallelize(self): 29 | """Deparallelize the module.""" 30 | raise NotImplementedError 31 | 32 | def _save_metadata(self, module: nn.Module, parallel_context: ParallelContext): 33 | def _get_device(parallel_context: ParallelContext) -> int: 34 | rank = parallel_context.get_global_rank() 35 | tp_rank = parallel_context.get_local_rank(ParallelMode.TENSOR) 36 | pp_rank = parallel_context.get_local_rank(ParallelMode.PIPELINE) 37 | dp_rank = parallel_context.get_local_rank(ParallelMode.DATA) 38 | 39 | ranks = ( 40 | (ParallelMode.GLOBAL, rank), 41 | (ParallelMode.TENSOR, tp_rank), 42 | (ParallelMode.PIPELINE, pp_rank), 43 | (ParallelMode.DATA, dp_rank), 44 | ) 45 | device = parallel_context.ranks2device(ranks) 46 | local_device = device % parallel_context.get_world_size(ParallelMode.GLOBAL) 47 | return device, local_device 48 | 49 | device, local_device = _get_device(parallel_context) 50 | parallel_metadata = ParallelMetadata( 51 | device=device, 52 | local_device=local_device, 53 | ) 54 | setattr(module, "parallel_metadata", parallel_metadata) 55 | setattr(module, "to", partial(_to_device, module)) 56 | setattr(module, "cuda", partial(_to_cuda, module)) 57 | 58 | 59 | def _to_device(self, device: str): 60 | """Move a parallelized module to accelerators.""" 61 | SUPPORTED_DEVICES = ["cuda", "gpu"] 62 | 63 | def is_specific_device(device): 64 | import re 65 | 66 | pattern = r"^cuda:[0-9]+$" 67 | if re.match(pattern, device): 68 | return True 69 | return False 70 | 71 | parallel_metadata = cast(ParallelMetadata, getattr(self, "parallel_metadata", None)) 72 | 73 | assert parallel_metadata is not None, "Module is not parallelized yet" 74 | assert device in SUPPORTED_DEVICES, f"Device must be one of {SUPPORTED_DEVICES}, got {device}" 75 | assert not is_specific_device( 76 | device 77 | ), f'Moving to a specific device {device} is not supported. pipegoose will handle device assignment automatically. Please use "cuda" instead' 78 | 79 | if torch.cuda.device_count() == 0: 80 | raise RuntimeError("There are no GPUs available.") 81 | 82 | local_device = parallel_metadata.local_device 83 | for p in self.parameters(): 84 | p.data = p.to(f"cuda:{local_device}") 85 | if p.grad is not None: 86 | p.grad.data = p.grad.to(f"cuda:{local_device}") 87 | 88 | for b in self.buffers(): 89 | b.data = b.to(f"cuda:{local_device}") 90 | 91 | 92 | def _to_cuda(self): 93 | self.to("cuda") 94 | -------------------------------------------------------------------------------- /pipegoose/nn/parallel_mapping.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, cast 2 | 3 | 4 | class ParallelInfo: 5 | def __init__(self, module_name: Tuple[str], **kwargs): 6 | self.module_name = module_name 7 | self.kwargs = kwargs 8 | 9 | 10 | class ParallelMapping: 11 | @staticmethod 12 | def _search(module_name: str) -> Optional[ParallelInfo]: 13 | """ 14 | Search for module_name in mappings. 15 | """ 16 | module_name = ParallelMapping._extract_module_name(module_name) 17 | for child_class in ParallelMapping.__subclasses__(): 18 | if hasattr(child_class, "__MAPPING__"): 19 | for items in child_class.__MAPPING__.values(): 20 | for item in items: 21 | item = cast(ParallelInfo, item) 22 | if any(module_name in mapping_name for mapping_name in item.module_name): 23 | return item 24 | # NOTE: only search the first subclass of the current instance 25 | break 26 | 27 | return None 28 | 29 | @staticmethod 30 | def _extract_module_name(module_name: str) -> str: 31 | if "." in module_name: 32 | # NOTE: transformer.h.0.self_attention.dense -> self_attention.dense 33 | SEPARATOR = "." 34 | sections = module_name.split(SEPARATOR) 35 | return SEPARATOR.join(sections[-2:]) 36 | 37 | return module_name 38 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/_comm.py: -------------------------------------------------------------------------------- 1 | from queue import Queue 2 | from typing import Any 3 | 4 | import torch.distributed.rpc as rpc 5 | 6 | from pipegoose.distributed.parallel_context import ParallelContext 7 | from pipegoose.nn.pipeline_parallel._package import Package 8 | 9 | RECV_QUEUE = Queue() 10 | 11 | 12 | def _send_data(data: Any, src: int, dst: int, parallel_context: ParallelContext): 13 | dst_worker_name = parallel_context.get_worker_name(dst) 14 | rpc.rpc_sync(to=dst_worker_name, func=_recv_package, args=(data, src, dst)) 15 | 16 | 17 | def send_package(package: Package, parallel_context: ParallelContext): 18 | """Send a package to another pipeline stage based on the metadata of the package.""" 19 | 20 | assert isinstance(package, Package) 21 | 22 | rank = parallel_context.get_global_rank() 23 | 24 | if package.metadata.src == rank: 25 | dst = package.metadata.dst 26 | _send_data(package, src=rank, dst=dst, parallel_context=parallel_context) 27 | 28 | 29 | def _recv_package(package: Package, src: int, dst: int): 30 | """ 31 | Receive a package from another pipeline stage. 32 | 33 | NOTE: only be triggered by send_package. 34 | """ 35 | # TODO: add configurable destination queue 36 | assert isinstance(package, Package) 37 | 38 | package.metadata.microbatch_idx 39 | package.metadata.partition_idx 40 | 41 | RECV_QUEUE.put(package) 42 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/_job/callback.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from enum import Enum 3 | 4 | 5 | class CallbackEvent(Enum): 6 | """Enum for callback events.""" 7 | 8 | AFTER_CREATE = "after_create" 9 | 10 | BEFORE_COMPUTE = "before_compute" 11 | AFTER_COMPUTE = "after_compute" 12 | 13 | 14 | class Callback(ABC): 15 | """Callback for a job.""" 16 | 17 | order = 0 18 | 19 | @property 20 | def name(self) -> str: 21 | return self.__name__ 22 | 23 | def after_create(self): 24 | pass 25 | 26 | def before_compute(self): 27 | pass 28 | 29 | def after_compute(self): 30 | pass 31 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/_job/job.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from enum import Enum, auto 3 | from typing import Callable, List, NewType, Optional 4 | 5 | from pipegoose.nn.pipeline_parallel._job.callback import Callback, CallbackEvent 6 | from pipegoose.nn.pipeline_parallel._package import Package 7 | 8 | 9 | class JobStatus(Enum): 10 | # NOTE: wait for a worker pick up this job and execute it 11 | PENDING = auto() # just created and putted into job queue 12 | EXECUTING = auto() 13 | EXECUTED = auto() # executed but not sent the output to another pipeline stage 14 | DONE = auto() # executed and sent the output to another pipeline stage 15 | FAILED = auto() # failed to execute 16 | 17 | 18 | # NOTE: Both forward job and backward pass of 19 | # the same model partition has the same key 20 | PartitionKey = NewType("PartitionKey", str) 21 | 22 | 23 | class Job(ABC): 24 | """A job that will be executed by a worker.""" 25 | 26 | def __init__(self, function: Callable, input: Package, cbs: List[Callback] = []): 27 | self.function = function 28 | self.input = input 29 | self.cbs = [] 30 | 31 | self._status = JobStatus.PENDING 32 | self._output = None 33 | 34 | def generate_random_string(length=15): 35 | import random 36 | import string 37 | 38 | characters = string.ascii_letters + string.digits 39 | return "".join(random.choice(characters) for i in range(length)) 40 | 41 | self._key = generate_random_string() 42 | 43 | self.add_cbs(cbs) 44 | self._run_callback(CallbackEvent.AFTER_CREATE) 45 | 46 | @property 47 | def status(self) -> JobStatus: 48 | return self._status 49 | 50 | @property 51 | def key(self) -> str: 52 | return self._key 53 | 54 | @property 55 | def output(self) -> Optional[Package]: 56 | return self._output 57 | 58 | @output.setter 59 | def output(self, value: Optional[Package]): 60 | self._output = value 61 | 62 | def compute(self) -> Optional[Package]: 63 | try: 64 | self._run_callback(CallbackEvent.BEFORE_COMPUTE) 65 | 66 | # TODO: refactor make other callbacks to be able to access the output of a job 67 | self._output = self.run_compute() 68 | 69 | # TODO: turn the update of job status into a callback 70 | self._status = JobStatus.EXECUTED 71 | 72 | self._run_callback(CallbackEvent.AFTER_COMPUTE) 73 | 74 | return self.output 75 | except Exception as e: 76 | raise e 77 | 78 | def add_cbs(self, cbs: List[Callback]): 79 | """Add a list of callbacks to this job.""" 80 | for cb in cbs: 81 | self.add_cb(cb) 82 | 83 | def remove_cbs(self, cbs: List[Callback]): 84 | for cb in cbs: 85 | self.remove_cb(cb) 86 | 87 | def add_cb(self, cb: Callback): 88 | """Add a callback to this job.""" 89 | if isinstance(cb, type): 90 | cb = cb() 91 | 92 | assert isinstance(cb, Callback), f"cb must be an instance of Callback, got {type(cb)}" 93 | 94 | # NOTE: lets the callback access the job attributes 95 | cb.job = self 96 | self.cbs.append(cb) 97 | 98 | def remove_cb(self, cb: Callback): 99 | """Remove a callback from this job.""" 100 | # NOTE: if cb is a class 101 | if isinstance(cb, type): 102 | cbs = [x for x in self.cbs if isinstance(x, cb)] 103 | self.remove_cbs(cbs) 104 | else: 105 | if cb in self.cbs: 106 | self.cbs.remove(cb) 107 | 108 | def _run_callback(self, event_name: CallbackEvent): 109 | assert isinstance( 110 | event_name, CallbackEvent 111 | ), f"event_name must be an instance of CallbackEvent, got {type(event_name)}" 112 | 113 | sorted_cbs = sorted(self.cbs, key=lambda x: x.order) 114 | # NOTE: get the value of an enum member 115 | event_name = event_name.value 116 | 117 | for cb in sorted_cbs: 118 | event_method = getattr(cb, event_name, None) 119 | if event_method is not None: 120 | event_method() 121 | 122 | @abstractmethod 123 | def run_compute(self): 124 | """The actual computation of this job.""" 125 | raise NotImplementedError("not implemented") 126 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/_job/job_type.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | 3 | 4 | class JobType(Enum): 5 | FORWARD = auto() 6 | BACKWARD = auto() 7 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/_job/register.py: -------------------------------------------------------------------------------- 1 | from queue import Queue 2 | 3 | from pipegoose.nn.pipeline_parallel._job.job import Job 4 | 5 | 6 | class _JobRegister: 7 | def __init__(self, queue: Queue): 8 | self.queue = queue 9 | 10 | def registry(self, job: Job): 11 | assert isinstance(job, Job), f"job must be an instance of Job, got {type(job)}" 12 | self.queue.put(job) 13 | 14 | 15 | def add_job_to_queue(job: Job, queue: Queue): 16 | job_register = _JobRegister(queue) 17 | job_register.registry(job) 18 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/_package.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | 5 | from pipegoose.nn.pipeline_parallel._job.job_type import JobType 6 | 7 | 8 | @dataclass 9 | class TrainingMetadata: 10 | is_training: bool 11 | is_grad_enabled: bool 12 | 13 | 14 | @dataclass 15 | class Metadata: 16 | """Metadata for the output of a job.""" 17 | 18 | # pipeline 19 | # the index of the microbatch and partition that return this package 20 | microbatch_idx: int 21 | partition_idx: int 22 | 23 | job_type: JobType 24 | 25 | training: TrainingMetadata 26 | 27 | # global rank 28 | src: int 29 | dst: int 30 | 31 | 32 | class Package: 33 | """A data package that will be sent from one pipeline stage to another.""" 34 | 35 | def __init__(self, data: torch.Tensor, metadata: Metadata): 36 | self.data = data 37 | self.metadata = metadata 38 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from pipegoose.distributed.parallel_context import ParallelContext 4 | from pipegoose.distributed.parallel_mode import ParallelMode 5 | 6 | 7 | def sleep(timeout: int = 0.05): 8 | time.sleep(timeout) 9 | 10 | 11 | def get_partition_idx(parallel_context: ParallelContext) -> int: 12 | rank = parallel_context.get_global_rank() 13 | ranks_in_group = parallel_context.get_ranks_in_group(ParallelMode.PIPELINE) 14 | # pipeline_stage_idx = rank // n_ranks_per_group 15 | # return pipeline_stage_idx 16 | return ranks_in_group.index(rank) 17 | 18 | 19 | def is_last_stage(parallel_context: ParallelContext) -> bool: 20 | partition_idx = get_partition_idx(parallel_context) 21 | n_stages = parallel_context.pipeline_parallel_size 22 | return partition_idx == (n_stages - 1) 23 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/exception.py: -------------------------------------------------------------------------------- 1 | class PipelineGradientFlowError(Exception): 2 | """The gradients can't flow to leaf tensors""" 3 | 4 | 5 | class PipelineNoSavedActivationError(Exception): 6 | """Can't find saved activations to do backpropogation""" 7 | 8 | 9 | class PipelineNoSavedInput(Exception): 10 | """Can't find the input activations to return the gradients""" 11 | 12 | 13 | class PipelineInputNotRequiresGrad(Exception): 14 | """The input of the pipeline stage must requires grad in order for gradients to flow back""" 15 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/microbatch.py: -------------------------------------------------------------------------------- 1 | from typing import List, TypedDict 2 | 3 | import torch 4 | 5 | 6 | class ModelInputs(TypedDict): 7 | input_ids: torch.Tensor 8 | attention_mask: torch.Tensor 9 | 10 | 11 | def split(inputs: ModelInputs, n_microbatches: int) -> List[ModelInputs]: 12 | assert n_microbatches > 0, f"n_microbatches must be greater than 0, got {n_microbatches}" 13 | assert "input_ids" in inputs, f"inputs must have 'input_ids' key, got {inputs.keys()}" 14 | assert "attention_mask" in inputs, f"inputs must have 'attention_mask' key, got {inputs.keys()}" 15 | assert ( 16 | inputs["input_ids"].size(0) % n_microbatches == 0 17 | ), f"The batch size must be divisible by n_microbatches, got {inputs['input_ids'].size(0)} and {n_microbatches}" 18 | 19 | input_ids_microbatches = torch.split(inputs["input_ids"], n_microbatches) 20 | attention_mask_microbatches = torch.split(inputs["attention_mask"], n_microbatches) 21 | 22 | microbatches = [] 23 | for input_ids, attention_mask in zip(input_ids_microbatches, attention_mask_microbatches): 24 | microbatches.append(ModelInputs(input_ids=input_ids, attention_mask=attention_mask)) 25 | 26 | return microbatches 27 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/pipeline.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from pipegoose.constants import PIPELINE_MAX_WORKERS, PIPELINE_MIN_WORKERS 4 | from pipegoose.distributed.parallel_context import ParallelContext 5 | from pipegoose.nn.pipeline_parallel.scheduler import SchedulerType, get_scheduler 6 | 7 | 8 | class _PipelineEngine: 9 | """Turn a 🤗 transformers model into a pipeline parallel model.""" 10 | 11 | def __init__( 12 | self, 13 | module: nn.Module, 14 | num_concurrent: int = PIPELINE_MIN_WORKERS, 15 | max_concurrent: int = PIPELINE_MAX_WORKERS, 16 | scheduler: SchedulerType = SchedulerType.GPIPE, 17 | parallel_context: ParallelContext = None, 18 | ): 19 | assert num_concurrent <= max_concurrent, "num_concurrent must be less than or equal to max_concurrent" 20 | assert parallel_context is not None, "parallel_context must be provided" 21 | 22 | assert isinstance( 23 | parallel_context, ParallelContext 24 | ), f"parallel_context must be an instance of ParallelContext, got {type(parallel_context)}" 25 | assert isinstance(module, nn.Module), f"module must be an instance of nn.Module, got {type(module)}" 26 | assert isinstance(num_concurrent, int), f"num_concurrent must be an instance of int, got {type(num_concurrent)}" 27 | assert isinstance(max_concurrent, int), f"max_concurrent must be an instance of int, got {type(max_concurrent)}" 28 | 29 | self.module = module 30 | self.num_concurrent = num_concurrent 31 | self.max_concurrent = max_concurrent 32 | self.scheduler = get_scheduler(scheduler) 33 | self.parallel_context = parallel_context 34 | 35 | def parallelize(self) -> nn.Module: 36 | # TODO: wrap the model with a pipeline parallel model 37 | pass 38 | 39 | def forward(self, batches): 40 | partitions = None 41 | 42 | len(batches) 43 | len(partitions) 44 | 45 | with spawn_workers(): 46 | pass 47 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/pipeline_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from pipegoose.distributed.parallel_context import ParallelContext 5 | from pipegoose.nn.parallel import Parallel 6 | from pipegoose.nn.pipeline_parallel._utils import get_partition_idx 7 | from pipegoose.nn.pipeline_parallel._worker import WorkerManager 8 | from pipegoose.nn.pipeline_parallel.partitioner import UniformPartitioner 9 | from pipegoose.nn.pipeline_parallel.pipeline_engine import PipelineEngine 10 | from pipegoose.nn.pipeline_parallel.scheduler import GPipeScheduler 11 | 12 | 13 | class PipelineParallel(Parallel): 14 | """Automatically parallelize a module using pipeline parallelism.""" 15 | 16 | def __init__( 17 | self, 18 | module: nn.Module, 19 | num_microbatches: int, 20 | parallel_context: ParallelContext, 21 | ): 22 | self.module = module 23 | self.num_microbatches = num_microbatches 24 | self.parallel_context = parallel_context 25 | 26 | @torch.no_grad() 27 | def parallelize(self) -> nn.Module: 28 | if self.parallel_context.pipeline_parallel_size > 1: 29 | partition_idx = get_partition_idx(self.parallel_context) 30 | partitions = UniformPartitioner(self.module, self.parallel_context).split(["input_ids"]) 31 | module = partitions[partition_idx] 32 | 33 | n_partitions = self.parallel_context.pipeline_parallel_size 34 | scheduler = GPipeScheduler(self.num_microbatches, n_partitions) 35 | worker_manager = WorkerManager() 36 | 37 | pipeline_engine = PipelineEngine( 38 | module=module, 39 | scheduler=scheduler, 40 | worker_manager=worker_manager, 41 | parallel_context=self.parallel_context, 42 | ) 43 | 44 | module.forward = pipeline_engine.run 45 | 46 | self._save_metadata(module, self.parallel_context) 47 | 48 | return module 49 | else: 50 | return self.module 51 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/queue.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from queue import Queue 3 | from typing import Any, Dict, NewType, Tuple 4 | 5 | import torch 6 | 7 | from pipegoose.nn.pipeline_parallel.exception import ( 8 | PipelineNoSavedActivationError, 9 | PipelineNoSavedInput, 10 | ) 11 | 12 | ActivationKey = NewType("ActivationKey", Tuple[int, int]) 13 | 14 | # NOTE: the activations that received from earlier stages 15 | _INPUT_ACTIVATIONS: Dict[ActivationKey, torch.Tensor] = {} 16 | 17 | # NOTE: save activations from forward job for backward job 18 | _SAVED_ACTIVATIONS: Dict[ActivationKey, torch.Tensor] = {} 19 | 20 | _SAVED_SCHEDULED_ACTIVATIONS: Dict[ActivationKey, torch.Tensor] = {} 21 | 22 | _SAVED_GRAD_LOSS: Dict[ActivationKey, torch.Tensor] = {} 23 | 24 | _SAVED_METADATA_of_GRAD_LOSS: Dict[ActivationKey, Any] = {} 25 | 26 | 27 | @dataclass 28 | class JobQueue: 29 | """A queue for storing jobs.""" 30 | 31 | PENDING_JOBS = Queue() 32 | SELECTED_JOBS = Queue() 33 | FINISHED_JOBS = Queue() 34 | 35 | 36 | class SavedActivation: 37 | """A class for saving activations from forward job for backward job.""" 38 | 39 | @staticmethod 40 | def is_saved(microbatch_idx: int, partition_idx: int) -> bool: 41 | key = SavedActivation.get_key(microbatch_idx, partition_idx) 42 | return key in _SAVED_ACTIVATIONS 43 | 44 | @staticmethod 45 | def get_key(microbatch_idx: int, partition_idx: int) -> ActivationKey: 46 | return (microbatch_idx, partition_idx) 47 | 48 | @staticmethod 49 | def get_saved_activations(key: ActivationKey) -> torch.Tensor: 50 | """Get the saved activations for a given key for backward job.""" 51 | # NOTE: because a partition can have multiple microbatches, 52 | return _SAVED_ACTIVATIONS.pop(key) 53 | 54 | def save_activations(key: ActivationKey, data: torch.Tensor, is_by_schedule: bool = False): 55 | """Save forward job's activations for backward job.""" 56 | _SAVED_ACTIVATIONS[key] = data 57 | 58 | 59 | class InputActivations: 60 | """A class for saving activations from forward job for backward job.""" 61 | 62 | @staticmethod 63 | def get_key(microbatch_idx: int, partition_idx: int) -> ActivationKey: 64 | return (microbatch_idx, partition_idx) 65 | 66 | @staticmethod 67 | def is_saved(microbatch_idx: int, partition_idx: int) -> bool: 68 | key = InputActivations.get_key(microbatch_idx, partition_idx) 69 | return key in _INPUT_ACTIVATIONS 70 | 71 | @staticmethod 72 | def get_saved_activations(key: ActivationKey) -> torch.Tensor: 73 | """Get the saved activations for a given key for backward job.""" 74 | # NOTE: because a partition can have multiple microbatches, 75 | input = _INPUT_ACTIVATIONS[key] 76 | 77 | # return input.requires_grad_(True) 78 | # TODO: add support regular non-transformers model 79 | if isinstance(input, torch.Tensor): 80 | return input.requires_grad_(True) 81 | else: 82 | return input 83 | 84 | def save_activations(key: ActivationKey, data: torch.Tensor): 85 | """Save forward job's activations for backward job.""" 86 | _INPUT_ACTIVATIONS[key] = data 87 | 88 | 89 | def save_input_activations(input: torch.Tensor, microbatch_idx: int, partition_idx: int): 90 | # input.requires_grad = True 91 | key = InputActivations.get_key(microbatch_idx, partition_idx) 92 | InputActivations.save_activations(key, input) 93 | 94 | 95 | def get_input_activations(microbatch_idx: int, partition_idx: int) -> torch.Tensor: 96 | key = InputActivations.get_key(microbatch_idx, partition_idx) 97 | try: 98 | return InputActivations.get_saved_activations(key) 99 | except KeyError: 100 | raise PipelineNoSavedInput( 101 | f"Can't find the input activations to return the gradients for \ 102 | microbatch_idx={microbatch_idx}, partition_idx={partition_idx}" 103 | ) 104 | 105 | 106 | def save_output_activations(output: torch.Tensor, microbatch_idx: int, partition_idx: int): 107 | key = SavedActivation.get_key(microbatch_idx, partition_idx) 108 | SavedActivation.save_activations(key, output) 109 | 110 | 111 | def get_output_activations(microbatch_idx: int, partition_idx: int, is_pipeline: bool = False) -> torch.Tensor: 112 | key = SavedActivation.get_key(microbatch_idx, partition_idx) 113 | 114 | try: 115 | output = _SAVED_ACTIVATIONS[key] 116 | if is_pipeline is True: 117 | return output.requires_grad_(True) 118 | else: 119 | return output.detach().requires_grad_(True) 120 | except KeyError: 121 | raise PipelineNoSavedActivationError( 122 | f"Can't find saved activations to do backpropogation for \ 123 | microbatch_idx={microbatch_idx}, partition_idx={partition_idx}" 124 | ) 125 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/scheduler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractclassmethod 2 | from enum import Enum, auto 3 | from typing import List 4 | 5 | from pipegoose.nn.pipeline_parallel._job.job_type import JobType 6 | from pipegoose.nn.pipeline_parallel.task import Task 7 | 8 | 9 | class SchedulerType(Enum): 10 | GPIPE = auto() 11 | 12 | 13 | class BaseScheduler(ABC): 14 | @abstractclassmethod 15 | def get_schedules(self): 16 | """Return the schedule for the whole training run.""" 17 | raise NotImplementedError 18 | 19 | @abstractclassmethod 20 | def get_forward_schedules(self): 21 | """Return the forward schedule for the whole training run.""" 22 | raise NotImplementedError 23 | 24 | @abstractclassmethod 25 | def get_backward_schedules(self): 26 | """Return the backward schedule for the whole training run.""" 27 | raise NotImplementedError 28 | 29 | @abstractclassmethod 30 | def total_clock_cycles(self): 31 | """Return the total number of clock cycles.""" 32 | raise NotImplementedError 33 | 34 | 35 | class GPipeScheduler(BaseScheduler): 36 | """ 37 | torchgpipe: On-the-fly Pipeline Parallelism for Training Giant Models 38 | https://arxiv.org/abs/2004.09910 39 | 40 | Section 3.2.1: Forward Dependency: Deterministic Clock-cycle 41 | """ 42 | 43 | def __init__(self, n_microbatches: int, n_partitions: int): 44 | assert ( 45 | n_microbatches > 0 46 | ), "The number of microbatches must be \ 47 | greater than 0" 48 | assert ( 49 | n_partitions > 0 50 | ), "The number of partitions must be \ 51 | greater than 0" 52 | 53 | self.n_microbatches = n_microbatches 54 | self.n_partitions = n_partitions 55 | 56 | def get_schedules(self) -> List[List[Task]]: 57 | forward_schedules = self.get_forward_schedules() 58 | backward_schedules = self.get_backward_schedules() 59 | 60 | # NOTE: combine forward and backward schedule into a full schedule 61 | forward_schedules.extend(backward_schedules) 62 | 63 | return forward_schedules 64 | 65 | def get_forward_schedules(self) -> List[List[Task]]: 66 | schedules = [] 67 | n_clock_cycles = self.n_partitions + self.n_microbatches - 1 68 | for clock_idx in range(n_clock_cycles): 69 | start_partrition = max(clock_idx + 1 - self.n_microbatches, 0) 70 | end_partition = min(clock_idx + 1, self.n_partitions) 71 | 72 | tasks = [] 73 | for partition_idx in range(start_partrition, end_partition): 74 | microbatch_idx = clock_idx - partition_idx 75 | task = Task(JobType.FORWARD, microbatch_idx, partition_idx) 76 | tasks.append(task) 77 | 78 | schedules.append(tasks) 79 | return schedules 80 | 81 | def get_backward_schedules(self) -> List[List[Task]]: 82 | from copy import deepcopy 83 | 84 | forward_schedules = self.get_forward_schedules() 85 | n_clock_cycles = len(forward_schedules) 86 | backward_schedules = deepcopy(forward_schedules) 87 | backward_schedules.reverse() 88 | 89 | for clock_idx in range(n_clock_cycles): 90 | for task in backward_schedules[clock_idx]: 91 | task.job_type = JobType.BACKWARD 92 | 93 | return backward_schedules 94 | 95 | @property 96 | def total_clock_cycles(self) -> int: 97 | return len(self.get_schedules()) 98 | 99 | @property 100 | def total_forward_clock_cycles(self) -> int: 101 | """Return the total number of clock cycles required to run the forward pass.""" 102 | return len(self.get_forward_schedules()) 103 | 104 | @property 105 | def total_backward_clock_cycles(self) -> int: 106 | """Return the total number of clock cycles required to run the forward pass.""" 107 | return len(self.get_backward_schedules()) 108 | 109 | 110 | def get_scheduler(scheduler_type: SchedulerType) -> BaseScheduler: 111 | scheduler_type_to_scheduler = { 112 | SchedulerType.GPIPE: GPipeScheduler, 113 | } 114 | 115 | return scheduler_type_to_scheduler[scheduler_type] 116 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/sync/callback.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | class Callback: 5 | order = 0 6 | 7 | def after_new_clock_cycle(self, progress: Dict, clock_idx: int): 8 | raise NotImplementedError 9 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/sync/progress_tracker.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from pipegoose.nn.pipeline_parallel.pipeline_context import PipelineContext 4 | 5 | 6 | def get_progresses_from_pipeline_context(pipeline_context: PipelineContext) -> Dict: 7 | schedules = pipeline_context.schedules 8 | progresses = { 9 | i: {(item.microbatch_idx, item.partition_idx): False for item in sublist} for i, sublist in enumerate(schedules) 10 | } 11 | return progresses 12 | -------------------------------------------------------------------------------- /pipegoose/nn/pipeline_parallel/task.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from pipegoose.nn.pipeline_parallel._job.job_type import JobType 4 | 5 | 6 | @dataclass 7 | class Task: 8 | job_type: JobType 9 | microbatch_idx: int 10 | partition_idx: int 11 | -------------------------------------------------------------------------------- /pipegoose/nn/tensor_parallel/_functional.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inspired from OSLO: https://github.com/EleutherAI/oslo/blob/d7c4e32e766a99cc9d56533bc090570360dc8b2a/oslo/torch/nn/parallel/tensor_parallel/_1d/_ops.py#L17 3 | """ 4 | 5 | from typing import Any, Tuple 6 | 7 | import torch 8 | from torch.autograd import Function 9 | 10 | from pipegoose.distributed.functional import all_gather, all_reduce, scatter 11 | from pipegoose.distributed.parallel_context import ParallelContext 12 | from pipegoose.distributed.parallel_mode import ParallelMode 13 | 14 | 15 | class _Broadcast(Function): 16 | @staticmethod 17 | def forward(ctx, tensor: torch.Tensor, parallel_context: ParallelContext) -> torch.Tensor: 18 | ctx.parallel_context = parallel_context 19 | 20 | return tensor 21 | 22 | @staticmethod 23 | def backward(ctx: Any, grad: torch.Tensor) -> Tuple[torch.Tensor, None, None]: 24 | parallel_context = ctx.parallel_context 25 | 26 | all_reduce(grad, parallel_context=parallel_context, parallel_mode=ParallelMode.TENSOR) 27 | 28 | return (grad, None, None) 29 | 30 | 31 | class _Gather(Function): 32 | @staticmethod 33 | def forward(ctx: Any, input: torch.Tensor, dim: int, parallel_context: ParallelContext) -> torch.Tensor: 34 | ctx.dim = dim 35 | ctx.parallel_context = parallel_context 36 | 37 | return all_gather(input, dim=dim, async_op=False, parallel_context=parallel_context, parallel_mode=ParallelMode.TENSOR) 38 | 39 | @staticmethod 40 | def backward(ctx: Any, grad: torch.Tensor) -> Tuple[torch.Tensor, None, None]: 41 | dim = ctx.dim 42 | parallel_context = ctx.parallel_context 43 | 44 | return ( 45 | scatter(grad, dim=dim, parallel_context=parallel_context, parallel_mode=ParallelMode.TENSOR), 46 | None, 47 | None, 48 | ) 49 | 50 | 51 | class _Scatter(Function): 52 | @staticmethod 53 | def forward(ctx: Any, input: torch.Tensor, dim: int, parallel_context: ParallelContext) -> torch.Tensor: 54 | ctx.dim = dim 55 | ctx.parallel_context = parallel_context 56 | return scatter(input, dim=dim, parallel_context=parallel_context, parallel_mode=ParallelMode.TENSOR) 57 | 58 | @staticmethod 59 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]: 60 | dim = ctx.dim 61 | parallel_context = ctx.parallel_context 62 | 63 | return ( 64 | all_gather( 65 | grad_output, dim=dim, async_op=False, parallel_context=parallel_context, parallel_mode=ParallelMode.TENSOR 66 | ), 67 | None, 68 | None, 69 | ) 70 | 71 | 72 | class _Reduce(Function): 73 | @staticmethod 74 | def forward(ctx: Any, input: torch.Tensor, parallel_context: ParallelContext) -> torch.Tensor: 75 | return all_reduce(input, parallel_context=parallel_context, parallel_mode=ParallelMode.TENSOR) 76 | 77 | @staticmethod 78 | def backward(ctx: Any, grad: torch.Tensor) -> Tuple[torch.Tensor, None]: 79 | return (grad, None) 80 | 81 | 82 | def broadcast_to_tensor_group(input: torch.Tensor, parallel_context: ParallelContext): 83 | return _Broadcast.apply(input, parallel_context) 84 | 85 | 86 | def gather_to_tensor_group(input: torch.Tensor, dim: int, parallel_context: ParallelContext): 87 | return _Gather.apply(input, dim, parallel_context) 88 | 89 | 90 | def scatter_to_tensor_group(input: torch.Tensor, dim: int, parallel_context: ParallelContext): 91 | return _Scatter.apply(input, dim, parallel_context) 92 | 93 | 94 | def reduce_to_tensor_group(input: torch.Tensor, parallel_context: ParallelContext): 95 | return _Reduce.apply(input, parallel_context) 96 | -------------------------------------------------------------------------------- /pipegoose/nn/tensor_parallel/_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | 4 | class VocabUtility: 5 | @staticmethod 6 | def get_vocab_range_idx_from_partition_size(partition_size: int, rank: int) -> Tuple[int, int]: 7 | start_idx = rank * partition_size 8 | end_idx = start_idx + partition_size 9 | return start_idx, end_idx 10 | 11 | @staticmethod 12 | def get_vocab_range_from_global_vocab_size(world_size, rank, vocab_size): 13 | partition_size = vocab_size // world_size 14 | return VocabUtility.get_vocab_range_idx_from_partition_size(partition_size, rank) 15 | -------------------------------------------------------------------------------- /pipegoose/nn/tensor_parallel/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from pipegoose.distributed.parallel_context import ParallelContext 6 | from pipegoose.distributed.parallel_mode import ParallelMode 7 | from pipegoose.nn.tensor_parallel._functional import reduce_to_tensor_group 8 | from pipegoose.nn.tensor_parallel._utils import VocabUtility 9 | 10 | 11 | class ParallelEmbedding(nn.Module): 12 | def __init__(self, num_embeddings: int, embedding_dim: int, parallel_context: ParallelContext): 13 | super().__init__() 14 | world_size = parallel_context.get_world_size(ParallelMode.TENSOR) 15 | 16 | assert num_embeddings % world_size == 0, "num_embeddings must be divisible by world_size" 17 | 18 | num_embeddings_per_partition = num_embeddings // world_size 19 | self.parallel_context = parallel_context 20 | self.weight = nn.Parameter(torch.randn(num_embeddings_per_partition, embedding_dim)) 21 | self.vocab_start_idx, self.vocab_end_idx = VocabUtility.get_vocab_range_idx_from_partition_size( 22 | num_embeddings_per_partition, rank=parallel_context.get_local_rank(ParallelMode.TENSOR) 23 | ) 24 | self.world_size = world_size 25 | 26 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 27 | if self.world_size > 1: 28 | input_mask = (inputs < self.vocab_start_idx) | (inputs >= self.vocab_end_idx) 29 | # NOTE: align global embedding indices to local embedding indices 30 | masked_input = inputs.clone() - self.vocab_start_idx 31 | masked_input[input_mask] = 0 32 | else: 33 | masked_input = inputs 34 | 35 | parallel_output = F.embedding(masked_input, self.weight) 36 | 37 | if self.world_size > 1: 38 | parallel_output[input_mask, :] = 0.0 39 | 40 | output = reduce_to_tensor_group(parallel_output, parallel_context=self.parallel_context) 41 | 42 | return output 43 | -------------------------------------------------------------------------------- /pipegoose/nn/tensor_parallel/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from pipegoose.distributed.parallel_context import ParallelContext 6 | 7 | 8 | class LayerNorm(nn.Module): 9 | def __init__(self, normalized_shape: int, eps: float = 1e-5, bias: bool = True, parallel_context: ParallelContext = None): 10 | super().__init__() 11 | assert parallel_context is not None, "parallel_context must be provided" 12 | 13 | self.normalized_shape = normalized_shape 14 | self.eps = eps 15 | self.parallel_context = parallel_context 16 | 17 | self.weight = nn.Parameter(torch.ones(self.normalized_shape)) 18 | if bias: 19 | self.bias = nn.Parameter(torch.zeros(self.normalized_shape)) 20 | else: 21 | self.register_parameter("bias", None) 22 | 23 | def forward(self, input: torch.Tensor) -> torch.Tensor: 24 | normalized_shape = (self.normalized_shape,) if isinstance(self.normalized_shape, int) else self.normalized_shape 25 | return F.layer_norm(input, normalized_shape, self.weight, self.bias, self.eps) 26 | -------------------------------------------------------------------------------- /pipegoose/nn/tensor_parallel/linear.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from pipegoose.distributed.parallel_context import ParallelContext 8 | from pipegoose.distributed.parallel_mode import ParallelMode 9 | from pipegoose.nn.tensor_parallel._functional import ( 10 | broadcast_to_tensor_group, 11 | gather_to_tensor_group, 12 | reduce_to_tensor_group, 13 | scatter_to_tensor_group, 14 | ) 15 | 16 | 17 | class ColumnParallelLinear(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | out_features: int, 22 | bias: bool = True, 23 | gather_output: bool = False, 24 | parallel_context: Optional[ParallelContext] = None, 25 | ): 26 | super().__init__() 27 | out_per_partition = self._get_output_per_partition(out_features, parallel_context) 28 | 29 | self.gather_output = gather_output 30 | self.parallel_context = parallel_context 31 | self.weight = nn.Parameter(torch.randn(out_per_partition, in_features)) 32 | 33 | if bias is True: 34 | self.bias = nn.Parameter(torch.randn(out_per_partition)) 35 | 36 | def _get_output_per_partition(self, out_features: int, parallel_context: ParallelContext) -> int: 37 | local_world_size = parallel_context.get_world_size(ParallelMode.TENSOR) 38 | return out_features // local_world_size 39 | 40 | def forward(self, input: torch.Tensor) -> torch.Tensor: 41 | input_parallel = broadcast_to_tensor_group(input, self.parallel_context) 42 | outputs = F.linear(input_parallel, self.weight) 43 | 44 | if self.bias is not None: 45 | outputs = outputs + self.bias 46 | 47 | if self.gather_output: 48 | outputs = gather_to_tensor_group(outputs, dim=-1, parallel_context=self.parallel_context) 49 | 50 | return outputs 51 | 52 | 53 | class RowParallelLinear(nn.Module): 54 | def __init__( 55 | self, 56 | in_features: int, 57 | out_features: int, 58 | bias: bool = True, 59 | parallel_context: Optional[ParallelContext] = None, 60 | ) -> None: 61 | super().__init__() 62 | in_per_partition = self._get_input_per_partition(in_features, parallel_context) 63 | 64 | self.parallel_context = parallel_context 65 | self.weight = nn.Parameter(torch.randn(out_features, in_per_partition)) 66 | 67 | if bias is True: 68 | self.bias = nn.Parameter(torch.randn(out_features)) 69 | 70 | def _get_input_per_partition(self, in_features: int, parallel_context: ParallelContext) -> int: 71 | local_world_size = parallel_context.get_world_size(ParallelMode.TENSOR) 72 | return in_features // local_world_size 73 | 74 | def forward(self, input: torch.Tensor) -> torch.Tensor: 75 | input_parallel = scatter_to_tensor_group(input, dim=-1, parallel_context=self.parallel_context) 76 | output_parallel = F.linear(input_parallel, self.weight) 77 | outputs = reduce_to_tensor_group(output_parallel, parallel_context=self.parallel_context) 78 | 79 | if self.bias is not None: 80 | outputs = outputs + self.bias 81 | 82 | return outputs 83 | -------------------------------------------------------------------------------- /pipegoose/nn/tensor_parallel/parallel_mapping.py: -------------------------------------------------------------------------------- 1 | from pipegoose.nn.parallel_mapping import ParallelInfo, ParallelMapping 2 | 3 | 4 | class Column(ParallelInfo): 5 | pass 6 | 7 | 8 | class Row(ParallelInfo): 9 | pass 10 | 11 | 12 | class LMHead(ParallelInfo): 13 | pass 14 | 15 | 16 | class TensorParallelMapping(ParallelMapping): 17 | """ 18 | NOTE: Inspired from OSLO's Parallel Mapping 19 | https://github.com/EleutherAI/oslo/blob/d7c4e32e766a99cc9d56533bc090570360dc8b2a/oslo/torch/nn/parallel/tensor_parallel/mapping.py#L43 20 | """ 21 | 22 | # TODO: make this extendable 23 | # so user can define their own mapping 24 | __MAPPING__ = { 25 | "albert-base-v2": [Column(("query", "key", "value")), Row("attention.dense")], 26 | "bloom-560m": [ 27 | Column(("mlp.dense_h_to_4h", "self_attention.query_key_value")), 28 | Row(("mlp.dense_4h_to_h", "self_attention.dense")), 29 | LMHead(("lm_head",)), 30 | ], 31 | } 32 | 33 | @staticmethod 34 | def is_column_parallel(module_name: str) -> bool: 35 | item = TensorParallelMapping._search(module_name) 36 | if item is None: 37 | return False 38 | return isinstance(item, Column) 39 | 40 | @staticmethod 41 | def is_row_parallel(module_name: str) -> bool: 42 | item = TensorParallelMapping._search(module_name) 43 | if item is None: 44 | return False 45 | return isinstance(item, Row) 46 | 47 | @staticmethod 48 | def is_lm_head(module_name: str) -> bool: 49 | item = TensorParallelMapping._search(module_name) 50 | if item is None: 51 | return False 52 | return isinstance(item, LMHead) 53 | -------------------------------------------------------------------------------- /pipegoose/nn/tensor_parallel/tensor_parallel.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from pipegoose.distributed.parallel_context import ParallelContext 7 | from pipegoose.nn.expert_parallel.layers import ExpertLayer 8 | from pipegoose.nn.parallel import Parallel 9 | from pipegoose.nn.tensor_parallel.parallelizer import ( 10 | EmbeddingParallelizer, 11 | LayerNormParallelizer, 12 | LinearParallelizer, 13 | LMHeadParallelizer, 14 | ModuleParallelizer, 15 | ) 16 | 17 | 18 | class TensorParallel(Parallel): 19 | """Turn a 🤗 transformers model into a tensor parallel model.""" 20 | 21 | PARALLELIZERS = [EmbeddingParallelizer, LinearParallelizer, LayerNormParallelizer, LMHeadParallelizer] 22 | 23 | def __init__(self, module: nn.Module, parallel_context: ParallelContext): 24 | self.module = module 25 | self.parallel_context = parallel_context 26 | 27 | @torch.no_grad() 28 | def parallelize(self) -> nn.Module: 29 | module = self.module 30 | 31 | if self.parallel_context.tensor_parallel_size > 1: 32 | # NOTE: because module.named_modules returns a leaf more than once, 33 | # this could potentially lead to the weight of a module being split 34 | # multiple times. so we filter out and retain the non-repetitive modules (leaf modules) 35 | leaf_modules = self._get_leaf_modules(module) 36 | for module_name, leaf_module in leaf_modules: 37 | parallelizer = self._find_parallelizer(module_name, leaf_module) 38 | if parallelizer is not None: 39 | parallelizer(module_name, leaf_module, module, self.parallel_context).parallelize() 40 | 41 | self._save_metadata(module, self.parallel_context) 42 | 43 | return module 44 | 45 | def _get_leaf_modules(self, model: nn.Module) -> List[Tuple[str, nn.Module]]: 46 | """Return non-expert leaf modules.""" 47 | leaf_modules = [] 48 | expert_names = [] 49 | 50 | def is_child_of_expert(module_name): 51 | # NOTE: suppose an mlp expert has name "transformer.h.0.mlp" 52 | # then its children will have names like "transformer.h.0.mlp.{child_name}" 53 | # so we can check if a module is a child of an expert by checking if its name 54 | # starts with "transformer.h.0.mlp" 55 | for expert_name in expert_names: 56 | if module_name.startswith(expert_name): 57 | return True 58 | return False 59 | 60 | for module_name, module in model.named_modules(): 61 | if isinstance(module, ExpertLayer): 62 | expert_names.append(module_name) 63 | continue 64 | 65 | # NOTE: skip leaf modules that belong to ExpertLayer 66 | if is_child_of_expert(module_name) or list(module.children()): 67 | continue 68 | 69 | leaf_modules.append((module_name, module)) 70 | 71 | return leaf_modules 72 | 73 | def _find_parallelizer(self, module_name: str, module: nn.Module) -> Optional[ModuleParallelizer]: 74 | for parallelizer in self.PARALLELIZERS: 75 | if parallelizer.is_parallelizable(module_name, module): 76 | return parallelizer 77 | return None 78 | 79 | @torch.no_grad() 80 | def deparallelize(self) -> nn.Module: 81 | for module_name, module in self.module.named_modules(): 82 | self.PARALLELIZERS[module].deparallelize(module_name, module, self.parallel_context) 83 | -------------------------------------------------------------------------------- /pipegoose/nn/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from pipegoose.constants import CHECKPOINT_PATH_NAME, CHECKPOINT_WEIGHTS_NAME 7 | from pipegoose.distributed.parallel_context import ParallelContext 8 | from pipegoose.distributed.parallel_mode import ParallelMode 9 | 10 | 11 | def from_pretrained(module: nn.Module, ckp_path: str, parallel_context: ParallelContext): 12 | """Load the weights of a pretrained parallelized model.""" 13 | tp_rank = parallel_context.get_local_rank(ParallelMode.TENSOR) 14 | pp_rank = parallel_context.get_local_rank(ParallelMode.PIPELINE) 15 | 16 | ckp_name = CHECKPOINT_WEIGHTS_NAME.format(tp_rank, pp_rank) 17 | ckp_path = os.path.join(ckp_path, ckp_name) 18 | 19 | if os.path.exists(ckp_path): 20 | state_dict = torch.load(ckp_path) 21 | module.load_state_dict(state_dict) 22 | else: 23 | raise ValueError(f"ckp_path {ckp_path} does not exist") 24 | 25 | 26 | def save_pretrained( 27 | module: nn.Module, 28 | ckp_name: str = CHECKPOINT_WEIGHTS_NAME, 29 | ckp_path: str = CHECKPOINT_PATH_NAME, 30 | parallel_context: ParallelContext = None, 31 | ): 32 | """ 33 | Save the weights of a pretrained parallelized model. 34 | 35 | NOTE: Assume that the model is already parallelized and discarded 36 | the weights of parts that a node is not responsible for. 37 | """ 38 | assert isinstance( 39 | parallel_context, ParallelContext 40 | ), f"parallel_context must be an instance of ParallelContext, got {type(parallel_context)}" 41 | 42 | tp_rank = parallel_context.get_local_rank(ParallelMode.TENSOR) 43 | pp_rank = parallel_context.get_local_rank(ParallelMode.PIPELINE) 44 | ckp_name = ckp_name.format(tp_rank, pp_rank) 45 | 46 | if os.path.isdir(ckp_path): 47 | state_dict = module.state_dict() 48 | torch.save(state_dict, os.path.join(ckp_path, ckp_name)) 49 | else: 50 | raise ValueError(f"ckp_path {ckp_path} does not exist") 51 | -------------------------------------------------------------------------------- /pipegoose/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from pipegoose.optim.zero.optim import DistributedOptimizer 2 | -------------------------------------------------------------------------------- /pipegoose/optim/base_optim.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractclassmethod 2 | 3 | 4 | class BaseDistributedOptimizer(ABC): 5 | """A base class for distributed optimizer.""" 6 | 7 | @abstractclassmethod 8 | def defaults(self): 9 | raise NotImplementedError("defaults is not implemented") 10 | 11 | @abstractclassmethod 12 | def param_groups(self): 13 | raise NotImplementedError("param_groups is not implemented") 14 | 15 | @abstractclassmethod 16 | def add_param_group(self): 17 | raise NotImplementedError("add_param_group is not implemented") 18 | 19 | @abstractclassmethod 20 | def load_state_dict(self): 21 | raise NotImplementedError("load_state_dict is not implemented") 22 | 23 | @abstractclassmethod 24 | def state_dict(self): 25 | raise NotImplementedError("state_dict is not implemented") 26 | 27 | @abstractclassmethod 28 | def step(self): 29 | raise NotImplementedError("step is not implemented") 30 | 31 | @abstractclassmethod 32 | def zero_grad(self): 33 | raise NotImplementedError("zero_grad is not implemented") 34 | -------------------------------------------------------------------------------- /pipegoose/optim/zero/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xrsrke/pipegoose/fe6bcfc2ad4d592fcb11beda41481d9ce8cfc28c/pipegoose/optim/zero/__init__.py -------------------------------------------------------------------------------- /pipegoose/optim/zero/optim.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | 3 | from pipegoose.distributed.functional import broadcast 4 | from pipegoose.distributed.parallel_context import ParallelContext 5 | from pipegoose.distributed.parallel_mode import ParallelMode 6 | from pipegoose.optim.base_optim import BaseDistributedOptimizer 7 | from pipegoose.optim.zero.sharding import OptimizerStateSharding 8 | from pipegoose.optim.zero.utils import ( 9 | copy_flatten_tensor_to_unflatten_tensors, 10 | flatten_a_list_tensor, 11 | ) 12 | 13 | 14 | class DistributedOptimizer(BaseDistributedOptimizer): 15 | """ZeRO-1 optimizer that works natively in 3D parallelism.""" 16 | 17 | def __init__(self, optim: Optimizer, parallel_context: ParallelContext): 18 | self.optim = optim 19 | self.parallel_context = parallel_context 20 | 21 | self._setup_local_optim() 22 | 23 | def _setup_local_optim(self): 24 | """Setup local optimizer.""" 25 | sharded_param_groups = OptimizerStateSharding( 26 | self.optim.param_groups, self.parallel_context, ParallelMode.DATA 27 | ).shard() 28 | ranks_in_group = self.parallel_context.get_ranks_in_group(ParallelMode.DATA) 29 | self._rank_to_param_groups = {rank: params for rank, params in zip(ranks_in_group, sharded_param_groups)} 30 | 31 | dp_local_rank = self.parallel_context.get_local_rank(ParallelMode.DATA) 32 | dp_global_rank = self.parallel_context.get_global_rank_from_local_rank(dp_local_rank, ParallelMode.DATA) 33 | self.optim.param_groups = self._rank_to_param_groups[dp_global_rank] 34 | 35 | @property 36 | def defaults(self): 37 | """Return the default hyperparameters.""" 38 | return self.optim.defaults 39 | 40 | @property 41 | def param_groups(self): 42 | """Return the parameter groups.""" 43 | return self.optim.param_groups 44 | 45 | def add_param_group(self, *args, **kwargs): 46 | """Add a new parameter group to the optimizer.""" 47 | self.optim.add_param_group(*args, **kwargs) 48 | 49 | def load_state_dict(self, *args, **kwargs): 50 | """Load the optimizer state.""" 51 | self.optim.load_state_dict(*args, **kwargs) 52 | 53 | def state_dict(self, *args, **kwargs): 54 | """Return the state of the optimizer""" 55 | return self.optim.state_dict(*args, **kwargs) 56 | 57 | def step(self, *args, **kwargs): 58 | # NOTE: each rank updates its subset of parameters using the local optimizer 59 | self.optim.step(*args, **kwargs) 60 | 61 | # NOTE: each model replicas broadcast the updated parameters to other model replicas 62 | for rank, param_groups in self._rank_to_param_groups.items(): 63 | for param_group in param_groups: 64 | flatten_params = flatten_a_list_tensor(param_group["params"]) 65 | broadcast(flatten_params, src=rank, parallel_context=self.parallel_context, parallel_mode=ParallelMode.DATA) 66 | copy_flatten_tensor_to_unflatten_tensors(flatten_params, param_group["params"]) 67 | 68 | def zero_grad(self): 69 | """Zero out gradients.""" 70 | # NOTE: we zero out the gradients of the all parameters 71 | for param_groups in self._rank_to_param_groups.values(): 72 | for param_group in param_groups: 73 | for param in param_group["params"]: 74 | if param.grad is not None: 75 | param.grad = None 76 | -------------------------------------------------------------------------------- /pipegoose/optim/zero/sharding.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Dict, List 3 | 4 | import torch 5 | 6 | from pipegoose.distributed.parallel_context import ParallelContext 7 | from pipegoose.distributed.parallel_mode import ParallelMode 8 | 9 | 10 | class OptimizerStateSharding: 11 | """ 12 | Shard optimizer parameters across parallelism dimension. 13 | 14 | NOTE: Only shard the parameters in each param groups and keep the number of param groups the same. 15 | """ 16 | 17 | def __init__( 18 | self, param_groups: List[Dict[str, torch.Tensor]], parallel_context: ParallelContext, parallel_mode: ParallelMode 19 | ): 20 | self.param_groups = param_groups 21 | self.parallel_context = parallel_context 22 | self.parallel_mode = parallel_mode 23 | 24 | def shard(self) -> List[Dict[str, torch.Tensor]]: 25 | """ 26 | Credit: https://github.com/facebookresearch/fairscale/blob/164cc0f3170b4a3951dd84dda29c3e1504ac4d6e/fairscale/optim/oss.py#L173 27 | """ 28 | world_size = self.parallel_context.get_world_size(self.parallel_mode) 29 | partition_parameters = [[] for _ in range(world_size)] 30 | sizes = [0 for _ in range(world_size)] 31 | 32 | for param_group in self.param_groups: 33 | param_lists = [[] for _ in range(world_size)] 34 | 35 | for param in param_group["params"]: 36 | # TODO: fix if the numel of more than one ranks are equal 37 | next_rank = sizes.index(min(sizes)) 38 | param_lists[next_rank].append(param) 39 | sizes[next_rank] += param.numel() 40 | 41 | for rank, params in enumerate(param_lists): 42 | partitioned_param_group = copy.copy(param_group) 43 | partitioned_param_group["params"] = params 44 | partition_parameters[rank].append(partitioned_param_group) 45 | 46 | return partition_parameters 47 | -------------------------------------------------------------------------------- /pipegoose/optim/zero/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 5 | 6 | 7 | def delete_tensor_from_memory(tensor): 8 | """ 9 | Delete a tensor from memory 10 | 11 | Args: 12 | tensor (torch.Tensor): the tensor to be deleted 13 | """ 14 | del tensor 15 | torch.cuda.empty_cache() 16 | 17 | 18 | def flatten_a_list_tensor(list: List[torch.Tensor]) -> torch.Tensor: 19 | """Flatten a list of tensors into a single tensor.""" 20 | return _flatten_dense_tensors(list) 21 | 22 | 23 | def copy_flatten_tensor_to_unflatten_tensors(flat: torch.Tensor, tensors: List[torch.Tensor]): 24 | """Copied the data in a flatten tensor to its original unflatten tensors.""" 25 | for tensor, flat_data in zip(tensors, _unflatten_dense_tensors(flat, tensors)): 26 | tensor.data.copy_(flat_data) 27 | -------------------------------------------------------------------------------- /pipegoose/partitioning/profile.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractclassmethod 2 | from typing import List 3 | 4 | import torch 5 | from torch import nn 6 | from transformers import AutoModel 7 | 8 | 9 | class ProfileStrategy(ABC): 10 | def __init__(self, module: AutoModel, device: torch.device): 11 | self.module = module 12 | self.device = device 13 | 14 | @abstractclassmethod 15 | def profile(self): 16 | raise NotImplementedError("Not implemented.") 17 | 18 | 19 | class ProfileByMemory(ProfileStrategy): 20 | """Profiles CUDA memory usage by layer.""" 21 | 22 | def profile(self, input: torch.Tensor) -> List[int]: 23 | sizes = [] 24 | input = input.to(self.device) 25 | output = input 26 | 27 | for _, layer in self.module.named_children(): 28 | layer.to(self.device) 29 | layer.train() 30 | 31 | # calculate the memory occupied by the layer's output 32 | memory_before = torch.cuda.memory_allocated(device=self.device) 33 | output = layer(output) 34 | memory_after = torch.cuda.memory_allocated(device=self.device) 35 | occupied_memory = memory_after - memory_before 36 | 37 | # now calculate the memory occupied by the layer's parameters 38 | param_memory = self._compute_param_memory(layer) 39 | total_memory = occupied_memory + param_memory 40 | 41 | sizes.append(total_memory) 42 | return sizes 43 | 44 | def _compute_param_memory(self, module: nn.Module) -> int: 45 | total_size = 0 46 | for p in module.parameters(): 47 | total_size += p.storage().size() * p.storage().element_size() 48 | 49 | return total_size 50 | -------------------------------------------------------------------------------- /pipegoose/trainer/callback.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Callback: 5 | # NOTE: add more events 6 | # NOTE: READING 7 | # + Pytorch lightning's Callback 8 | 9 | def on_fit_start(self, trainer: "pipegoose.Trainer", pl_module: nn.Module) -> None: 10 | """Called when fit begins.""" 11 | 12 | def on_fit_end(self, trainer: "pipegoose.Trainer", pl_module: nn.Module) -> None: 13 | """Called when fit ends.""" 14 | -------------------------------------------------------------------------------- /pipegoose/trainer/logger.py: -------------------------------------------------------------------------------- 1 | from pipegoose.distributed import ParallelContext 2 | 3 | 4 | class DistributedLogger: 5 | LEVELS = ["warning", ...] 6 | 7 | def __init__(self, parallel_context: ParallelContext): 8 | pass 9 | 10 | def set_level(self): 11 | pass 12 | 13 | def log(self): 14 | pass 15 | -------------------------------------------------------------------------------- /pipegoose/trainer/state.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class TrainerStatus(Enum): 5 | INITIALIZING = "initializing" 6 | RUNNING = "running" 7 | FINISHED = "finished" 8 | 9 | 10 | class TrainerStage(Enum): 11 | TRAINING = "train" 12 | VALIDATING = "validate" 13 | TESTING = "test" 14 | PREDICTING = "predict" 15 | 16 | 17 | class TrainerState(Enum): 18 | status: TrainerStatus 19 | stage: TrainerStage 20 | -------------------------------------------------------------------------------- /pipegoose/trainer/trainer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from torch import nn 4 | from torch.optim import Optimizer 5 | from torch.utils.data import DataLoader 6 | 7 | from pipegoose.distributed.parallel_context import ParallelContext 8 | from pipegoose.trainer.callback import Callback 9 | from pipegoose.trainer.logger import DistributedLogger 10 | from pipegoose.trainer.state import TrainerState 11 | 12 | 13 | class Trainer: 14 | def __init__( 15 | self, 16 | module: nn.Module, 17 | train_loader: DataLoader, 18 | eval_loader: DataLoader, 19 | optim: Optimizer, 20 | num_epochs: int, 21 | callbacks: List[Callback] = [], 22 | loggers: List[DistributedLogger] = [], 23 | parallel_context: ParallelContext = None, 24 | ): 25 | # NOTE: based on the data_parallel_size, tensor_parallel_size, and pipeline_parallel_size 26 | # in the parallel_context, we do the correspond parallel model. 27 | self.state = TrainerState() 28 | 29 | def fit(self): 30 | # NOTE: both train and validation 31 | pass 32 | 33 | def train(self): 34 | # NOTE: only train 35 | pass 36 | -------------------------------------------------------------------------------- /pipegoose/utils/memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_tensor_storage_mem_loc(tensor: torch.Tensor) -> int: 5 | """Return the memory location of the tensor storage.""" 6 | return tensor.storage().data_ptr() 7 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pipegoose" 3 | version = "0.2.0" 4 | description = "" 5 | authors = ["xrsrke "] 6 | license = "MIT" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.9" 11 | torch = "^2.0.1" 12 | transformers = "^4.30.2" 13 | datasets = "^2.13.0" 14 | torchtyping = "^0.1.4" 15 | mkdocs = "^1.4.3" 16 | mkdocstrings = "^0.22.0" 17 | mkdocs-material = "^9.1.16" 18 | pytest-order = "^1.1.0" 19 | einops = "^0.6.1" 20 | 21 | 22 | [tool.poetry.group.dev.dependencies] 23 | pytest = "^7.3.2" 24 | pre-commit = "^3.3.3" 25 | wandb = "^0.15.12" 26 | 27 | [build-system] 28 | requires = ["poetry-core"] 29 | build-backend = "poetry.core.masonry.api" 30 | -------------------------------------------------------------------------------- /tests/convergence/test_dp.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xrsrke/pipegoose/fe6bcfc2ad4d592fcb11beda41481d9ce8cfc28c/tests/convergence/test_dp.py -------------------------------------------------------------------------------- /tests/convergence/test_pp.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xrsrke/pipegoose/fe6bcfc2ad4d592fcb11beda41481d9ce8cfc28c/tests/convergence/test_pp.py -------------------------------------------------------------------------------- /tests/convergence/test_tp.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xrsrke/pipegoose/fe6bcfc2ad4d592fcb11beda41481d9ce8cfc28c/tests/convergence/test_tp.py -------------------------------------------------------------------------------- /tests/core/bucket/test_bucket.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from pipegoose.core.bucket.bucket import Bucket 5 | from pipegoose.core.bucket.exception import BucketClosedError, BucketFullError 6 | 7 | 8 | def test_add_a_tensor_to_bucket(): 9 | BUCKET_SIZE = 1024 10 | DTYPE = torch.float32 11 | 12 | tensor = torch.randn(2, 4, dtype=DTYPE) 13 | TENSOR_STORAGE = tensor.storage() 14 | 15 | bucket = Bucket(BUCKET_SIZE, DTYPE) 16 | 17 | assert bucket.size == BUCKET_SIZE 18 | assert bucket.dtype == DTYPE 19 | assert bucket.available_size == BUCKET_SIZE 20 | assert len(bucket) == 0 21 | assert bucket.is_full is False 22 | 23 | new_tensor = bucket.add_tensor(tensor) 24 | 25 | assert isinstance(new_tensor, torch.Tensor) 26 | assert torch.equal(new_tensor, tensor) 27 | assert bucket.available_size == BUCKET_SIZE - new_tensor.numel() 28 | assert len(bucket) == 1 29 | # NOTE: the new tensor should be stored in the same storage as the bucket 30 | assert new_tensor.storage().data_ptr() == bucket.storage().data_ptr() 31 | # NOTE: the new tensor should have a different storage from the original tensor 32 | # since it's stored in the bucket 33 | assert new_tensor.storage().data_ptr() != TENSOR_STORAGE.data_ptr() 34 | 35 | 36 | def test_add_tensor_that_larger_than_bucket_size(): 37 | BUCKET_SIZE = 1024 38 | DTYPE = torch.float32 39 | tensor = torch.randn(2, BUCKET_SIZE, dtype=DTYPE) 40 | 41 | bucket = Bucket(BUCKET_SIZE, DTYPE) 42 | 43 | with pytest.raises(Exception): 44 | bucket.add_tensor(tensor) 45 | 46 | 47 | def test_add_tensor_that_larger_than_available_space(): 48 | BUCKET_SIZE = 1024 49 | DTYPE = torch.float32 50 | tensor = torch.randn(BUCKET_SIZE - 1) 51 | redundant_tensor = torch.randn(BUCKET_SIZE, dtype=DTYPE) 52 | 53 | bucket = Bucket(BUCKET_SIZE, DTYPE) 54 | 55 | bucket.add_tensor(tensor) 56 | 57 | with pytest.raises(BucketFullError): 58 | bucket.add_tensor(redundant_tensor) 59 | 60 | 61 | def test_add_a_tensor_to_a_closed_bucket(): 62 | BUCKET_SIZE = 1024 63 | DTYPE = torch.float32 64 | tensor = torch.randn(100) 65 | 66 | bucket = Bucket(BUCKET_SIZE, DTYPE) 67 | assert bucket.is_closed is False 68 | 69 | bucket.close() 70 | 71 | with pytest.raises(BucketClosedError): 72 | bucket.add_tensor(tensor) 73 | 74 | assert bucket.is_closed is True 75 | 76 | 77 | def test_add_a_tensor_with_different_dtype_to_a_bucket(): 78 | BUCKET_SIZE = 1024 79 | DTYPE = torch.float32 80 | tensor = torch.randn(10, dtype=torch.float16) 81 | 82 | bucket = Bucket(BUCKET_SIZE, DTYPE) 83 | 84 | with pytest.raises(Exception): 85 | bucket.add_tensor(tensor) 86 | 87 | 88 | def test_flush_all_tensors_in_bucket(): 89 | BUCKET_SIZE = 1024 90 | DTYPE = torch.float32 91 | x1 = torch.randn(10, dtype=DTYPE) 92 | x2 = torch.randn(20, dtype=DTYPE) 93 | 94 | bucket = Bucket(BUCKET_SIZE, DTYPE) 95 | bucket.add_tensor(x1) 96 | bucket.add_tensor(x2) 97 | bucket.clear() 98 | 99 | assert bucket.available_size == BUCKET_SIZE 100 | assert len(bucket) == 0 101 | # NOTE: how to test whether the bucket storage is deleted? 102 | # assert get_tensor_storage_mem_loc(x1) != bucket.storage().data_ptr() 103 | # assert get_tensor_storage_mem_loc(x2) != bucket.storage().data_ptr() 104 | 105 | 106 | def test_delete_bucket_memory_storage(): 107 | pass 108 | -------------------------------------------------------------------------------- /tests/core/bucket/test_bucket_distributor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.distributed as dist 4 | 5 | from pipegoose.core.bucket.dist import BucketDistributor 6 | from pipegoose.core.bucket.utils import mb_size_to_num_elements 7 | from pipegoose.distributed.parallel_mode import ParallelMode 8 | from pipegoose.testing.utils import init_parallel_context, spawn 9 | 10 | 11 | def run_execute_a_tensor_that_larger_than_bucket_size( 12 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 13 | ): 14 | # NOTE: append a tensor that is larger than the bucket size 15 | # and then execute the operation immediately 16 | PARALLEL_MODE = ParallelMode.DATA 17 | DTYPE = torch.float32 18 | BUCKET_SIZE_MB = 0.001 19 | NUM_ELEMNETS_IN_BUCKET = mb_size_to_num_elements(BUCKET_SIZE_MB, DTYPE) 20 | EXPECTED_OUTPUT = torch.arange(NUM_ELEMNETS_IN_BUCKET * 2).sum() * data_parallel_size 21 | 22 | parallel_context = init_parallel_context( 23 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 24 | ) 25 | 26 | tensor = torch.arange(2 * NUM_ELEMNETS_IN_BUCKET, dtype=DTYPE) 27 | bucket_distributor = BucketDistributor(dist.all_reduce, BUCKET_SIZE_MB, parallel_context) 28 | bucket_distributor.execute(tensor, PARALLEL_MODE) 29 | 30 | output = tensor.sum() 31 | assert torch.equal(output, EXPECTED_OUTPUT) 32 | 33 | 34 | @pytest.mark.parametrize("data_parallel_size", [1, 2]) 35 | def test_execute_a_tensor_that_larger_than_bucket_size(data_parallel_size): 36 | TENSOR_PARALLEL_SIZE = 1 37 | PIPELINE_PARALLEL_SIZE = 1 38 | WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * data_parallel_size 39 | 40 | spawn( 41 | run_execute_a_tensor_that_larger_than_bucket_size, 42 | world_size=WORLD_SIZE, 43 | tensor_parallel_size=TENSOR_PARALLEL_SIZE, 44 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 45 | data_parallel_size=data_parallel_size, 46 | ) 47 | -------------------------------------------------------------------------------- /tests/core/bucket/test_bucket_manager.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xrsrke/pipegoose/fe6bcfc2ad4d592fcb11beda41481d9ce8cfc28c/tests/core/bucket/test_bucket_manager.py -------------------------------------------------------------------------------- /tests/core/bucket/test_bucket_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from pipegoose.core.bucket.utils import mb_size_to_num_elements 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "mb_size, dtype, expected_num_elements", 9 | [ 10 | (10, torch.int8, 10485760), 11 | (20, torch.float32, 5242880), 12 | (40, torch.float16, 20971520), 13 | ], 14 | ) 15 | def test_mb_size_to_num_elements(mb_size, dtype, expected_num_elements): 16 | num_elements = mb_size_to_num_elements(mb_size, dtype) 17 | assert num_elements == expected_num_elements 18 | -------------------------------------------------------------------------------- /tests/distributed/_initializers/test_initialize_data_parallel_group.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch.distributed as dist 3 | from utils import map_rank_to_group 4 | 5 | from pipegoose.distributed._initializers.initialize_data import ( 6 | DataParallelGroupInitializer, 7 | ) 8 | from pipegoose.distributed.parallel_mode import ParallelMode 9 | from pipegoose.testing.utils import spawn 10 | 11 | GROUPS_IN_WORLD_SIZE_1 = [0] 12 | GROUPS_IN_WORLD_SIZE_8 = [[0, 2], [1, 3], [4, 6], [5, 7], [8, 10], [9, 11], [12, 14], [13, 15]] 13 | 14 | 15 | def init_tensor_parallel_group( 16 | rank, world_size, host, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, groups 17 | ): 18 | init_method = f"tcp://{host}:{port}" 19 | expected_ranks = map_rank_to_group(rank, groups) 20 | 21 | dist.init_process_group( 22 | rank=rank, 23 | world_size=world_size, 24 | backend="gloo", 25 | init_method=init_method, 26 | ) 27 | 28 | result = DataParallelGroupInitializer( 29 | rank, 30 | world_size, 31 | tensor_parallel_size=tensor_parallel_size, 32 | pipeline_parallel_size=pipeline_parallel_size, 33 | data_parallel_size=data_parallel_size, 34 | ).init_dist_group() 35 | 36 | assert 0 <= result["local_rank"] < result["local_world_size"] 37 | assert result["local_rank"] < tensor_parallel_size 38 | 39 | assert result["local_world_size"] == tensor_parallel_size 40 | 41 | assert isinstance(result["process_group"], dist.ProcessGroup) 42 | 43 | assert result["ranks_in_group"] == expected_ranks 44 | assert dist.get_process_group_ranks(result["process_group"]) == expected_ranks 45 | assert result["parallel_mode"] == ParallelMode.DATA 46 | 47 | dist.barrier() 48 | dist.destroy_process_group(result["process_group"]) 49 | dist.barrier() 50 | dist.destroy_process_group() 51 | 52 | 53 | @pytest.mark.parametrize( 54 | "world_size, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, groups", 55 | [(1, 1, 1, 1, GROUPS_IN_WORLD_SIZE_1), (8, 2, 2, 2, GROUPS_IN_WORLD_SIZE_8)], 56 | ) 57 | def test_init_tensor_parallel_group(world_size, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, groups): 58 | spawn( 59 | init_tensor_parallel_group, 60 | world_size=world_size, 61 | host="localhost", 62 | tensor_parallel_size=tensor_parallel_size, 63 | pipeline_parallel_size=pipeline_parallel_size, 64 | data_parallel_size=data_parallel_size, 65 | groups=groups, 66 | ) 67 | -------------------------------------------------------------------------------- /tests/distributed/_initializers/test_initialize_expert_parallel_group.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch.distributed as dist 3 | from utils import map_rank_to_group 4 | 5 | from pipegoose.distributed._initializers.initialize_expert import ( 6 | ExpertDataParallelGroupInitializer, 7 | ) 8 | from pipegoose.distributed.parallel_mode import ParallelMode 9 | from pipegoose.testing.utils import spawn 10 | 11 | GROUPS_IN_WORLD_SIZE_1 = [0] 12 | GROUPS_IN_WORLD_SIZE_8 = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]] 13 | 14 | 15 | def init_tensor_parallel_group( 16 | rank, world_size, host, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, groups 17 | ): 18 | init_method = f"tcp://{host}:{port}" 19 | expected_ranks = map_rank_to_group(rank, groups) 20 | 21 | dist.init_process_group( 22 | rank=rank, 23 | world_size=world_size, 24 | backend="gloo", 25 | init_method=init_method, 26 | ) 27 | 28 | result = ExpertDataParallelGroupInitializer( 29 | rank, 30 | world_size, 31 | tensor_parallel_size=tensor_parallel_size, 32 | pipeline_parallel_size=pipeline_parallel_size, 33 | data_parallel_size=data_parallel_size, 34 | ).init_dist_group() 35 | 36 | assert 0 <= result["local_rank"] < result["local_world_size"] 37 | assert result["local_rank"] < tensor_parallel_size 38 | 39 | assert result["local_world_size"] == tensor_parallel_size 40 | 41 | assert isinstance(result["process_group"], dist.ProcessGroup) 42 | 43 | assert result["ranks_in_group"] == expected_ranks 44 | assert dist.get_process_group_ranks(result["process_group"]) == expected_ranks 45 | 46 | assert result["parallel_mode"] == ParallelMode.EXPERT_DATA 47 | 48 | dist.barrier() 49 | dist.destroy_process_group(result["process_group"]) 50 | dist.barrier() 51 | dist.destroy_process_group() 52 | 53 | 54 | @pytest.mark.parametrize( 55 | "world_size, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, groups", 56 | [(1, 1, 1, 1, GROUPS_IN_WORLD_SIZE_1), (8, 2, 2, 2, GROUPS_IN_WORLD_SIZE_8)], 57 | ) 58 | def test_init_tensor_parallel_group(world_size, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, groups): 59 | spawn( 60 | init_tensor_parallel_group, 61 | world_size=world_size, 62 | host="localhost", 63 | tensor_parallel_size=tensor_parallel_size, 64 | pipeline_parallel_size=pipeline_parallel_size, 65 | data_parallel_size=data_parallel_size, 66 | groups=groups, 67 | ) 68 | -------------------------------------------------------------------------------- /tests/distributed/_initializers/test_initialize_pipeline_parallel_group.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch.distributed as dist 3 | from utils import map_rank_to_group 4 | 5 | from pipegoose.distributed._initializers.initialize_pipeline import ( 6 | PipelineParallelGroupInitializer, 7 | ) 8 | from pipegoose.distributed.parallel_mode import ParallelMode 9 | from pipegoose.testing.utils import spawn 10 | 11 | GROUPS_IN_WORLD_SIZE_1 = [0] 12 | # TODO: is this correct? 13 | GROUPS_IN_WORLD_SIZE_8 = [[0, 4], [1, 5], [2, 6], [3, 7]] 14 | 15 | 16 | def init_tensor_parallel_group( 17 | rank, world_size, host, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, groups 18 | ): 19 | init_method = f"tcp://{host}:{port}" 20 | expected_ranks = map_rank_to_group(rank, groups) 21 | 22 | dist.init_process_group( 23 | rank=rank, 24 | world_size=world_size, 25 | backend="gloo", 26 | init_method=init_method, 27 | ) 28 | 29 | result = PipelineParallelGroupInitializer( 30 | rank, 31 | world_size, 32 | tensor_parallel_size=tensor_parallel_size, 33 | pipeline_parallel_size=pipeline_parallel_size, 34 | data_parallel_size=data_parallel_size, 35 | ).init_dist_group() 36 | 37 | assert 0 <= result["local_rank"] < result["local_world_size"] 38 | assert result["local_rank"] < tensor_parallel_size 39 | 40 | assert result["local_world_size"] == tensor_parallel_size 41 | 42 | assert isinstance(result["process_group"], dist.ProcessGroup) 43 | 44 | assert result["ranks_in_group"] == expected_ranks 45 | assert dist.get_process_group_ranks(result["process_group"]) == expected_ranks 46 | 47 | assert result["parallel_mode"] == ParallelMode.PIPELINE 48 | 49 | dist.barrier() 50 | dist.destroy_process_group(result["process_group"]) 51 | dist.barrier() 52 | dist.destroy_process_group() 53 | 54 | 55 | @pytest.mark.parametrize( 56 | "world_size, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, groups", 57 | [(1, 1, 1, 1, GROUPS_IN_WORLD_SIZE_1), (8, 2, 2, 2, GROUPS_IN_WORLD_SIZE_8)], 58 | ) 59 | def test_init_tensor_parallel_group(world_size, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, groups): 60 | spawn( 61 | init_tensor_parallel_group, 62 | world_size=world_size, 63 | host="localhost", 64 | tensor_parallel_size=tensor_parallel_size, 65 | pipeline_parallel_size=pipeline_parallel_size, 66 | data_parallel_size=data_parallel_size, 67 | groups=groups, 68 | ) 69 | -------------------------------------------------------------------------------- /tests/distributed/_initializers/test_initialize_tensor_parallel_group.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch.distributed as dist 3 | from utils import map_rank_to_group 4 | 5 | from pipegoose.distributed._initializers.initialize_tensor import ( 6 | TensorParallelGroupInitializer, 7 | ) 8 | from pipegoose.distributed.parallel_mode import ParallelMode 9 | from pipegoose.testing.utils import spawn 10 | 11 | GROUPS_IN_WORLD_SIZE_1 = [0] 12 | GROUPS_IN_WORLD_SIZE_8 = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]] 13 | 14 | 15 | def init_tensor_parallel_group( 16 | rank, world_size, host, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, groups 17 | ): 18 | init_method = f"tcp://{host}:{port}" 19 | expected_ranks = map_rank_to_group(rank, groups) 20 | 21 | dist.init_process_group( 22 | rank=rank, 23 | world_size=world_size, 24 | backend="gloo", 25 | init_method=init_method, 26 | ) 27 | 28 | result = TensorParallelGroupInitializer( 29 | rank, 30 | world_size, 31 | tensor_parallel_size=tensor_parallel_size, 32 | pipeline_parallel_size=pipeline_parallel_size, 33 | data_parallel_size=data_parallel_size, 34 | ).init_dist_group() 35 | 36 | assert 0 <= result["local_rank"] < result["local_world_size"] 37 | assert result["local_rank"] < tensor_parallel_size 38 | 39 | assert result["local_world_size"] == tensor_parallel_size 40 | 41 | assert isinstance(result["process_group"], dist.ProcessGroup) 42 | 43 | assert result["ranks_in_group"] == expected_ranks 44 | assert dist.get_process_group_ranks(result["process_group"]) == expected_ranks 45 | 46 | assert result["parallel_mode"] == ParallelMode.TENSOR 47 | 48 | dist.barrier() 49 | dist.destroy_process_group(result["process_group"]) 50 | dist.barrier() 51 | dist.destroy_process_group() 52 | 53 | 54 | @pytest.mark.parametrize( 55 | "world_size, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, groups", 56 | [(1, 1, 1, 1, GROUPS_IN_WORLD_SIZE_1), (8, 2, 2, 2, GROUPS_IN_WORLD_SIZE_8)], 57 | ) 58 | def test_init_tensor_parallel_group(world_size, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, groups): 59 | spawn( 60 | init_tensor_parallel_group, 61 | world_size=world_size, 62 | host="localhost", 63 | tensor_parallel_size=tensor_parallel_size, 64 | pipeline_parallel_size=pipeline_parallel_size, 65 | data_parallel_size=data_parallel_size, 66 | groups=groups, 67 | ) 68 | -------------------------------------------------------------------------------- /tests/distributed/_initializers/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | def map_rank_to_group(rank: int, groups: List[int]) -> List[int]: 5 | if len(groups) == 1: 6 | return groups 7 | else: 8 | rank_to_group = {r: g for g in groups for r in g} 9 | return rank_to_group[rank] 10 | -------------------------------------------------------------------------------- /tests/distributed/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pipegoose.distributed.parallel_context import ParallelContext 4 | 5 | 6 | @pytest.fixture(scope="session") 7 | def parallel_context(): 8 | TENSOR_PARALLEL_SIZE = 1 9 | PIPELINE_PARALLEL_SIZE = 1 10 | DATA_PARALLEL_SIZE = 1 11 | SEED = 69 12 | RANK = 0 13 | WORLD_SIZE = 1 14 | HOST = "localhost" 15 | PORT = 12355 16 | 17 | parallel_context = ParallelContext( 18 | rank=RANK, 19 | local_rank=RANK, 20 | world_size=WORLD_SIZE, 21 | local_world_size=WORLD_SIZE, 22 | host=HOST, 23 | port=PORT, 24 | backend="gloo", 25 | seed=SEED, 26 | tensor_parallel_size=TENSOR_PARALLEL_SIZE, 27 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 28 | data_parallel_size=DATA_PARALLEL_SIZE, 29 | ) 30 | 31 | return parallel_context 32 | -------------------------------------------------------------------------------- /tests/distributed/test_p2p.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pipegoose.distributed.functional import recv, send 4 | from pipegoose.distributed.parallel_mode import ParallelMode 5 | from pipegoose.testing.utils import init_parallel_context, spawn 6 | 7 | 8 | def run_p2p(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size): 9 | parallel_context = init_parallel_context( 10 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 11 | ) 12 | rank = parallel_context.get_local_rank(parallel_mode=ParallelMode.PIPELINE) 13 | 14 | data = torch.tensor([1.0, 2.0, 3.0], requires_grad=True, dtype=torch.float) 15 | send(data, src=0, dst=1, parallel_context=parallel_context) 16 | 17 | received_data = recv(src=0, dst=1, parallel_context=parallel_context) 18 | 19 | if rank == 1: 20 | assert torch.allclose(data, received_data) 21 | assert received_data.requires_grad == data.requires_grad 22 | assert received_data.dtype == data.dtype 23 | else: 24 | assert received_data is None 25 | 26 | 27 | def test_send_recv_p2p(): 28 | TENSOR_PARALLEL_SIZE = 1 29 | PIPELINE_PARALLEL_SIZE = 2 30 | DATA_PARALLEL_SIZE = 1 31 | 32 | WORLD_SIZE = TENSOR_PARALLEL_SIZE * DATA_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE 33 | 34 | spawn( 35 | run_p2p, 36 | world_size=WORLD_SIZE, 37 | tensor_parallel_size=TENSOR_PARALLEL_SIZE, 38 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 39 | data_parallel_size=DATA_PARALLEL_SIZE, 40 | ) 41 | -------------------------------------------------------------------------------- /tests/distributed/test_parallel_mode.py: -------------------------------------------------------------------------------- 1 | from pipegoose.distributed.parallel_mode import ParallelMode 2 | 3 | 4 | def test_parallel_mode(): 5 | assert hasattr(ParallelMode, "GLOBAL") 6 | assert hasattr(ParallelMode, "TENSOR") 7 | assert hasattr(ParallelMode, "PIPELINE") 8 | assert hasattr(ParallelMode, "DATA") 9 | assert hasattr(ParallelMode, "EXPERT_DATA") 10 | 11 | assert ParallelMode.GLOBAL == ParallelMode.GLOBAL 12 | assert ParallelMode.GLOBAL != ParallelMode.TENSOR 13 | -------------------------------------------------------------------------------- /tests/distributed/test_rpc.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pytest 4 | import torch 5 | import torch.distributed.rpc as rpc 6 | 7 | from pipegoose.distributed.parallel_context import ParallelContext 8 | from pipegoose.testing.utils import spawn 9 | 10 | skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") 11 | 12 | backend = ["gloo", pytest.param("nccl", marks=skip_if_no_cuda)] 13 | 14 | 15 | RPC_RECEIVE_QUEUE = list() 16 | 17 | 18 | def recv_rpc_call(value): 19 | tensor = torch.Tensor(value) 20 | RPC_RECEIVE_QUEUE.append(tensor) 21 | 22 | 23 | def run_send_rcv_rpc( 24 | rank, world_size, seed, backend, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, rpc_type 25 | ): 26 | VALUE = 69 27 | 28 | RPC_TYPE_TO_FUNC = {"rpc_sync": rpc.rpc_sync, "rpc_async": rpc.rpc_async} 29 | rpc_func = RPC_TYPE_TO_FUNC[rpc_type] 30 | 31 | parallel_context = ParallelContext( 32 | rank=rank, 33 | local_rank=rank, 34 | world_size=world_size, 35 | local_world_size=world_size, 36 | host="localhost", 37 | port=port, 38 | seed=seed, 39 | backend=backend, 40 | tensor_parallel_size=tensor_parallel_size, 41 | pipeline_parallel_size=pipeline_parallel_size, 42 | data_parallel_size=data_parallel_size, 43 | ) 44 | 45 | assert isinstance(parallel_context.get_worker_name(rank), str) 46 | 47 | if world_size > 1: 48 | assert rpc._is_current_rpc_agent_set() is True 49 | 50 | if rank == 0: 51 | tensor = torch.tensor(VALUE) 52 | 53 | fut = rpc_func(to=parallel_context.get_worker_name(rank=1), func=recv_rpc_call, args=(tensor,)) 54 | 55 | if rpc_func == rpc.rpc_async: 56 | fut.wait() 57 | 58 | else: 59 | while len(RPC_RECEIVE_QUEUE) < 1: 60 | time.sleep(0.1) 61 | 62 | tensor = RPC_RECEIVE_QUEUE.pop() 63 | 64 | assert tensor == VALUE 65 | 66 | parallel_context.destroy() 67 | 68 | if world_size > 1: 69 | assert rpc._is_current_rpc_agent_set() is False 70 | 71 | 72 | @pytest.mark.parametrize("rpc_type", ["rpc_sync", "rpc_async"]) 73 | def test_send_rcv_rpc(rpc_type): 74 | TENSOR_PARALLEL_SIZE = 1 75 | PIPELINE_PARALLEL_SIZE = 2 76 | DATA_PARALLEL_SIZE = 1 77 | 78 | SEED = 69 79 | BACKEND = "gloo" 80 | 81 | spawn( 82 | run_send_rcv_rpc, 83 | world_size=PIPELINE_PARALLEL_SIZE, 84 | seed=SEED, 85 | backend=BACKEND, 86 | tensor_parallel_size=TENSOR_PARALLEL_SIZE, 87 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 88 | data_parallel_size=DATA_PARALLEL_SIZE, 89 | rpc_type=rpc_type, 90 | ) 91 | -------------------------------------------------------------------------------- /tests/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xrsrke/pipegoose/fe6bcfc2ad4d592fcb11beda41481d9ce8cfc28c/tests/nn/__init__.py -------------------------------------------------------------------------------- /tests/nn/data_parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xrsrke/pipegoose/fe6bcfc2ad4d592fcb11beda41481d9ce8cfc28c/tests/nn/data_parallel/__init__.py -------------------------------------------------------------------------------- /tests/nn/expert_parallel/test_expert_context.py: -------------------------------------------------------------------------------- 1 | from pipegoose.nn.expert_parallel.expert_context import ExpertContext 2 | 3 | 4 | def test_expert_context(): 5 | expert_context = ExpertContext.get_instance() 6 | 7 | expert_context.push_aux_loss(1.01) 8 | expert_context.push_z_loss(2.01) 9 | 10 | expert_context.push_aux_loss(1.02) 11 | expert_context.push_z_loss(2.02) 12 | 13 | # make sure that we have a singleton! 14 | expert_context = ExpertContext.get_instance() 15 | 16 | assert expert_context.pop_all_aux_loss() == [1.01, 1.02] 17 | assert expert_context.pop_all_aux_loss() == [] 18 | 19 | assert expert_context.pop_all_z_loss() == [2.01, 2.02] 20 | assert expert_context.pop_all_z_loss() == [] 21 | -------------------------------------------------------------------------------- /tests/nn/expert_parallel/test_expert_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from pipegoose.nn.expert_parallel import ExpertLoss 6 | from pipegoose.nn.expert_parallel.expert_context import ExpertContext 7 | 8 | 9 | def test_expert_loss(): 10 | torch.manual_seed(42) 11 | logits = torch.randn((10, 5)) 12 | gt = torch.randn((10, 5)) 13 | 14 | loss_func = nn.MSELoss() 15 | 16 | expert_loss = ExpertLoss(loss_func, aux_weight=0.1, z_weight=0.2) 17 | expert_context = ExpertContext.get_instance() 18 | 19 | assert expert_loss.aux_weight == 0.1 20 | assert expert_loss.z_weight == 0.2 21 | assert expert_loss.loss_func == loss_func 22 | assert expert_loss.aux_loss == [] 23 | assert expert_loss.z_loss == [] 24 | 25 | expert_context.push_aux_loss(1.01) 26 | expert_context.push_z_loss(2.01) 27 | 28 | expert_context.push_aux_loss(1.02) 29 | expert_context.push_z_loss(2.02) 30 | 31 | assert expert_loss.aux_loss == [1.01, 1.02] 32 | assert expert_loss.z_loss == [2.01, 2.02] 33 | 34 | expected_loss = F.mse_loss(logits, gt) + 0.1 * (1.01 + 1.02) + 0.2 * (2.01 + 2.02) 35 | loss = expert_loss(logits, gt) 36 | 37 | assert torch.allclose(loss, expected_loss) 38 | 39 | assert expert_context.aux_loss == [] 40 | assert expert_context.z_loss == [] 41 | -------------------------------------------------------------------------------- /tests/nn/expert_parallel/test_expert_parallel_mapping.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers import AutoModelForCausalLM 3 | 4 | from pipegoose.nn.expert_parallel.parallel_mapping import ExpertParallelMapping 5 | 6 | MODEL_NAME = "Muennighoff/bloom-tiny-random" 7 | 8 | 9 | @pytest.fixture(scope="session") 10 | def model(): 11 | return AutoModelForCausalLM.from_pretrained(MODEL_NAME) 12 | 13 | 14 | @pytest.mark.skip(reason="Not implemented yet.") 15 | def test_is_mlp_mapping(model): 16 | BLOOM_MLP_NAME = "transformer.h.{}.mlp" 17 | mappings = {} 18 | 19 | for name, _ in model.named_modules(): 20 | # if "transformer.h.0.mlp" in name: 21 | # assert 1 == 1 22 | 23 | mappings[name] = ExpertParallelMapping.is_mlp(name) 24 | 25 | for layer_idx in range(len(model.transformer.h)): 26 | assert mappings[BLOOM_MLP_NAME.format(layer_idx)] is True 27 | -------------------------------------------------------------------------------- /tests/nn/expert_parallel/test_expert_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pipegoose.nn.expert_parallel.utils import get_num_local_experts 4 | from pipegoose.testing.utils import init_parallel_context, spawn 5 | 6 | 7 | def run_get_num_local_experts( 8 | rank, 9 | world_size, 10 | port, 11 | tensor_parallel_size, 12 | pipeline_parallel_size, 13 | data_parallel_size, 14 | num_experts, 15 | ref_num_local_experts, 16 | ): 17 | parallel_context = init_parallel_context( 18 | rank, 19 | world_size, 20 | port, 21 | tensor_parallel_size, 22 | pipeline_parallel_size, 23 | data_parallel_size, 24 | ) 25 | 26 | num_local_experts = get_num_local_experts(num_experts, parallel_context) 27 | 28 | assert num_local_experts == ref_num_local_experts 29 | 30 | 31 | @pytest.mark.parametrize( 32 | "tensor_parallel_size, num_experts, expected", 33 | [ 34 | (1, 16, 16), 35 | (2, 16, 8), 36 | (4, 16, 4), 37 | (8, 16, 2), 38 | ], 39 | ) 40 | def test_get_num_local_experts(tensor_parallel_size, num_experts, expected): 41 | DATA_PARALLEL_SIZE = 1 42 | PIPELINE_PARALLEL_SIZE = 1 43 | WORLD_SIZE = tensor_parallel_size * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE 44 | 45 | spawn( 46 | run_get_num_local_experts, 47 | world_size=WORLD_SIZE, 48 | tensor_parallel_size=tensor_parallel_size, 49 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 50 | data_parallel_size=DATA_PARALLEL_SIZE, 51 | num_experts=num_experts, 52 | ref_num_local_experts=expected, 53 | ) 54 | -------------------------------------------------------------------------------- /tests/nn/expert_parallel/test_experts.py: -------------------------------------------------------------------------------- 1 | def test_experts(): 2 | # NOTE: test memory difference 3 | pass 4 | -------------------------------------------------------------------------------- /tests/nn/expert_parallel/test_layers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from pipegoose.distributed.functional import all_reduce 6 | from pipegoose.distributed.parallel_mode import ParallelMode 7 | from pipegoose.nn.expert_parallel.layers import ExpertLayer 8 | from pipegoose.testing.utils import count_model_parameters, init_parallel_context, spawn 9 | from pipegoose.nn.expert_parallel.routers import RouterOutput 10 | 11 | 12 | class DummyRouter: 13 | def __init__(self, num_experts): 14 | self.num_experts = num_experts 15 | 16 | def __call__(self, inputs): 17 | n_tokens = inputs.shape[0] * inputs.shape[1] 18 | return RouterOutput( 19 | torch.randint(0, self.num_experts, (n_tokens,)), 20 | None, 21 | None, 22 | None 23 | ) 24 | 25 | 26 | def run_expert_layer( 27 | rank, 28 | world_size, 29 | port, 30 | tensor_parallel_size, 31 | pipeline_parallel_size, 32 | data_parallel_size, 33 | inputs, 34 | num_experts, 35 | expert, 36 | router, 37 | enable_tensor_parallel, 38 | ): 39 | parallel_context = init_parallel_context( 40 | rank, 41 | world_size, 42 | port, 43 | tensor_parallel_size, 44 | pipeline_parallel_size, 45 | data_parallel_size, 46 | ) 47 | 48 | torch.manual_seed(42) 49 | torch.cuda.manual_seed(42) 50 | 51 | expert_layer = ExpertLayer( 52 | num_experts, 53 | expert, 54 | router, 55 | enable_tensor_parallel, 56 | parallel_context 57 | ) 58 | 59 | local_param_count = count_model_parameters(expert_layer) 60 | total_param_count = all_reduce( 61 | torch.tensor(local_param_count), parallel_context=parallel_context, parallel_mode=ParallelMode.TENSOR 62 | ) 63 | assert total_param_count == count_model_parameters(expert) * num_experts 64 | assert all(isinstance(x, type(expert)) for x in expert_layer.experts) 65 | 66 | outputs = expert_layer(inputs) 67 | 68 | assert outputs.shape == inputs.shape 69 | assert not (outputs == 0).all(dim=-1).any(), "There is at least one input embedding that doesn't go through any experts." 70 | 71 | 72 | @pytest.mark.parametrize("tensor_parallel_size, num_experts", [(1, 1), (2, 2), (2, 4), (8, 8)]) 73 | @pytest.mark.parametrize("enable_tensor_parallel", [False]) 74 | def test_expert_layer(tensor_parallel_size, num_experts, enable_tensor_parallel): 75 | PIPELINE_PARALLEL_SIZE = 1 76 | DATA_PARALLEL_SIZE = 1 77 | WORLD_SIZE = tensor_parallel_size * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE 78 | BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE = 5, 10, 64 79 | 80 | inputs = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE) 81 | expert = nn.Sequential( 82 | nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE * 4), 83 | nn.ReLU(), 84 | nn.Linear(HIDDEN_SIZE * 4, HIDDEN_SIZE), 85 | ) 86 | router = DummyRouter(num_experts) 87 | 88 | spawn( 89 | run_expert_layer, 90 | world_size=WORLD_SIZE, 91 | tensor_parallel_size=tensor_parallel_size, 92 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 93 | data_parallel_size=DATA_PARALLEL_SIZE, 94 | inputs=inputs, 95 | num_experts=num_experts, 96 | expert=expert, 97 | router=router, 98 | enable_tensor_parallel=enable_tensor_parallel, 99 | ) 100 | -------------------------------------------------------------------------------- /tests/nn/expert_parallel/test_routers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from pipegoose.nn.expert_parallel import SwitchNoisePolicy, Top1Router, Top2Router 5 | 6 | 7 | def run_topk_router(router, batch_size, seq_len, d_model, num_experts, top_k): 8 | router.train() 9 | 10 | input = torch.randn(batch_size, seq_len, d_model, requires_grad=True) 11 | 12 | router_output = router(input) 13 | 14 | assert router_output.dispatching_order.shape == (batch_size * seq_len, num_experts) 15 | assert router_output.weight.shape == (batch_size * seq_len, num_experts) 16 | assert router_output.aux_loss.shape == () 17 | assert router_output.z_loss.shape == () 18 | 19 | total_tokens = batch_size * seq_len 20 | 21 | if hasattr(router, "_expert_capacity") and router.expert_capacity: 22 | expert_capacity = router._expert_capacity(total_tokens) 23 | 24 | for expert_id in range(num_experts): 25 | assert router_output.dispatching_order[..., expert_id].sum().item() < expert_capacity 26 | 27 | for token_id in range(total_tokens): 28 | assert router_output.dispatching_order[token_id, ...].sum().item() <= top_k 29 | 30 | else: 31 | for token_id in range(total_tokens): 32 | assert router_output.dispatching_order[token_id, ...].sum().item() == top_k 33 | 34 | # test backward pass 35 | 36 | target_weight = torch.randn_like(router_output.weight) # Random target for testing 37 | 38 | loss = router_output.aux_loss + router_output.z_loss 39 | loss += F.mse_loss(router_output.weight, target_weight) 40 | 41 | loss.backward() 42 | 43 | # check the gradients 44 | assert input.grad is not None, "Input gradient should not be None" 45 | assert not torch.all(input.grad == 0), "Input gradient should not be all zeros" 46 | for param in router.parameters(): 47 | assert param.grad is not None, "Parameter gradient should not be None" 48 | assert not torch.all(param.grad == 0), "Parameter gradient should not be all zeros" 49 | 50 | 51 | def test_top1_router(): 52 | NUM_EXPERTS = 5 53 | BATCH_SIZE, SEQ_LEN, D_MODEL = 5, 10, 64 54 | 55 | noise_policy = SwitchNoisePolicy() 56 | top1_router = Top1Router(noise_policy, NUM_EXPERTS, D_MODEL) 57 | 58 | run_topk_router(top1_router, BATCH_SIZE, SEQ_LEN, D_MODEL, NUM_EXPERTS, top_k=1) 59 | 60 | 61 | def test_top1_router_with_expert_capacity(): 62 | NUM_EXPERTS = 5 63 | BATCH_SIZE, SEQ_LEN, D_MODEL = 5, 10, 64 64 | 65 | noise_policy = SwitchNoisePolicy() 66 | top1_router = Top1Router(noise_policy, NUM_EXPERTS, D_MODEL, expert_capacity=(1.0, 2.0)) 67 | 68 | run_topk_router(top1_router, BATCH_SIZE, SEQ_LEN, D_MODEL, NUM_EXPERTS, top_k=1) 69 | 70 | 71 | def test_top2_router(): 72 | NUM_EXPERTS = 5 73 | BATCH_SIZE, SEQ_LEN, D_MODEL = 5, 10, 64 74 | 75 | noise_policy = SwitchNoisePolicy() 76 | top2_router = Top2Router(noise_policy, NUM_EXPERTS, D_MODEL) 77 | 78 | run_topk_router(top2_router, BATCH_SIZE, SEQ_LEN, D_MODEL, NUM_EXPERTS, top_k=2) 79 | 80 | 81 | def test_top2_router_with_expert_capacity(): 82 | NUM_EXPERTS = 5 83 | BATCH_SIZE, SEQ_LEN, D_MODEL = 5, 10, 64 84 | 85 | noise_policy = SwitchNoisePolicy() 86 | top2_router = Top2Router(noise_policy, NUM_EXPERTS, D_MODEL, expert_capacity=(1.0, 2.0)) 87 | 88 | run_topk_router(top2_router, BATCH_SIZE, SEQ_LEN, D_MODEL, NUM_EXPERTS, top_k=2) 89 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/conftest.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import pytest 4 | import torch 5 | from torch import nn 6 | 7 | from pipegoose.nn.pipeline_parallel._job.creator import create_job 8 | from pipegoose.nn.pipeline_parallel._job.job_type import JobType 9 | from pipegoose.nn.pipeline_parallel._package import Metadata, Package, TrainingMetadata 10 | from pipegoose.testing.utils import init_parallel_context 11 | 12 | # NOTE: it should be compatible to perform 13 | # matrix multiplication with the job's function 14 | INPUT_SHAPE = ( 15 | 4, 16 | 2, 17 | ) 18 | LINEAR_SHAPE = ( 19 | 2, 20 | 4, 21 | ) 22 | 23 | 24 | @pytest.fixture 25 | def training_info(): 26 | return { 27 | "n_microbatches": 5, 28 | } 29 | 30 | 31 | @pytest.fixture(scope="session") 32 | def parallel_context(): 33 | TENSOR_PARALLEL_SIZE = 1 34 | PIPELINE_PARALLEL_SIZE = 1 35 | DATA_PARALLEL_SIZE = 1 36 | RANK = 0 37 | WORLD_SIZE = 1 38 | PORT = 12355 39 | 40 | parallel_context = init_parallel_context( 41 | rank=RANK, 42 | world_size=WORLD_SIZE, 43 | port=PORT, 44 | tensor_parallel_size=TENSOR_PARALLEL_SIZE, 45 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 46 | data_parallel_size=DATA_PARALLEL_SIZE, 47 | ) 48 | 49 | return parallel_context 50 | 51 | 52 | @pytest.fixture(scope="session") 53 | def pipeline_context(parallel_context): 54 | from pipegoose.nn.pipeline_parallel.pipeline_context import PipelineContext 55 | from pipegoose.nn.pipeline_parallel.scheduler import SchedulerType, get_scheduler 56 | 57 | # N_PARTITIONS = 3 58 | N_PARTITIONS = parallel_context.pipeline_parallel_size 59 | N_MICROBATCHES = 5 60 | 61 | scheduler = get_scheduler(SchedulerType.GPIPE)(N_MICROBATCHES, N_PARTITIONS) 62 | pipeline_context = PipelineContext( 63 | scheduler=scheduler, 64 | parallel_context=parallel_context, 65 | ) 66 | 67 | return pipeline_context 68 | 69 | 70 | @pytest.fixture 71 | def base_package(): 72 | MICROBATCH_IDX = 0 73 | PARTITION_IDX = 0 74 | IS_TRAINING = True 75 | IS_GRAD_ENABLED = True 76 | 77 | # NOTE: this is the package of an input microbatch 78 | SRC = 0 79 | DST = 1 80 | 81 | data = torch.randn(*INPUT_SHAPE, requires_grad=IS_GRAD_ENABLED) 82 | 83 | metadata = Metadata( 84 | microbatch_idx=MICROBATCH_IDX, 85 | partition_idx=PARTITION_IDX, 86 | job_type=JobType.FORWARD, 87 | training=TrainingMetadata( 88 | is_training=IS_TRAINING, 89 | is_grad_enabled=IS_GRAD_ENABLED, 90 | ), 91 | src=SRC, 92 | dst=DST, 93 | ) 94 | 95 | return Package(data, metadata) 96 | 97 | 98 | @pytest.fixture 99 | def forward_package(base_package): 100 | # NOTE: package for forward job 101 | base_package.metadata.job_type = JobType.FORWARD 102 | return base_package 103 | 104 | 105 | @pytest.fixture(scope="function") 106 | def backward_package(base_package): 107 | from pipegoose.nn.pipeline_parallel.queue import ( 108 | save_input_activations, 109 | save_output_activations, 110 | ) 111 | 112 | backward_package = deepcopy(base_package) 113 | backward_package.metadata.src = 2 114 | backward_package.metadata.dst = 1 115 | backward_package.metadata.job_type = JobType.BACKWARD 116 | 117 | MICROBATCH_IDX = backward_package.metadata.microbatch_idx 118 | PARTITION_IDX = backward_package.metadata.partition_idx 119 | 120 | input = torch.randn(*INPUT_SHAPE, requires_grad=True) 121 | save_input_activations(input, MICROBATCH_IDX, PARTITION_IDX) 122 | 123 | linear = nn.Linear(*LINEAR_SHAPE) 124 | output = linear(input) 125 | INITIAL_GRADS = torch.ones_like(output) 126 | 127 | # NOTE: stores the output activations that the backward job 128 | # will use to compute the gradients 129 | save_output_activations(output, MICROBATCH_IDX, PARTITION_IDX) 130 | 131 | backward_package.data = torch.ones_like(INITIAL_GRADS) 132 | 133 | return backward_package 134 | 135 | 136 | @pytest.fixture(scope="function") 137 | def backward_job(backward_package, parallel_context, pipeline_context): 138 | def function(): 139 | def backward_function(*args, **kwargs): 140 | return torch.randn(1) 141 | 142 | return backward_function 143 | 144 | job = create_job(function, backward_package, parallel_context, pipeline_context) 145 | return job 146 | 147 | 148 | @pytest.fixture 149 | def forward_function(): 150 | return nn.Linear(*LINEAR_SHAPE) 151 | 152 | 153 | @pytest.fixture(scope="function") 154 | def forward_job(forward_package, forward_function): 155 | from pipegoose.nn.pipeline_parallel._job.forward import ForwardJob 156 | 157 | # return create_job(function, forward_package, parallel_context, pipeline_context) 158 | return ForwardJob(forward_function, forward_package) 159 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/job/test_callback.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pipegoose.nn.pipeline_parallel._job.callback import Callback 4 | from pipegoose.nn.pipeline_parallel._job.job import Job 5 | 6 | 7 | # NOTE: We don't want to rely on the behavior of other jobs, like forward job 8 | # and backward job. so we create a dummy job solely to test callbacks 9 | class DummyJob(Job): 10 | def run_compute(self): 11 | return self.function(self.input.data) 12 | 13 | 14 | @pytest.fixture 15 | def job(forward_package): 16 | def function(*args, **kwargs): 17 | pass 18 | 19 | return DummyJob(function, forward_package) 20 | 21 | 22 | @pytest.fixture(scope="function") 23 | def cbs(): 24 | class Callback1(Callback): 25 | pass 26 | 27 | class Callback2(Callback): 28 | pass 29 | 30 | class Callback3(Callback): 31 | pass 32 | 33 | return [Callback1(), Callback2, Callback3] 34 | 35 | 36 | def test_a_callback_access_job_attributes(job): 37 | QUEUE = [] 38 | 39 | class AccessJobAttributesCallback(Callback): 40 | def after_compute(self): 41 | QUEUE.append(self.job.key) 42 | 43 | job.add_cb(AccessJobAttributesCallback) 44 | job.compute() 45 | 46 | assert len(QUEUE) == 1 47 | assert QUEUE == [job.key] 48 | 49 | 50 | def test_run_callbacks_by_order(job): 51 | QUEUE = [] 52 | 53 | class Callback1(Callback): 54 | order = 0 55 | 56 | def after_compute(self): 57 | QUEUE.append(1) 58 | 59 | class Callback2(Callback): 60 | order = 1 61 | 62 | def after_compute(self): 63 | QUEUE.append(2) 64 | 65 | class Callback3(Callback): 66 | order = 2 67 | 68 | def after_compute(self): 69 | QUEUE.append(3) 70 | 71 | job.add_cbs([Callback3, Callback1, Callback2]) 72 | job.compute() 73 | 74 | assert QUEUE == [1, 2, 3] 75 | 76 | 77 | def test_create_and_run_a_callback(job): 78 | QUEUE = [] 79 | 80 | class AddToQueueCallback(Callback): 81 | def after_compute(self): 82 | QUEUE.append(69) 83 | 84 | cb = AddToQueueCallback() 85 | 86 | assert isinstance(cb.order, int) 87 | 88 | job.add_cb(cb) 89 | job.compute() 90 | 91 | assert QUEUE == [69] 92 | 93 | 94 | def test_add_and_remove_a_callback(job): 95 | class ToyCallback(Callback): 96 | pass 97 | 98 | N_ORIG_CBS = len(job.cbs) 99 | cb = ToyCallback() 100 | 101 | job.add_cb(cb) 102 | assert len(job.cbs) == 1 + N_ORIG_CBS 103 | 104 | job.remove_cb(cb) 105 | assert len(job.cbs) == N_ORIG_CBS 106 | 107 | 108 | def test_add_and_remove_a_list_of_callback(job, cbs): 109 | N_ORIG_CBS = len(job.cbs) 110 | 111 | job.add_cbs(cbs) 112 | assert len(job.cbs) == 3 + N_ORIG_CBS 113 | 114 | job.remove_cbs(cbs) 115 | assert len(job.cbs) == N_ORIG_CBS 116 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/job/test_creator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch.distributed as dist 3 | from torch import nn 4 | 5 | from pipegoose.distributed.parallel_mode import ParallelMode 6 | from pipegoose.nn.pipeline_parallel._job.backward import BackwardJob 7 | from pipegoose.nn.pipeline_parallel._job.creator import create_job 8 | from pipegoose.nn.pipeline_parallel._job.forward import ForwardJob 9 | from pipegoose.nn.pipeline_parallel._job.job import JobStatus 10 | from pipegoose.nn.pipeline_parallel.sync.handshake import ProgressTracker 11 | from pipegoose.nn.pipeline_parallel.sync.progress_tracker import ( 12 | get_progresses_from_pipeline_context, 13 | ) 14 | from pipegoose.testing.utils import init_pipeline_context, spawn 15 | 16 | # NOTE: use for creating a forward job 17 | function = nn.Linear(2, 4) 18 | 19 | 20 | # @pytest.mark.parametrize("package", ["forward_package", "backward_package"]) 21 | @pytest.mark.skip 22 | def test_backward_job(backward_package, parallel_context, pipeline_context): 23 | # package = request.getfixturevalue(package) 24 | job = create_job(function, backward_package, parallel_context, pipeline_context) 25 | 26 | job.compute() 27 | 28 | assert job.status == JobStatus.EXECUTED 29 | 30 | 31 | def run_create_a_job_from_package( 32 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, package, job_cls 33 | ): 34 | MASTER_RANK = 0 35 | pipeline_context, parallel_context = init_pipeline_context( 36 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 37 | ) 38 | pipeline_context.forward() 39 | tracker = ProgressTracker(MASTER_RANK, parallel_context=parallel_context, parallel_mode=ParallelMode.GLOBAL) 40 | progresses = get_progresses_from_pipeline_context(pipeline_context) 41 | tracker.initiate(progresses) 42 | 43 | dist.barrier() 44 | 45 | job = create_job(function, package, parallel_context, pipeline_context) 46 | 47 | assert isinstance(job, job_cls) 48 | assert isinstance(job.key, str) 49 | assert callable(job.function) is True 50 | assert job.status == JobStatus.PENDING 51 | 52 | job.compute() 53 | 54 | assert job.status == JobStatus.EXECUTED 55 | 56 | 57 | @pytest.mark.skip 58 | @pytest.mark.parametrize("package, job_cls", [("forward_package", ForwardJob), ("backward_package", BackwardJob)]) 59 | def test_create_a_job_from_package(request, package, job_cls): 60 | TENSOR_PARALLEL_SIZE = 1 61 | PIPELINE_PARALLEL_SIZE = 2 62 | DATA_PARALLEL_SIZE = 1 63 | WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE 64 | 65 | package = request.getfixturevalue(package) 66 | 67 | spawn( 68 | run_create_a_job_from_package, 69 | world_size=WORLD_SIZE, 70 | tensor_parallel_size=TENSOR_PARALLEL_SIZE, 71 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 72 | data_parallel_size=DATA_PARALLEL_SIZE, 73 | package=package, 74 | job_cls=job_cls, 75 | ) 76 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/job/test_hybrid_job.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xrsrke/pipegoose/fe6bcfc2ad4d592fcb11beda41481d9ce8cfc28c/tests/nn/pipeline_parallel/job/test_hybrid_job.py -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/job/test_job.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xrsrke/pipegoose/fe6bcfc2ad4d592fcb11beda41481d9ce8cfc28c/tests/nn/pipeline_parallel/job/test_job.py -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/job/test_register.py: -------------------------------------------------------------------------------- 1 | from queue import Queue 2 | 3 | from pipegoose.nn.pipeline_parallel._job.job import Job 4 | from pipegoose.nn.pipeline_parallel._job.register import add_job_to_queue 5 | 6 | 7 | class Dummyjob(Job): 8 | def run_compute(self): 9 | pass 10 | 11 | 12 | def test_register_job(): 13 | input = 1 14 | 15 | def function(input): 16 | return input 17 | 18 | job = Dummyjob(function, input) 19 | JOB_QUEUE = Queue() 20 | 21 | add_job_to_queue(job, JOB_QUEUE) 22 | 23 | assert JOB_QUEUE.qsize() == 1 24 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/sync/test_handshake.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | import pytest 5 | import torch 6 | 7 | from pipegoose.distributed.parallel_mode import ParallelMode 8 | from pipegoose.nn.pipeline_parallel.sync.handshake import ParallelGroupHandshake 9 | from pipegoose.testing.utils import init_parallel_context, spawn 10 | 11 | 12 | def run_parallel_group_handshake( 13 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, parallel_mode, shared_counter 14 | ): 15 | def do_random_delay(): 16 | rand_time = random.uniform(0, 3) 17 | time.sleep(rand_time) 18 | 19 | parallel_context = init_parallel_context( 20 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 21 | ) 22 | 23 | # NOTE: simulate some random delay in different ranks 24 | # before the handshake 25 | do_random_delay() 26 | 27 | handshake = ParallelGroupHandshake( 28 | parallel_context, 29 | parallel_mode=parallel_mode, 30 | ) 31 | handshake.initiate() 32 | 33 | do_random_delay() 34 | handshake.confirm() 35 | shared_counter.add_(rank) 36 | 37 | handshake.barrier() 38 | 39 | # NOTE: since each process adds its rank to the shared counter, the sum of all ranks should be equal to 40 | local_world_size = parallel_context.get_world_size(parallel_mode) 41 | assert shared_counter.item() == sum(x for x in range(local_world_size)) 42 | 43 | 44 | @pytest.mark.parametrize( 45 | "parallel_mode, tensor_parallel_size, pipeline_paralell_size, data_parallel_size", 46 | [ 47 | # (ParallelMode.GLOBAL, 2, 2, 2), 48 | # (ParallelMode.TENSOR, 2, 1, 1), 49 | (ParallelMode.PIPELINE, 1, 2, 1), 50 | # (ParallelMode.DATA, 1, 1, 2), 51 | ], 52 | ) 53 | def test_parallel_group_handshake(parallel_mode, tensor_parallel_size, pipeline_paralell_size, data_parallel_size): 54 | WORLD_SIZE = tensor_parallel_size * pipeline_paralell_size * data_parallel_size 55 | 56 | shared_counter = torch.tensor(0) 57 | shared_counter.share_memory_() 58 | 59 | spawn( 60 | run_parallel_group_handshake, 61 | world_size=WORLD_SIZE, 62 | tensor_parallel_size=tensor_parallel_size, 63 | pipeline_parallel_size=pipeline_paralell_size, 64 | data_parallel_size=data_parallel_size, 65 | parallel_mode=parallel_mode, 66 | shared_counter=shared_counter, 67 | ) 68 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/test_comm.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pytest 4 | 5 | from pipegoose.nn.pipeline_parallel._comm import RECV_QUEUE, send_package 6 | from pipegoose.nn.pipeline_parallel._package import Package 7 | from pipegoose.testing.utils import init_parallel_context, spawn 8 | 9 | 10 | def run_send_recv_package(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, package): 11 | PACKAGE_SRC_RANK = package.metadata.src 12 | PACKAGE_DST_RANK = package.metadata.dst 13 | 14 | # MICROBATCH_IDX = package.metadata.microbatch_idx 15 | # PARTITION_IDX = package.metadata.partition_idx 16 | 17 | parallel_context = init_parallel_context( 18 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 19 | ) 20 | 21 | if rank == PACKAGE_SRC_RANK: 22 | send_package(package, parallel_context=parallel_context) 23 | elif rank == PACKAGE_DST_RANK: 24 | time.sleep(1) 25 | received_package = RECV_QUEUE.get() 26 | # received_package = RECV_QUEUE[(MICROBATCH_IDX, PARTITION_IDX)] 27 | 28 | assert isinstance(received_package, Package) 29 | 30 | 31 | @pytest.mark.parametrize("pipeline_parallel_size", [2]) 32 | def test_run_send_recv_package(forward_package, pipeline_parallel_size): 33 | TENSOR_PARALLEL_SIZE = 1 34 | DATA_PARALLEL_SIZE = 1 35 | 36 | spawn( 37 | run_send_recv_package, 38 | world_size=pipeline_parallel_size, 39 | tensor_parallel_size=TENSOR_PARALLEL_SIZE, 40 | pipeline_parallel_size=pipeline_parallel_size, 41 | data_parallel_size=DATA_PARALLEL_SIZE, 42 | package=forward_package, 43 | ) 44 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/test_microbatch.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | from pipegoose.nn.pipeline_parallel import microbatch 4 | 5 | MODEL_NAME = "sshleifer/tiny-gpt2" 6 | 7 | 8 | def test_split_a_mini_batch_to_microbatches(): 9 | BATCH_SIZE = 36 10 | N_MICROBATCHES = 6 11 | 12 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) 13 | tokenizer.pad_token = tokenizer.eos_token 14 | 15 | text = "Persistence is all you need." 16 | batch_sentences = [text for _ in range(BATCH_SIZE)] 17 | inputs = tokenizer(batch_sentences, padding=True, return_tensors="pt") 18 | 19 | microbatches = microbatch.split(inputs, n_microbatches=N_MICROBATCHES) 20 | 21 | assert isinstance(microbatches, list) 22 | assert len(microbatches) == N_MICROBATCHES 23 | assert all(set(batch.keys()) == set(inputs.keys()) for batch in microbatches) is True 24 | 25 | total_sentences = sum(microbatch["input_ids"].size(0) for microbatch in microbatches) 26 | assert total_sentences == BATCH_SIZE 27 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/test_package.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from pipegoose.nn.pipeline_parallel._job.job_type import JobType 5 | from pipegoose.nn.pipeline_parallel._package import Metadata, Package, TrainingMetadata 6 | 7 | 8 | @pytest.mark.parametrize("job_type", [JobType.FORWARD, JobType.BACKWARD]) 9 | def test_package(job_type): 10 | MICROBATCH_IDX = 1 11 | PARTITION_IDX = 2 12 | IS_TRAINING = True 13 | IS_GRAD_ENABLED = False 14 | 15 | SRC = 0 16 | DST = 1 17 | 18 | data = torch.randn(2, 4) 19 | metadata = Metadata( 20 | microbatch_idx=MICROBATCH_IDX, 21 | partition_idx=PARTITION_IDX, 22 | job_type=job_type, 23 | training=TrainingMetadata( 24 | is_training=IS_TRAINING, 25 | is_grad_enabled=IS_GRAD_ENABLED, 26 | ), 27 | src=SRC, 28 | dst=DST, 29 | ) 30 | 31 | package = Package(data, metadata) 32 | 33 | assert package.metadata.microbatch_idx == MICROBATCH_IDX 34 | assert package.metadata.partition_idx == PARTITION_IDX 35 | 36 | assert package.metadata.job_type == job_type 37 | assert package.metadata.training.is_training == IS_TRAINING 38 | assert package.metadata.training.is_grad_enabled == IS_GRAD_ENABLED 39 | 40 | assert package.metadata.src == SRC 41 | assert package.metadata.dst == DST 42 | 43 | assert torch.equal(package.data, data) 44 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/test_partitioner.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from transformers import ( 4 | AutoTokenizer, 5 | BloomConfig, 6 | BloomForCausalLM, 7 | GPT2Config, 8 | GPT2LMHeadModel, 9 | ) 10 | 11 | from pipegoose.nn.pipeline_parallel.partitioner import UniformPartitioner 12 | from pipegoose.testing.utils import init_parallel_context, spawn 13 | 14 | 15 | def get_gpt2_and_tokenizer(): 16 | model = GPT2LMHeadModel(config=GPT2Config(n_layer=6)) 17 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 18 | return model, tokenizer 19 | 20 | 21 | def get_bloom_and_tokenizer_with_6_layers(): 22 | return BloomForCausalLM(BloomConfig(n_layer=6)), AutoTokenizer.from_pretrained("bigscience/bloom-560m") 23 | 24 | 25 | # TODO: Also add a function for a generic nn.Transformer model 26 | def run_model_partitioner( 27 | rank, 28 | world_size, 29 | port, 30 | tensor_parallel_size, 31 | pipeline_parallel_size, 32 | data_parallel_size, 33 | model_retrieval_func, 34 | ): 35 | parallel_context = init_parallel_context( 36 | rank, 37 | world_size, 38 | port, 39 | tensor_parallel_size, 40 | pipeline_parallel_size, 41 | data_parallel_size, 42 | ) 43 | 44 | torch.manual_seed(0) 45 | batch_sentences = ["hello world from pipegoose"] 46 | model, tokenizer = model_retrieval_func() 47 | model.eval() 48 | tokenizer.pad_token = tokenizer.eos_token 49 | inputs = tokenizer(batch_sentences, padding=True, return_tensors="pt") 50 | gt_logits = model(**inputs).logits 51 | 52 | partitioned_model = UniformPartitioner(model, parallel_context).split() 53 | assert ( 54 | len(partitioned_model) == pipeline_parallel_size 55 | ), f"Received model with {len(partitioned_model)} instead of {pipeline_parallel_size}" 56 | 57 | print("Start printing partitioned model") 58 | for i, shard in enumerate(partitioned_model): 59 | shard_param_count = 0 60 | print("==================") 61 | print(f"Shard {i + 1}") 62 | for _, module in shard.named_children(): 63 | # Sum the parameters of each module in the shard 64 | shard_param_count += sum(p.numel() for p in module.parameters()) 65 | print(f"Layer type: {type(module).__name__}") 66 | print(module) 67 | print(f"Total parameters in Shard {i + 1}: {shard_param_count}") 68 | print("==================") 69 | print("End printing partitioned model") 70 | 71 | partitioned_model_result = inputs 72 | for partition_id in range(pipeline_parallel_size): 73 | if type(partitioned_model_result) in (list, tuple): 74 | partitioned_model_result = partitioned_model[partition_id](*partitioned_model_result) 75 | else: 76 | partitioned_model_result = partitioned_model[partition_id](**partitioned_model_result) 77 | 78 | assert torch.allclose(gt_logits, partitioned_model_result), "Results are not close" 79 | 80 | 81 | @pytest.mark.parametrize("pipeline_parallel_size", [2, 3, 4, 5, 6]) 82 | @pytest.mark.parametrize( 83 | "model_retrieval_func", 84 | [ 85 | get_gpt2_and_tokenizer, 86 | get_bloom_and_tokenizer_with_6_layers, 87 | ], 88 | ) 89 | def test_naive_partitioning(pipeline_parallel_size, model_retrieval_func): 90 | TENSOR_PARALLEL_SIZE = 1 91 | DATA_PARALLEL_SIZE = 1 92 | print( 93 | f"Running test with pipeline_parallel_size={pipeline_parallel_size}, tensor_parallel_size={TENSOR_PARALLEL_SIZE}, data_parallel_size={DATA_PARALLEL_SIZE}" 94 | ) 95 | spawn( 96 | run_model_partitioner, 97 | world_size=pipeline_parallel_size, 98 | tensor_parallel_size=TENSOR_PARALLEL_SIZE, 99 | pipeline_parallel_size=pipeline_parallel_size, 100 | data_parallel_size=DATA_PARALLEL_SIZE, 101 | model_retrieval_func=model_retrieval_func, 102 | ) 103 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/test_pipeline_engine.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from functools import reduce 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from pipegoose.nn.pipeline_parallel._utils import get_partition_idx, is_last_stage 8 | from pipegoose.nn.pipeline_parallel._worker import WorkerManager 9 | from pipegoose.nn.pipeline_parallel.pipeline_engine import PipelineEngine 10 | from pipegoose.nn.pipeline_parallel.scheduler import GPipeScheduler 11 | from pipegoose.testing.utils import init_parallel_context, spawn 12 | 13 | 14 | def run_pipeline_engine( 15 | rank, 16 | world_size, 17 | port, 18 | tensor_parallel_size, 19 | pipeline_parallel_size, 20 | data_parallel_size, 21 | n_microbatches, 22 | model, 23 | inputs, 24 | ref_outputs, 25 | ref_grads, 26 | ): 27 | parallel_context = init_parallel_context( 28 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 29 | ) 30 | partition_idx = get_partition_idx(parallel_context) 31 | partition = model[partition_idx] 32 | scheduler = GPipeScheduler(n_microbatches, pipeline_parallel_size) 33 | worker_manager = WorkerManager() 34 | pipeline_engine = PipelineEngine( 35 | module=partition, 36 | scheduler=scheduler, 37 | worker_manager=worker_manager, 38 | parallel_context=parallel_context, 39 | ) 40 | outputs = pipeline_engine.run(inputs) 41 | 42 | if is_last_stage(parallel_context): 43 | assert torch.allclose(torch.cat(outputs, dim=0), ref_outputs) 44 | 45 | for output in outputs: 46 | output.sum().backward(retain_graph=True) 47 | 48 | for p, ref_grad in zip(partition.parameters(), ref_grads[partition_idx]): 49 | assert p.grad is not None 50 | assert torch.allclose(p.grad, ref_grad) 51 | 52 | 53 | def test_pipeline_engine(): 54 | TENSOR_PARALLEL_SIZE = 1 55 | PIPELINE_PARALLEL_SIZE = 4 56 | DATA_PARALLEL_SIZE = 1 57 | 58 | BATCH_SIZE = 32 59 | N_MICROBATCHES = 6 60 | SEQ_LEN = 10 61 | HIDDEN_DIM = 5 62 | WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE 63 | 64 | inputs = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM, requires_grad=False) 65 | model = nn.ModuleList([nn.Sequential(nn.Linear(HIDDEN_DIM, HIDDEN_DIM), nn.ReLU()) for _ in range(PIPELINE_PARALLEL_SIZE)]) 66 | ORIG_MODEL = deepcopy(model) 67 | outputs = reduce(lambda inputs, layer: layer(inputs), model, inputs) 68 | 69 | outputs.sum().backward() 70 | 71 | grads = [[p.grad for p in layer.parameters()] for layer in model] 72 | 73 | spawn( 74 | run_pipeline_engine, 75 | world_size=WORLD_SIZE, 76 | tensor_parallel_size=TENSOR_PARALLEL_SIZE, 77 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 78 | data_parallel_size=DATA_PARALLEL_SIZE, 79 | n_microbatches=N_MICROBATCHES, 80 | model=ORIG_MODEL, 81 | inputs=inputs.detach(), 82 | ref_outputs=outputs.detach(), 83 | ref_grads=grads, 84 | ) 85 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/test_pp_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pipegoose.nn.pipeline_parallel._utils import get_partition_idx 4 | from pipegoose.testing.utils import init_parallel_context, spawn 5 | 6 | # NOTE: a mapping from global rank to partition index in pipeline parallelism 7 | # (tensor_parallel_size, pipeline_parallel_size, data_parallel_size) = {rank: partition_idx} 8 | RANK_TO_PARTITION_IDX = { 9 | (2, 4, 2): { 10 | 0: 0, 11 | 1: 0, 12 | 2: 0, 13 | 3: 0, 14 | 4: 1, 15 | 5: 1, 16 | 6: 1, 17 | 7: 1, 18 | 8: 2, 19 | 9: 2, 20 | 10: 2, 21 | 11: 2, 22 | 12: 3, 23 | 13: 3, 24 | 14: 3, 25 | 15: 3, 26 | }, 27 | (1, 4, 1): {0: 0, 1: 1, 2: 2, 3: 3}, 28 | } 29 | 30 | 31 | def run_get_partition_idx( 32 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, rank_to_partition_idx 33 | ): 34 | parallel_context = init_parallel_context( 35 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 36 | ) 37 | 38 | partition_idx = get_partition_idx(parallel_context) 39 | 40 | assert partition_idx == rank_to_partition_idx[rank] 41 | 42 | 43 | @pytest.mark.parametrize("tensor_parallel_size, pipeline_parallel_size, data_parallel_size", [(2, 4, 2), (1, 4, 1)]) 44 | def test_get_partition_idx(tensor_parallel_size, pipeline_parallel_size, data_parallel_size): 45 | world_size = tensor_parallel_size * pipeline_parallel_size * data_parallel_size 46 | 47 | rank_to_partition_idx = RANK_TO_PARTITION_IDX[(tensor_parallel_size, pipeline_parallel_size, data_parallel_size)] 48 | 49 | spawn( 50 | run_get_partition_idx, 51 | world_size=world_size, 52 | tensor_parallel_size=tensor_parallel_size, 53 | pipeline_parallel_size=pipeline_parallel_size, 54 | data_parallel_size=data_parallel_size, 55 | rank_to_partition_idx=rank_to_partition_idx, 56 | ) 57 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/test_queue.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from pipegoose.nn.pipeline_parallel.queue import ( 5 | SavedActivation, 6 | get_input_activations, 7 | get_output_activations, 8 | save_input_activations, 9 | save_output_activations, 10 | ) 11 | 12 | 13 | def test_save_and_retrieve_activations(): 14 | MICROBATCH_DIX = 1 15 | PARTITION_IDX = 0 16 | 17 | activations = torch.randn(2, 4) 18 | 19 | key = SavedActivation.get_key(MICROBATCH_DIX, PARTITION_IDX) 20 | SavedActivation.save_activations(key, activations) 21 | 22 | saved_activations = SavedActivation.get_saved_activations(key) 23 | 24 | assert torch.equal(activations, saved_activations) 25 | 26 | 27 | def test_save_and_get_output_activations(): 28 | MICROBATCH_DIX = 1 29 | PARTITION_IDX = 0 30 | 31 | input = torch.randn(2, 4) 32 | linear = nn.Linear(4, 2) 33 | output = linear(input) 34 | 35 | save_output_activations(output, MICROBATCH_DIX, PARTITION_IDX) 36 | 37 | retrieved_output = get_output_activations(MICROBATCH_DIX, PARTITION_IDX) 38 | assert torch.equal(output, retrieved_output) 39 | # NOTE: for the pipeline engine do backward, the input should require grad 40 | assert retrieved_output.requires_grad is True 41 | 42 | 43 | def test_save_and_get_input_activations(): 44 | MICROBATCH_DIX = 1 45 | PARTITION_IDX = 0 46 | 47 | input = torch.randn(2, 4) 48 | save_input_activations(input, MICROBATCH_DIX, PARTITION_IDX) 49 | 50 | retrieved_input = get_input_activations(MICROBATCH_DIX, PARTITION_IDX) 51 | assert torch.equal(input, retrieved_input) 52 | # NOTE: for the pipeline engine do backward, the input should require grad 53 | assert retrieved_input.requires_grad is True 54 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/test_scheduler.py: -------------------------------------------------------------------------------- 1 | from pipegoose.nn.pipeline_parallel._job.job_type import JobType 2 | from pipegoose.nn.pipeline_parallel.scheduler import GPipeScheduler 3 | 4 | N_MICROBATCHES = 4 5 | N_PARTITIONS = 3 6 | 7 | 8 | def test_generate_forward_and_backward_schedules_using_gpipe_scheduler(): 9 | TOTAL_CLOCK_CYCLES_IN_FORWARD = N_MICROBATCHES + N_PARTITIONS - 1 10 | TOTAL_CLOCK_CYCLES = TOTAL_CLOCK_CYCLES_IN_FORWARD * 2 11 | JOB_TYPES = [JobType.FORWARD, JobType.BACKWARD] 12 | 13 | scheduler = GPipeScheduler(N_MICROBATCHES, N_PARTITIONS) 14 | 15 | assert scheduler.total_clock_cycles == TOTAL_CLOCK_CYCLES 16 | 17 | schedules = scheduler.get_schedules() 18 | assert len(schedules) == TOTAL_CLOCK_CYCLES 19 | 20 | for tasks in schedules: 21 | assert isinstance(tasks, list) 22 | 23 | for task in tasks: 24 | assert task.job_type in JOB_TYPES 25 | assert isinstance(task.partition_idx, int) 26 | assert isinstance(task.microbatch_idx, int) 27 | 28 | 29 | def test_generate_forward_schedules_using_gpipe_scheduler(): 30 | TOTAL_CLOCK_CYCLES_IN_FORWARD = N_MICROBATCHES + N_PARTITIONS - 1 31 | 32 | scheduler = GPipeScheduler(N_MICROBATCHES, N_PARTITIONS) 33 | schedules = scheduler.get_forward_schedules() 34 | 35 | assert len(schedules) == TOTAL_CLOCK_CYCLES_IN_FORWARD 36 | 37 | for tasks in schedules: 38 | assert isinstance(tasks, list) 39 | 40 | for task in tasks: 41 | assert task.job_type == JobType.FORWARD 42 | assert isinstance(task.partition_idx, int) 43 | assert isinstance(task.microbatch_idx, int) 44 | 45 | 46 | def test_generate_backward_schedules_using_gpipe_scheduler(): 47 | TOTAL_CLOCK_CYCLES_IN_BACKWARD = N_MICROBATCHES + N_PARTITIONS - 1 48 | 49 | scheduler = GPipeScheduler(N_MICROBATCHES, N_PARTITIONS) 50 | schedules = scheduler.get_backward_schedules() 51 | 52 | assert len(schedules) == TOTAL_CLOCK_CYCLES_IN_BACKWARD 53 | 54 | for tasks in schedules: 55 | assert isinstance(tasks, list) 56 | 57 | for task in tasks: 58 | assert task.job_type == JobType.BACKWARD 59 | assert isinstance(task.partition_idx, int) 60 | assert isinstance(task.microbatch_idx, int) 61 | -------------------------------------------------------------------------------- /tests/nn/pipeline_parallel/test_worker.py: -------------------------------------------------------------------------------- 1 | from queue import Queue 2 | 3 | from pipegoose.nn.pipeline_parallel._utils import sleep 4 | from pipegoose.nn.pipeline_parallel._worker import WorkerManager 5 | 6 | NUM_WORKERS = 5 7 | MIN_WORKERS = 2 8 | MAX_WORKERS = 7 9 | 10 | 11 | def test_worker_manager(): 12 | worker_manager = WorkerManager(num_workers=NUM_WORKERS, min_workers=MIN_WORKERS, max_workers=MAX_WORKERS) 13 | worker_manager.spawn() 14 | 15 | # wait for workers to spawn 16 | sleep(1.69) 17 | 18 | assert worker_manager.num_workers == NUM_WORKERS 19 | assert worker_manager.min_workers == MIN_WORKERS 20 | assert worker_manager.max_workers == MAX_WORKERS 21 | 22 | assert len(worker_manager.worker_pool) >= MIN_WORKERS 23 | assert len(worker_manager.worker_pool) <= MAX_WORKERS 24 | assert isinstance(worker_manager.pending_jobs, Queue) 25 | assert isinstance(worker_manager.selected_jobs, Queue) 26 | 27 | # NOTE: since we don't have any jobs, all workers should be idle 28 | for worker in worker_manager.worker_pool: 29 | assert worker.is_running is False 30 | assert worker.is_alive() is True 31 | 32 | 33 | def test_destroy_worker_manager(): 34 | pass 35 | 36 | 37 | def test_execute_a_job_from_selected_job_queue(): 38 | PENDING_JOBS = Queue() 39 | SELECTED_JOBS = Queue() 40 | QUEUE = [] 41 | 42 | class FakeJob: 43 | def compute(self): 44 | QUEUE.append(1) 45 | 46 | job = FakeJob() 47 | worker_manager = WorkerManager( 48 | pending_jobs=PENDING_JOBS, 49 | selected_jobs=SELECTED_JOBS, 50 | num_workers=NUM_WORKERS, 51 | min_workers=MIN_WORKERS, 52 | max_workers=MAX_WORKERS, 53 | ) 54 | worker_manager.spawn() 55 | 56 | PENDING_JOBS.put(job) 57 | assert PENDING_JOBS.qsize() == 1 58 | assert SELECTED_JOBS.qsize() == 0 59 | 60 | # NOTE: wait for job selector picks up the job 61 | sleep(2) 62 | 63 | assert QUEUE == [1] 64 | assert PENDING_JOBS.qsize() == 0 65 | assert SELECTED_JOBS.qsize() == 0 66 | 67 | 68 | def test_construct_a_job_from_received_package(): 69 | pass 70 | 71 | 72 | def test_putting_a_job_into_the_pending_job_queue(): 73 | pass 74 | -------------------------------------------------------------------------------- /tests/nn/tensor_parallel/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers import AutoModelForCausalLM 3 | 4 | MODEL_NAME = "Muennighoff/bloom-tiny-random" 5 | 6 | 7 | @pytest.fixture(scope="session") 8 | def model(): 9 | return AutoModelForCausalLM.from_pretrained(MODEL_NAME) 10 | -------------------------------------------------------------------------------- /tests/nn/tensor_parallel/test_embedding.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import pytest 4 | import torch 5 | from torch import nn 6 | 7 | from pipegoose.distributed.parallel_mode import ParallelMode 8 | from pipegoose.nn.tensor_parallel.embedding import ParallelEmbedding 9 | from pipegoose.testing.utils import get_partition, init_parallel_context, spawn 10 | 11 | 12 | def run_parallel_embedding( 13 | rank, 14 | world_size, 15 | port, 16 | tensor_parallel_size, 17 | pipeline_parallel_size, 18 | data_parallel_size, 19 | input, 20 | output, 21 | orig_weight, 22 | ref_weight, 23 | grads, 24 | ): 25 | parallel_context = init_parallel_context( 26 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 27 | ) 28 | local_rank = parallel_context.get_local_rank(parallel_mode=ParallelMode.TENSOR) 29 | ranks_in_group = parallel_context.get_ranks_in_group(parallel_mode=ParallelMode.TENSOR) 30 | 31 | if local_rank in ranks_in_group: 32 | NUM_EMBEDDING = orig_weight.shape[0] 33 | EMBEDDING_DIM = orig_weight.shape[1] 34 | parallel_embedding = ParallelEmbedding(NUM_EMBEDDING, EMBEDDING_DIM, parallel_context=parallel_context) 35 | 36 | parallel_embedding.weight.data = get_partition(orig_weight, dim=0, parallel_context=parallel_context) 37 | parallel_output = parallel_embedding(input) 38 | 39 | assert torch.allclose(parallel_output, output) 40 | 41 | parallel_output.sum().backward() 42 | 43 | REF_GRAD = get_partition(grads, dim=0, parallel_context=parallel_context) 44 | assert torch.allclose(parallel_embedding.weight.grad.data, REF_GRAD) 45 | 46 | REF_WEIGHT = get_partition(ref_weight, dim=0, parallel_context=parallel_context) 47 | assert torch.allclose(parallel_embedding.weight.data, REF_WEIGHT) 48 | 49 | 50 | @pytest.mark.parametrize("tensor_parallel_size", [1, 2]) 51 | def test_parallel_embedding(tensor_parallel_size): 52 | PIPELINE_PARALLEL_SIZE = 1 53 | DATA_PARALLEL_SIZE = 1 54 | 55 | NUM_EMBEDDING = 100 56 | EMBEDDING_DIM = 10 57 | 58 | input = torch.randint(0, NUM_EMBEDDING, (10, 5)) 59 | embedding = nn.Embedding(NUM_EMBEDDING, EMBEDDING_DIM) 60 | weight = deepcopy(embedding.weight.data) 61 | output = embedding(input) 62 | output.sum().backward() 63 | grads = deepcopy(embedding.weight.grad.data) 64 | 65 | ref_weight = deepcopy(embedding.weight.data) 66 | 67 | spawn( 68 | run_parallel_embedding, 69 | world_size=tensor_parallel_size, 70 | tensor_parallel_size=tensor_parallel_size, 71 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 72 | data_parallel_size=DATA_PARALLEL_SIZE, 73 | input=input.detach(), 74 | output=output.detach(), 75 | orig_weight=weight.detach(), 76 | ref_weight=ref_weight, 77 | grads=grads.detach(), 78 | ) 79 | -------------------------------------------------------------------------------- /tests/nn/tensor_parallel/test_functional_.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from pipegoose.distributed.parallel_mode import ParallelMode 5 | from pipegoose.testing.utils import init_parallel_context, spawn 6 | 7 | 8 | @pytest.fixture 9 | def parallel_modes(): 10 | return [ParallelMode.GLOBAL, ParallelMode.TENSOR, ParallelMode.PIPELINE, ParallelMode.DATA] 11 | 12 | 13 | PARAMETRIZE = pytest.mark.parametrize( 14 | "world_size, tensor_parallel_size, pipeline_parallel_size, data_parallel_size", [(1, 1, 1, 1), (8, 2, 2, 2)] 15 | ) 16 | 17 | 18 | def run_parallel_test( 19 | rank, world_size, port, parallel_modes, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, test_logic 20 | ): 21 | parallel_context = init_parallel_context( 22 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 23 | ) 24 | 25 | for parallel_mode in parallel_modes: 26 | rank = parallel_context.get_local_rank(parallel_mode) 27 | ranks_in_group = parallel_context.get_ranks_in_group(parallel_mode) 28 | test_logic(rank, ranks_in_group, parallel_context, parallel_mode) 29 | 30 | parallel_context.destroy() 31 | 32 | 33 | def run_broadcast(rank, world_size, port, parallel_modes, tensor_parallel_size, pipeline_parallel_size, data_parallel_size): 34 | parallel_context = init_parallel_context( 35 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 36 | ) 37 | 38 | for parallel_mode in parallel_modes: 39 | rank = parallel_context.get_local_rank(parallel_mode) 40 | ranks_in_group = parallel_context.get_ranks_in_group(parallel_mode) 41 | 42 | if rank == ranks_in_group: 43 | src = parallel_context.get_ranks_in_group(parallel_mode)[-1] 44 | if rank == src: 45 | x = torch.tensor(6.9, dtype=torch.float32, requires_grad=True) 46 | else: 47 | x = torch.tensor(4.2, dtype=torch.float32) 48 | 49 | # Broadcast.apply(x, src=src, parallel_context=parallel_context, parallel_mode=parallel_mode) 50 | 51 | assert torch.equal(x, torch.tensor(6.9)) 52 | assert x.dtype == torch.float32 53 | assert x.requires_grad is True 54 | 55 | parallel_context.destroy() 56 | 57 | 58 | @pytest.mark.skip 59 | @pytest.mark.parametrize( 60 | "world_size, tensor_parallel_size, pipeline_parallel_size, data_parallel_size", [(1, 1, 1, 1), (8, 2, 2, 2)] 61 | ) 62 | def test_broadcast(parallel_modes, world_size, tensor_parallel_size, pipeline_parallel_size, data_parallel_size): 63 | spawn( 64 | run_broadcast, 65 | world_size=world_size, 66 | parallel_modes=parallel_modes, 67 | tensor_parallel_size=tensor_parallel_size, 68 | pipeline_parallel_size=pipeline_parallel_size, 69 | data_parallel_size=data_parallel_size, 70 | ) 71 | -------------------------------------------------------------------------------- /tests/nn/tensor_parallel/test_layer_norm.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import pytest 4 | import torch 5 | from torch import nn 6 | 7 | from pipegoose.nn.tensor_parallel.layer_norm import LayerNorm 8 | from pipegoose.testing.utils import init_parallel_context, spawn 9 | 10 | 11 | def run_layer_norm( 12 | rank, 13 | world_size, 14 | port, 15 | tensor_parallel_size, 16 | pipeline_parallel_size, 17 | data_parallel_size, 18 | normalized_shape, 19 | input, 20 | output, 21 | weight, 22 | bias, 23 | weight_grad, 24 | bias_grad, 25 | ): 26 | parallel_context = init_parallel_context( 27 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 28 | ) 29 | p_layer_norm = LayerNorm(normalized_shape=normalized_shape, parallel_context=parallel_context) 30 | p_layer_norm.weight.data = weight 31 | p_layer_norm.bias.data = bias 32 | 33 | p_output = p_layer_norm(input) 34 | 35 | assert torch.allclose(output, p_output) 36 | 37 | p_output.sum().backward() 38 | 39 | assert torch.allclose(p_layer_norm.weight.grad, weight_grad) 40 | assert torch.allclose(p_layer_norm.bias.grad, bias_grad) 41 | 42 | 43 | @pytest.mark.parametrize("tensor_parallel_size", [1, 2]) 44 | @pytest.mark.parametrize("hidden_size", [20]) 45 | @pytest.mark.parametrize("normalized_shape", [20, (20,)]) 46 | def test_layer_norm(tensor_parallel_size, hidden_size, normalized_shape): 47 | PIPELINE_PARALLEL_SIZE = 1 48 | DATA_PARALLEL_SIZE = 1 49 | 50 | BATCH_SIZE = 5 51 | SEQ_LEN = 10 52 | EPS = 1e-5 53 | 54 | input = torch.randn(BATCH_SIZE, SEQ_LEN, hidden_size, requires_grad=True) 55 | 56 | layer_norm = nn.LayerNorm(normalized_shape, eps=EPS) 57 | 58 | # NOTE: since we assign the weight and bias to the parallel layer norm 59 | # we do deepcopy to make sure if the parallel layer norm do backward pass 60 | # it won't affect the original layer norm's weight and bias 61 | weight = deepcopy(layer_norm.weight) 62 | bias = deepcopy(layer_norm.bias) 63 | output = layer_norm(input) 64 | output.sum().backward() 65 | 66 | weight_grad = deepcopy(layer_norm.weight.grad) 67 | bias_grad = deepcopy(layer_norm.bias.grad) 68 | 69 | spawn( 70 | run_layer_norm, 71 | world_size=tensor_parallel_size, 72 | tensor_parallel_size=tensor_parallel_size, 73 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 74 | data_parallel_size=DATA_PARALLEL_SIZE, 75 | normalized_shape=normalized_shape, 76 | input=input.detach(), 77 | output=output.detach(), 78 | weight=weight.detach(), 79 | bias=bias.detach(), 80 | weight_grad=weight_grad.detach(), 81 | bias_grad=bias_grad.detach(), 82 | ) 83 | -------------------------------------------------------------------------------- /tests/nn/tensor_parallel/test_loss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | from pipegoose.distributed.parallel_context import ParallelContext 7 | from pipegoose.distributed.parallel_mode import ParallelMode 8 | from pipegoose.nn.tensor_parallel.loss import VocabParallelCrossEntropy 9 | from pipegoose.testing.utils import spawn 10 | 11 | 12 | def check_equal(A, B): 13 | assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) or torch.allclose(A, B) 14 | 15 | 16 | def run_parallel_cross_entropy( 17 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, logits, targets, loss, grads 18 | ): 19 | def get_partition(logits): 20 | local_world_size = parallel_context.get_world_size(parallel_mode=ParallelMode.TENSOR) 21 | per_partition = N_LABELS // local_world_size 22 | chunks = torch.split(logits, per_partition, dim=-1) 23 | return chunks[local_rank] 24 | 25 | parallel_context = ParallelContext( 26 | rank=rank, 27 | local_rank=rank, 28 | world_size=world_size, 29 | local_world_size=world_size, 30 | host="localhost", 31 | port=port, 32 | seed=69, 33 | backend="gloo", 34 | tensor_parallel_size=tensor_parallel_size, 35 | pipeline_parallel_size=pipeline_parallel_size, 36 | data_parallel_size=data_parallel_size, 37 | ) 38 | 39 | local_rank = parallel_context.get_local_rank(parallel_mode=ParallelMode.TENSOR) 40 | ranks_in_group = parallel_context.get_ranks_in_group(parallel_mode=ParallelMode.TENSOR) 41 | 42 | if local_rank in ranks_in_group: 43 | N_LABELS = logits.shape[-1] 44 | parallel_logits = get_partition(logits) 45 | parallel_logits.requires_grad = True 46 | 47 | parallel_cross_entropy = VocabParallelCrossEntropy(parallel_context=parallel_context) 48 | parallel_loss = parallel_cross_entropy(parallel_logits, targets) 49 | 50 | assert torch.allclose(parallel_loss, loss) 51 | 52 | # parallel_loss.backward() 53 | # assert torch.allclose(parallel_logits.grad.data, get_partition(grads)) 54 | 55 | 56 | @pytest.mark.parametrize("tensor_parallel_size", [1, 2]) 57 | def test_parallel_cross_entropy(tensor_parallel_size): 58 | PIPELINE_PARALLEL_SIZE = 1 59 | DATA_PARALLEL_SIZE = 1 60 | 61 | BATCH_SIZE = 1 62 | SEQ_LEN = 2 63 | VOCAB_SIZE = 4 64 | 65 | torch.manual_seed(69) 66 | 67 | logits = torch.randn(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE, requires_grad=True) 68 | targets = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN)) 69 | 70 | loss = F.cross_entropy( 71 | rearrange(logits, "batch_size seq_len vocab_size -> (batch_size seq_len) vocab_size"), 72 | rearrange(targets, "batch_size seq_len -> (batch_size seq_len)"), 73 | ) 74 | 75 | loss.backward() 76 | grads = logits.grad.data 77 | 78 | spawn( 79 | run_parallel_cross_entropy, 80 | world_size=tensor_parallel_size, 81 | tensor_parallel_size=tensor_parallel_size, 82 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 83 | data_parallel_size=DATA_PARALLEL_SIZE, 84 | logits=logits.detach(), 85 | targets=targets, 86 | loss=loss.detach(), 87 | grads=grads.detach(), 88 | ) 89 | -------------------------------------------------------------------------------- /tests/nn/tensor_parallel/test_parallel_mapping.py: -------------------------------------------------------------------------------- 1 | from pipegoose.nn.tensor_parallel.parallel_mapping import TensorParallelMapping 2 | 3 | 4 | def test_is_column_parallel_mapping(model): 5 | BLOOM_DENSE_H_TO_4H_NAME = "transformer.h.{}.mlp.dense_h_to_4h" 6 | BLOOM_QKV_NAME = "transformer.h.{}.self_attention.query_key_value" 7 | mappings = {} 8 | 9 | for name, _ in model.named_modules(): 10 | mappings[name] = TensorParallelMapping.is_column_parallel(name) 11 | 12 | for layer_idx in range(len(model.transformer.h)): 13 | assert mappings[BLOOM_DENSE_H_TO_4H_NAME.format(layer_idx)] is True 14 | assert mappings[BLOOM_QKV_NAME.format(layer_idx)] is True 15 | 16 | 17 | def test_is_row_parallel_mapping(model): 18 | BLOOM_DENSE_4H_TO_H_NAME = "transformer.h.{}.mlp.dense_4h_to_h" 19 | BLOOM_ATTENTION_DENSE_NAME = "transformer.h.{}.self_attention.dense" 20 | 21 | mappings = {} 22 | 23 | for name, _ in model.named_modules(): 24 | mappings[name] = TensorParallelMapping.is_row_parallel(name) 25 | 26 | for layer_idx in range(len(model.transformer.h)): 27 | assert ( 28 | mappings[BLOOM_DENSE_4H_TO_H_NAME.format(layer_idx)] is True 29 | ), f"{BLOOM_DENSE_4H_TO_H_NAME.format(layer_idx)} is not row parallelized" 30 | assert ( 31 | mappings[BLOOM_ATTENTION_DENSE_NAME.format(layer_idx)] is True 32 | ), f"{BLOOM_ATTENTION_DENSE_NAME.format(layer_idx)} is not row parallelized" 33 | 34 | 35 | def test_is_lm_head_mapping(model): 36 | BLOOM_LM_HEAD_NAME = "lm_head" 37 | 38 | mappings = {} 39 | 40 | for name, _ in model.named_modules(): 41 | mappings[name] = TensorParallelMapping.is_lm_head(name) 42 | 43 | assert mappings[BLOOM_LM_HEAD_NAME] is True, f"{BLOOM_LM_HEAD_NAME} is not language model head" 44 | -------------------------------------------------------------------------------- /tests/nn/tensor_parallel/test_utils.py: -------------------------------------------------------------------------------- 1 | from pipegoose.nn.tensor_parallel._utils import VocabUtility 2 | 3 | 4 | def test_get_vocab_range_from_global_vocab_size_from_vocab_utility(): 5 | world_size = 2 6 | rank = 1 7 | vocab_size = 10 8 | 9 | vocab_start_idx, vocab_end_idx = VocabUtility.get_vocab_range_from_global_vocab_size(world_size, rank, vocab_size) 10 | 11 | assert vocab_start_idx == 5 12 | assert vocab_end_idx == 10 13 | 14 | 15 | def test_get_vocab_range_from_partition_size_from_vocab_utility(): 16 | rank = 1 17 | partition_size = 5 18 | 19 | vocab_start_idx, vocab_end_idx = VocabUtility.get_vocab_range_idx_from_partition_size(partition_size, rank) 20 | 21 | assert vocab_start_idx == 5 22 | assert vocab_end_idx == 10 23 | -------------------------------------------------------------------------------- /tests/nn/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from torch import nn 5 | 6 | from pipegoose.constants import CHECKPOINT_WEIGHTS_NAME 7 | from pipegoose.distributed.parallel_mode import ParallelMode 8 | from pipegoose.nn.utils import from_pretrained, save_pretrained 9 | from pipegoose.testing.utils import init_parallel_context, spawn 10 | 11 | 12 | class SimpleModel(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | self.fc = nn.Linear(10, 2) 16 | 17 | def forward(self, x): 18 | return self.fc(x) 19 | 20 | 21 | def run_save_and_load_pretrained(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size): 22 | # TODO: add automatically create the directory if it does not exist 23 | # and then delete it after the test is done 24 | 25 | CHECKPOINT_WEIGHTS_PATH = "." if os.getenv("GITHUB_ACTIONS") == "true" else "./downloads" 26 | 27 | def zero_weights(m): 28 | """Sets all model weights to zero.""" 29 | if isinstance(m, nn.Module): 30 | if hasattr(m, "weight"): 31 | nn.init.constant_(m.weight, 0) 32 | if hasattr(m, "bias"): 33 | nn.init.constant_(m.bias, 0) 34 | 35 | for layer in m.children(): 36 | zero_weights(layer) 37 | 38 | parallel_context = init_parallel_context( 39 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 40 | ) 41 | 42 | model = SimpleModel() 43 | 44 | save_pretrained(model, ckp_path=CHECKPOINT_WEIGHTS_PATH, parallel_context=parallel_context) 45 | 46 | tp_rank = parallel_context.get_local_rank(ParallelMode.TENSOR) 47 | pp_rank = parallel_context.get_local_rank(ParallelMode.PIPELINE) 48 | 49 | assert os.path.exists(os.path.join(CHECKPOINT_WEIGHTS_PATH, CHECKPOINT_WEIGHTS_NAME.format(tp_rank, pp_rank))) 50 | 51 | zero_weights(model) 52 | assert model.fc.weight.sum() == 0 53 | 54 | from_pretrained(model, ckp_path=CHECKPOINT_WEIGHTS_PATH, parallel_context=parallel_context) 55 | assert model.fc.weight.sum() != 0 56 | 57 | 58 | @pytest.mark.parametrize("tensor_parallel_size, pipeline_parallel_size, data_parallel_size", [(1, 1, 1), (2, 2, 2)]) 59 | def test_save_and_load_pretrained(tensor_parallel_size, pipeline_parallel_size, data_parallel_size): 60 | world_size = tensor_parallel_size * pipeline_parallel_size * data_parallel_size 61 | spawn( 62 | run_save_and_load_pretrained, 63 | world_size=world_size, 64 | tensor_parallel_size=tensor_parallel_size, 65 | pipeline_parallel_size=pipeline_parallel_size, 66 | data_parallel_size=data_parallel_size, 67 | ) 68 | -------------------------------------------------------------------------------- /tests/optim/zero/test_optim.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import pytest 4 | import torch 5 | from torch import nn, optim 6 | 7 | from pipegoose.optim import DistributedOptimizer 8 | from pipegoose.testing.utils import init_parallel_context, spawn 9 | 10 | 11 | def count_parameters(optimizer): 12 | return sum(p.numel() for group in optimizer.param_groups for p in group["params"]) 13 | 14 | 15 | def run_dist_optim( 16 | rank, 17 | world_size, 18 | port, 19 | tensor_parallel_size, 20 | pipeline_parallel_size, 21 | data_parallel_size, 22 | input, 23 | model, 24 | updated_model, 25 | ref_grads, 26 | optimizer, 27 | ): 28 | ORIG_UPDATED_MODEL = deepcopy(updated_model) 29 | ORIG_OPTIM = deepcopy(optimizer) 30 | REF_GRADS = ref_grads 31 | 32 | parallel_context = init_parallel_context( 33 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 34 | ) 35 | optimizer = optim.Adam(model.parameters()) 36 | dist_optimizer = DistributedOptimizer(optimizer, parallel_context) 37 | 38 | assert dist_optimizer.defaults == optimizer.defaults 39 | assert len(dist_optimizer.param_groups) == len(optimizer.param_groups) 40 | assert dist_optimizer.state_dict().keys() == optimizer.state_dict().keys() 41 | # NOTE: test whether the optimizer partitions the parameters across data parallel dimension 42 | assert count_parameters(dist_optimizer) < count_parameters(ORIG_OPTIM) 43 | 44 | dist_optimizer.zero_grad() 45 | model(input).sum().backward() 46 | dist_optimizer.step() 47 | 48 | # NOTE: test whether the model parameters are updated correctly 49 | for p1, p2 in zip(model.parameters(), ORIG_UPDATED_MODEL.parameters()): 50 | assert torch.allclose(p1, p2), f"p1: {p1}, p2: {p2}" 51 | 52 | # NOTE: make sure the optimizer keep the gradients after .step() 53 | # it's up to the user to call .zero_grad() or not 54 | # NOTE: dist_grads just means the gradients of the model parameters 55 | dist_grads = [p.grad for p in model.parameters()] 56 | for p1, p2 in zip(dist_grads, REF_GRADS): 57 | assert p1 is not None 58 | assert torch.allclose(p1, p2), f"p1: {p1}, p2: {p2}" 59 | 60 | dist_optimizer.zero_grad() 61 | 62 | for p in model.parameters(): 63 | assert p.grad is None 64 | 65 | 66 | @pytest.mark.parametrize("data_parallel_size", [2, 4]) 67 | def test_dist_optim(data_parallel_size): 68 | TENSOR_PARALLEL_SIZE = 1 69 | PIPELINE_PARALLEL_SIZE = 1 70 | WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * data_parallel_size 71 | 72 | BATCH_SIZE = 500 73 | HIDDEN_SIZE = 1000 74 | OUTPUT_SIZE = 100 75 | 76 | input = torch.randn(BATCH_SIZE, HIDDEN_SIZE) 77 | model = nn.Sequential(nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE), nn.ReLU(), nn.Linear(HIDDEN_SIZE, OUTPUT_SIZE)) 78 | ORIG_INPUT = deepcopy(input) 79 | ORIG_MODEL = deepcopy(model) 80 | optimizer = optim.Adam(model.parameters()) 81 | optimizer.zero_grad() 82 | 83 | model(input).sum().backward() 84 | optimizer.step() 85 | REF_GRADS = [p.grad for p in model.parameters()] 86 | 87 | spawn( 88 | run_dist_optim, 89 | world_size=WORLD_SIZE, 90 | tensor_parallel_size=TENSOR_PARALLEL_SIZE, 91 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 92 | data_parallel_size=data_parallel_size, 93 | input=ORIG_INPUT, 94 | model=ORIG_MODEL, 95 | updated_model=model, 96 | ref_grads=REF_GRADS, 97 | optimizer=optimizer, 98 | ) 99 | -------------------------------------------------------------------------------- /tests/optim/zero/test_optim_utils.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | import torch 4 | 5 | from pipegoose.optim.zero.utils import ( 6 | copy_flatten_tensor_to_unflatten_tensors, 7 | flatten_a_list_tensor, 8 | ) 9 | 10 | 11 | def test_flatten_a_list_tensor(): 12 | tensor_list = [torch.rand(2, 3) for _ in range(5)] 13 | 14 | flat_tensor = flatten_a_list_tensor(tensor_list) 15 | 16 | assert flat_tensor.numel() == sum(t.numel() for t in tensor_list) 17 | assert flat_tensor.shape == (sum(t.numel() for t in tensor_list),) 18 | 19 | original_elements = torch.tensor(list(chain.from_iterable(t.tolist() for t in tensor_list))) 20 | assert torch.equal(original_elements.view(-1), flat_tensor) 21 | 22 | 23 | def test_copy_flatten_tensor_to_unflatten_tensors(): 24 | tensor_list = [torch.rand(2, 3) for _ in range(5)] 25 | flat_tensor = flatten_a_list_tensor(tensor_list) 26 | new_tensor_list = [torch.randn_like(t) for t in tensor_list] 27 | 28 | copy_flatten_tensor_to_unflatten_tensors(flat_tensor, new_tensor_list) 29 | 30 | for original, copied in zip(tensor_list, new_tensor_list): 31 | assert torch.equal(original, copied) 32 | -------------------------------------------------------------------------------- /tests/optim/zero/test_sharding.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import pytest 4 | from torch import nn 5 | from torch.optim import SGD 6 | from transformers import AutoModel 7 | 8 | from pipegoose.distributed.parallel_mode import ParallelMode 9 | from pipegoose.optim.zero.sharding import OptimizerStateSharding 10 | from pipegoose.testing.utils import init_parallel_context, spawn 11 | 12 | 13 | def run_optimizer_states_sharding( 14 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, model 15 | ): 16 | def calculate_total_sharded_elements(sharded_params): 17 | total = 0 18 | num_params_per_partition = [] 19 | for param_groups in sharded_params: 20 | local_total = 0 21 | for param_group in param_groups: 22 | for param in param_group["params"]: 23 | local_total += param.numel() 24 | 25 | num_params_per_partition.append(local_total) 26 | total += local_total 27 | return total, num_params_per_partition 28 | 29 | parallel_context = init_parallel_context( 30 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 31 | ) 32 | world_size = parallel_context.get_world_size(ParallelMode.DATA) 33 | 34 | ORIG_MODEL = deepcopy(model) 35 | optim = SGD(model.parameters(), lr=0.01) 36 | param_groups = optim.param_groups 37 | 38 | sharder = OptimizerStateSharding(param_groups, parallel_context, ParallelMode.DATA) 39 | sharded_params = sharder.shard() 40 | 41 | assert len(sharded_params) == world_size 42 | 43 | for rank, shard in enumerate(sharded_params): 44 | assert isinstance(shard, list) 45 | for param_group in shard: 46 | assert len(param_group["params"]) > 0 47 | for param in param_group["params"]: 48 | assert isinstance(param, nn.Parameter) 49 | 50 | # NOTE: each rank, expect to have the same number of parameter groups 51 | assert len(shard) == len(optim.param_groups) 52 | 53 | total_elements = sum(param.numel() for param in ORIG_MODEL.parameters()) 54 | total_sharded_elements, num_params_per_partition = calculate_total_sharded_elements(sharded_params) 55 | assert total_sharded_elements == total_elements 56 | # NOTE: each partition, expect to have less than the total number of parameters 57 | for num_param in num_params_per_partition: 58 | assert num_param < total_elements 59 | 60 | 61 | @pytest.mark.parametrize("data_parallel_size", [2, 5]) 62 | def test_optimizer_states_sharding(data_parallel_size): 63 | model = AutoModel.from_pretrained("gpt2") 64 | 65 | TENSOR_PARALLEL_SIZE = 1 66 | PIPELINE_PARALLEL_SIZE = 1 67 | WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * data_parallel_size 68 | 69 | spawn( 70 | run_optimizer_states_sharding, 71 | world_size=WORLD_SIZE, 72 | tensor_parallel_size=TENSOR_PARALLEL_SIZE, 73 | pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, 74 | data_parallel_size=data_parallel_size, 75 | model=model, 76 | ) 77 | -------------------------------------------------------------------------------- /tests/partitioning/test_profile.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from pipegoose.partitioning.profile import ProfileByMemory 6 | 7 | skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.") 8 | 9 | 10 | # @skip_if_no_cuda 11 | @pytest.mark.skip("consider remove this module") 12 | def test_profile_by_memory(): 13 | # model = AutoModel.from_pretrained("gpt2") 14 | # tokenizer = AutoTokenizer.from_pretrained("gpt2") 15 | # text = [ 16 | # "Persistence is all you need.", 17 | # "Persistence is all you need.", 18 | # "Persistence is all you need." 19 | # ] 20 | # token_ids = tokenizer(text, return_tensors="pt")["input_ids"] 21 | 22 | model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)]) 23 | sample = torch.rand(7, 1) 24 | 25 | NUM_MODULES = sum(1 for _ in model.children()) 26 | 27 | profiler = ProfileByMemory(model, torch.device("cpu")) 28 | sizes = profiler.profile(sample) 29 | 30 | assert len(sizes) == NUM_MODULES 31 | assert all(isinstance(size, int) for size in sizes) 32 | -------------------------------------------------------------------------------- /tests/test_hybrid.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import pytest 4 | import torch 5 | from torch.optim import Adam 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | 8 | from pipegoose.nn.tensor_parallel.tensor_parallel import TensorParallel 9 | from pipegoose.testing.utils import ( 10 | get_partition, 11 | init_parallel_context, 12 | skip_in_github_actions, 13 | spawn, 14 | ) 15 | 16 | MODEL_NAME = "bigscience/bloom-560m" 17 | 18 | 19 | def run_hybrid_parallelism(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, kwargs): 20 | parallel_context = init_parallel_context( 21 | rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size 22 | ) 23 | 24 | model = kwargs["model"] 25 | model = TensorParallel(model, parallel_context).parallelize() 26 | # model = DataParallel(model, parallel_context).parallelize() 27 | 28 | optim = Adam(model.parameters()) 29 | # dist_optim = DistributedOptimizer(optim, parallel_context) 30 | 31 | output = model(**kwargs["input"], labels=kwargs["labels"]) 32 | loss = output.loss 33 | 34 | optim.zero_grad() 35 | loss.backward() 36 | optim.step() 37 | 38 | for p1, p2 in zip(model.parameters(), kwargs["updated_model"].parameters()): 39 | assert torch.allclose(p1, get_partition(p2, dim=0, parallel_context=parallel_context), rtol=1e-1) 40 | 41 | 42 | @skip_in_github_actions 43 | @pytest.mark.parametrize("tensor_parallel_size", [2]) 44 | @pytest.mark.parametrize("pipeline_parallel_size", [1]) 45 | @pytest.mark.parametrize("data_parallel_size", [1]) 46 | def test_hybrid_parallelism(tensor_parallel_size, pipeline_parallel_size, data_parallel_size): 47 | WORLD_SIZE = tensor_parallel_size * pipeline_parallel_size * data_parallel_size 48 | 49 | model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) 50 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) 51 | optim = Adam(model.parameters()) 52 | ORIG_MODEL = deepcopy(model) 53 | 54 | text = "Persistence is all you need." 55 | input = tokenizer(text, return_tensors="pt") 56 | labels = input["input_ids"] 57 | 58 | output = model(**input, labels=labels) 59 | loss = output.loss 60 | optim.zero_grad() 61 | loss.backward() 62 | optim.step() 63 | 64 | kwargs = { 65 | "model": ORIG_MODEL, 66 | "updated_model": model, 67 | "input": input, 68 | "labels": labels, 69 | } 70 | 71 | spawn( 72 | run_hybrid_parallelism, 73 | world_size=WORLD_SIZE, 74 | tensor_parallel_size=tensor_parallel_size, 75 | pipeline_parallel_size=pipeline_parallel_size, 76 | data_parallel_size=data_parallel_size, 77 | kwargs=kwargs, 78 | ) 79 | --------------------------------------------------------------------------------