├── .coveragerc ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── .ruff.toml ├── LICENSE ├── README.md ├── _test ├── README.md ├── __init__.py ├── conftest.py ├── core │ ├── __init__.py │ ├── test_hook_function.py │ ├── test_hook_function_group_manager.py │ └── test_wrapper.py ├── distributed │ ├── __init__.py │ └── test_initialize.py ├── multi_gpu │ ├── __init__.py │ ├── core │ │ ├── __init__.py │ │ ├── test_dp_wrappers_vs_single_gpu.py │ │ └── test_fairscale_mpu_vs_ddp.py │ ├── distributed │ │ ├── __init__.py │ │ ├── test_initialize.py │ │ └── test_mappings.py │ ├── registry.py │ ├── run_multi_gpu_tests_slurm.py │ └── testing_utils.py └── traverse │ ├── __init__.py │ ├── test_nodes.py │ └── test_ops.py ├── demos ├── README.md ├── __init__.py ├── example_plots │ ├── induction_loss.png │ └── induction_score_by_head.png └── induction_heads_multigpu.py ├── docs ├── Makefile ├── make.bat └── source │ ├── conf.py │ ├── demo_links │ └── induction.rst │ ├── demos.rst │ ├── example_links │ ├── fsdp_example.rst │ ├── megatron_example.rst │ └── single_gpu_example.rst │ ├── examples.rst │ ├── flex_model.core.rst │ ├── flex_model.distributed.rst │ ├── flex_model.rst │ ├── flex_model.traverse.rst │ ├── index.rst │ ├── intro.rst │ └── modules.rst ├── examples ├── README.md ├── __init__.py ├── fsdp_example.py ├── megatron_example.py └── single_gpu_example.py ├── flex_model ├── __init__.py ├── core │ ├── __init__.py │ ├── core_utils.py │ ├── hook_function.py │ └── wrapper.py ├── distributed │ ├── __init__.py │ ├── distributed_state.py │ ├── mappings.py │ └── strategies.py ├── package_info.py ├── traverse │ ├── __init__.py │ ├── nodes.py │ └── ops.py └── utils.py ├── profiling ├── README.md ├── __init__.py ├── profile_hooks.py └── utils.py ├── pyproject.toml ├── requirements ├── dev_requirements.txt ├── examples_requirements.txt ├── requirements.txt └── test_requirements.txt └── setup.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = 4 | test 5 | flex_model 6 | 7 | [report] 8 | ; Regexes for lines to exclude from consideration 9 | exclude_also = 10 | ; Don't complain about missing debug-only code: 11 | def __repr__ 12 | if self\.debug 13 | 14 | ; Don't complain if tests don't hit defensive assertion code: 15 | raise AssertionError 16 | raise NotImplementedError 17 | 18 | ; Don't complain if non-runnable code isn't run: 19 | if 0: 20 | if __name__ == .__main__.: 21 | 22 | ; Don't complain about abstract methods, they aren't run: 23 | @(abc\.)?abstractmethod 24 | 25 | ignore_errors = True 26 | 27 | [html] 28 | directory = coverage_html_report 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG][...]:" 5 | labels: BUG 6 | assignees: MChoi-git 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Provide minimal code example to reproduce the bug, along with steps for running the code example. 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | 22 | **System Info** 23 | Please provide information on number of GPUs used, wrapped model, distributed config, etc. 24 | 25 | Other system info can be obtained by running PyTorch's [collect_env](https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py) script. This can be run via: 26 | ``` 27 | wget https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py 28 | python collect_env.py 29 | ``` 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[ENHANCEMENT][...]:" 5 | labels: ENHANCEMENT 6 | assignees: MChoi-git 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # vim 2 | *.swp 3 | *.swo 4 | 5 | # python 6 | __pycache__/ 7 | .mypy_cache/ 8 | 9 | # Stubs and logs 10 | llama/ 11 | *.logits 12 | *.log 13 | .ipynb_checkpoints/ 14 | stubs/ 15 | transformers/ 16 | Megatron-LM/ 17 | wandb/ 18 | profiles/ 19 | 20 | docs/build/ 21 | docs/source/generated 22 | 23 | # slurm 24 | *.err 25 | *.out 26 | *.slrm 27 | *_logs/ 28 | 29 | # Jupyter 30 | *.ipynb 31 | 32 | # Coverage 33 | .coverage 34 | coverage_html_report/ 35 | 36 | # PIP builds 37 | dist/ 38 | *.egg-info/ 39 | build/ 40 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.1.5 4 | hooks: 5 | - id: ruff 6 | args: [ --fix ] 7 | - id: ruff-format 8 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.9" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the "docs/" directory with Sphinx 19 | sphinx: 20 | configuration: docs/source/conf.py 21 | 22 | # Optionally build your docs in additional formats such as PDF and ePub 23 | # formats: 24 | # - pdf 25 | # - epub 26 | 27 | # Optional but recommended, declare the Python requirements required 28 | # to build your documentation 29 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 30 | python: 31 | install: 32 | - method: pip 33 | path: . 34 | extra_requirements: 35 | - all 36 | -------------------------------------------------------------------------------- /.ruff.toml: -------------------------------------------------------------------------------- 1 | # Exclude a variety of commonly ignored directories. 2 | exclude = [ 3 | ".bzr", 4 | ".direnv", 5 | ".eggs", 6 | ".git", 7 | ".git-rewrite", 8 | ".hg", 9 | ".mypy_cache", 10 | ".nox", 11 | ".pants.d", 12 | ".pytype", 13 | ".ruff_cache", 14 | ".svn", 15 | ".tox", 16 | ".venv", 17 | "__pypackages__", 18 | "_build", 19 | "buck-out", 20 | "build", 21 | "dist", 22 | "node_modules", 23 | "venv", 24 | ] 25 | 26 | # Same as Black. 27 | line-length = 80 28 | indent-width = 4 29 | 30 | # Assume Python 3.9 31 | target-version = "py39" 32 | 33 | # Avoid removing imports from __init__.py files. 34 | ignore-init-module-imports = true 35 | 36 | # Ignore `E402` (import violations) in all `__init__.py` files, and in `path/to/file.py`. 37 | [lint.per-file-ignores] 38 | "__init__.py" = ["F401"] 39 | 40 | [lint] 41 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 42 | # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or 43 | # McCabe complexity (`C901`) by default. 44 | select = ["E4", "E7", "E9", "F"] 45 | ignore = [] 46 | 47 | # Allow fix for all enabled rules (when `--fix`) is provided. 48 | fixable = ["ALL"] 49 | unfixable = [] 50 | 51 | # Allow unused variables when underscore-prefixed. 52 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 53 | 54 | [format] 55 | # Like Black, use double quotes for strings. 56 | quote-style = "double" 57 | 58 | # Like Black, indent with spaces, rather than tabs. 59 | indent-style = "space" 60 | 61 | # Like Black, respect magic trailing commas. 62 | skip-magic-trailing-comma = false 63 | 64 | # Like Black, automatically detect the appropriate line ending. 65 | line-ending = "auto" 66 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Matthew Choi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlexModel 2 | [![Documentation Status](https://readthedocs.org/projects/flexmodel/badge/?version=latest)](https://flexmodel.readthedocs.io/en/latest/) 3 | 4 | # Installation 5 | 6 | ## Using pip 7 | Run the command: `pip install flex-model`. 8 | 9 | ## From Source 10 | Clone the repository and run `pip install .[all]` to install all dependencies. 11 | For more fine-grained options, you can install directly from the requirements 12 | files found in `requirements/`. 13 | 14 | ## Current WIPs 15 | * ~~Adding more rigorous tess for hook function behaviours~~ 16 | * Implement strategies for hanlding distributed `save_ctx` and `trainable_modules` 17 | * Visualizations of model architecture showing where hooks can be placed 18 | * Visualizations of model architecture showing where hooks are currently 19 | located, and what groups they are tagged in 20 | * Editing function presets 21 | * Distributed debugging tools and user-provided editing function parsing 22 | * Remove the need to pass an `expected_shape` 23 | 24 | ## Introduction 25 | `FlexModel` is a tool designed for distributed interpretability of Large 26 | Language Models (LLMs). `FlexModel` allows you to retrieve and/or edit 27 | **unsharded** activations within your LLM (among other things). 28 | 29 | For more detailed information, checkout our [docs](https://flexmodel.readthedocs.io/en/latest/) 30 | or read the paper (coming soon!). There are also a myriad of examples/demos 31 | to showcase what you can do with `FlexModel`. Feel free to raise a github issue 32 | for new features/bugs! 33 | 34 | ## Introduction: The `FlexModel` Wrapper 35 | `FlexModel` wraps any `nn.Module`, and replaces the typical PyTorch 36 | `nn.Module` hook registration functions. It contains all the state necessary 37 | for doing model surgery, while leaving the wrapped module invariant. 38 | 39 | ## Introduction: The `HookFunction` 40 | The replaced hook registration functions now receive `HookFunction` instances 41 | as input. It is the `HookFunction`'s job to retrieve activations and/or 42 | edit them within the wrapped model. To edit an activation (which will affect 43 | subsequent model operation), you can simply provide your `HookFunction` with 44 | an editing function. The best part is that the editing function can contain 45 | arbitrary code, and runs single-threaded. So you don't have to worry about any 46 | SPMD parallelism in your editing function! 47 | 48 | ## What's a hook function? 49 | PyTorch exposes endpoints in each `torch.nn.Module`, which calls your 50 | "hook function" at a specified time during module operation. There's a great 51 | introductory blogpost 52 | [here](https://web.stanford.edu/~nanbhas/blog/forward-hooks-pytorch/). 53 | 54 | ## Why not just use PyTorch hooks? 55 | Vanilla PyTorch hooks work great for single-gpu/process models. However if you 56 | need access to full activations for retrieval/editing, then you'll need to 57 | figure out how to unshard them in each hook function. Given the parallelism 58 | dimensions, `FlexModel` can figure out which collectives to call if necessary, 59 | so your activations are always unsharded. For example, `FlexModel` integrates 60 | simply with distributed frameworks like DDP, FSDP, Fairscale Megatron and 61 | Megatron-LM. 62 | 63 | ## What can I hook? 64 | You can attach hooks to anything which native PyTorch would allow you to 65 | hook into! `FlexModel` simply intercepts the `nn.Module` hook function 66 | registration API to inject our own logic. Concretely, we support the following 67 | hook function registration functions: 68 | * `nn.Module.register_forward_hook(...)`: [Usage](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook) 69 | * `nn.Module.register_full_backward_hook(...)`: [Usage](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook) 70 | * `torch.Tensor.register_hook(...)`: [Usage](https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html#torch-tensor-register-hook) 71 | * Note that this hook is not well-supported in the multi-gpu case, as parameter tensors are often custom-handled by frameworks like DDP/FSDP. 72 | * `nn.Module.register_forward_pre_hook(...)`: [Usage](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_pre_hook) 73 | * `nn.Module.register_full_backward_pre_hook(...)`: [Usage](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_pre_hook) 74 | 75 | # Usage 76 | Here's a short example on how you would use the FlexModel and HookFunction 77 | classes. When using distributed models using FSDP or Megatron layers, the 78 | `FlexModel` class requires specification of data parallel (DP), tensor parallel 79 | (TP), and pipeline parallel (PP) sizes. 80 | ```python 81 | from torch.distributed.fsdp import FullyShardedDataParallel 82 | from flex_model import FlexModel, HookFunction 83 | 84 | # Load some model 85 | model = FSDP(model, ...) 86 | inputs = ... 87 | 88 | # Activations will be dumped here 89 | output_dict = {} 90 | 91 | # Wrap the model 92 | model = FlexModel( 93 | model, 94 | output_dict, 95 | tensor_parallel_size=1, 96 | pipeline_parallel_size=1, 97 | data_parallel_size=4, # For FSDP over 4 GPUs 98 | ) 99 | 100 | # Only need to provide a shape hint on the dimension which may be sharded 101 | # across gpus. 102 | # (batch, sequence, hidden) -> sharded along hidden 103 | expected_shape = (None, None, hidden_dim) 104 | 105 | # Define a hook function 106 | def editing_function(current_module, inputs, save_ctx, modules) -> Tensor 107 | # Save input activation tensor for later use 108 | save_ctx.activation = inputs 109 | 110 | # Edit the activation tensor 111 | edited_activations = inputs * 2 112 | 113 | # Apply torch.nn.Modules to the activation tensor (can generate grads) 114 | edited_activations = modules["mlp"].forward(edited_activations) 115 | 116 | return edited_activations 117 | 118 | # Create hook function 119 | hook_function = HookFunction( 120 | module_name="layers.7.feed_forward.w1", 121 | expected_shape=expected_shape, 122 | editing_function=editing_function, 123 | ) 124 | 125 | # Register the hook function into our model 126 | model.register_forward_hook(hook_function) 127 | 128 | # Run a forward pass, activations will be placed in the output_dict 129 | model.forward(inputs) 130 | ``` 131 | 132 | 133 | # `HookFunction` Groups 134 | `HookFunction`s can be associated with group tags. Using these tags, you can 135 | choose which groups are run during a given forward pass. There are two primary 136 | ways of interacting with `HookFunction` groups: 137 | 1. `FlexModel.create_hook_group`: This function creates a collection of uniform 138 | `HookFunction` instances, and tags them under the same group name. Let's 139 | inspect the function signature: 140 | ```python 141 | def create_hook_group( 142 | self, 143 | group_name: str, 144 | group_constructor: str, 145 | expected_shape: Optional[Tuple[Optional[int], ...]] = None, 146 | editing_function: Optional[Callable] = None, 147 | unpack_idx: Optional[int] = 0, 148 | ``` 149 | - `group_name`: Name of the group to tag the created `HookFunctions` under. 150 | - `group_constructor`: String pattern which is used to match against 151 | submodule names. For example, setting this to "self_attn" will match any 152 | submodule with "self_attn" `in` its name. If 10 submodules match this, then 153 | 10 `HookFunction` instances will be created and registered on its respective 154 | submodule. 155 | - `expected_shape`, `editing_function` and `unpack_idx` will all be the 156 | same for each `HookFunction` created. 157 | 2. `FlexModel.update_hook_groups`: This function updates the group tags for 158 | existing `HookFunction` instances already registered. It takes either a list of 159 | `HookFunction`s to tag, a single `HookFunction` to tag, or a string to pattern- 160 | match against submodules to automatically tag any associated `HookFunction`s. 161 | 162 | ## Enabling/Disabling Certain Groups 163 | Note that `HookFunction` groups follow `set` semantics. When running forward passes, all 164 | `HookFunction`s are enabled by default (ie. all `HookFunction`s are members of 165 | the `all` group). Specifying the groups to run as a list of strings in the 166 | models's forward function will enable the union set of `HookFunction`s withing 167 | the groups. You can also enable the `complement` argument, which will enable 168 | all hooks **not** in the union set. 169 | 170 | ## Adding/Removing Group Tags 171 | Each `HookFunction` instance can be tagged in as many groups as you'd like. 172 | `HookFunction`s can also be removed from groups via `remove_hook_groups` with 173 | similar semantics to the `update_hook_groups` method. 174 | 175 | Note that you **cannot** remove the `all` group tag from any `HookFunction` 176 | instance, which will cause an exception. 177 | 178 | 179 | # Running Tests 180 | Running single-gpu tests from the project folder using `pytest` can be done with 181 | the command: 182 | ``` 183 | torchrun --nnodes 1 --nproc_per_node -m pytest --ignore=_test/multi_gpu _test/ 184 | ``` 185 | 186 | Multi-gpu tests are run via `submitit` on a `slurm` cluster. Navigate to 187 | `_test/multi_gpu` and run the command: 188 | ``` 189 | python run_multi_gpu_tests_slurm.py 190 | ``` 191 | The multi-gpu tests require 4 GPUs to run. 192 | 193 | 194 | # Important Notes 195 | - Make sure to replace any instances of `module.forward(inputs)` with 196 | `module(inputs)`. The forward hooks are not run by PyTorch if you directly call 197 | the forward function of a module (this is the case with LLaMA). 198 | - If you would like to create `HookFunction` entrypoints arbitrarily in the 199 | wrapped model, you can place `DummyModule`s with identity forward functions 200 | which can be hooked into. `DummyModule` is located in the `core/core_utils.py` 201 | file. 202 | -------------------------------------------------------------------------------- /_test/README.md: -------------------------------------------------------------------------------- 1 | # Running Tests 2 | Run tests using the command: 3 | ``` 4 | torchrun --nnodes 1 --nproc_per_node 1 -m pytest --ignore=_test/multi_gpu _test/ 5 | ``` 6 | These tests require a single GPU to run. Tests which require multiple gpus to 7 | run (see `multi_gpu` folder) are run using `submitit` instead of `pytest`. 8 | 9 | # Test Coverage 10 | To generate code coverage reports, run the following command from the top-level 11 | `flex_model` directory (ie. `cd ..`). 12 | ``` 13 | coverage run -m pytest --ignore=tests/multi_gpu test/ 14 | ``` 15 | And you can read the coverage report by running: 16 | ``` 17 | coverage report 18 | ``` 19 | -------------------------------------------------------------------------------- /_test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorInstitute/flex_model/94e3cb434d26bc35c8503b4f6f2dd0b500ae90e8/_test/__init__.py -------------------------------------------------------------------------------- /_test/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer 5 | 6 | 7 | @pytest.fixture 8 | def llama_13b() -> nn.Module: 9 | """Helper function to construct a llama-2 model and tokenizer.""" 10 | model = AutoModelForCausalLM.from_pretrained( 11 | "/model-weights/Llama-2-13b-hf", 12 | local_files_only=True, 13 | torch_dtype=torch.bfloat16, 14 | low_cpu_mem_usage=True, 15 | ) 16 | 17 | return model 18 | 19 | 20 | @pytest.fixture 21 | def llama_tokenizer() -> LlamaTokenizer: 22 | tokenizer = AutoTokenizer.from_pretrained( 23 | "/model-weights/Llama-2-13b-hf", 24 | local_files_only=True, 25 | ) 26 | tokenizer.pad_token_id = 0 27 | tokenizer.padding_side = "right" 28 | tokenizer.model_max_length = 128 29 | 30 | return tokenizer 31 | 32 | 33 | @pytest.fixture 34 | def make_opt_350m() -> nn.Module: 35 | def _make_opt_350m(): 36 | model = AutoModelForCausalLM.from_pretrained( 37 | "/model-weights/opt-350m", 38 | local_files_only=True, 39 | torch_dtype=torch.bfloat16, 40 | ) 41 | return model 42 | 43 | yield _make_opt_350m 44 | 45 | 46 | @pytest.fixture(scope="session") 47 | def opt_350m_module_names(): 48 | model = AutoModelForCausalLM.from_pretrained( 49 | "/model-weights/opt-350m", 50 | local_files_only=True, 51 | torch_dtype=torch.bfloat16, 52 | ) 53 | return [n for n, _ in model.named_modules()] 54 | 55 | 56 | @pytest.fixture 57 | def opt_tokenizer(): 58 | tokenizer = AutoTokenizer.from_pretrained( 59 | "/model-weights/opt-350m", 60 | local_files_only=True, 61 | ) 62 | 63 | return tokenizer 64 | 65 | 66 | # TODO: Parameterize models as a sweep. 67 | @pytest.fixture 68 | def model(): 69 | raise NotImplementedError 70 | -------------------------------------------------------------------------------- /_test/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorInstitute/flex_model/94e3cb434d26bc35c8503b4f6f2dd0b500ae90e8/_test/core/__init__.py -------------------------------------------------------------------------------- /_test/core/test_hook_function.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | from flex_model.core import FlexModel, HookFunction 7 | 8 | 9 | # could be any MLP layer and the code won't break. The test doesn't generalize 10 | # to other kinds of layers 11 | MODULE_NAME = "model.decoder.layers.9.fc2" 12 | 13 | 14 | # For getting submodule by name. 15 | def rgetattr(module, attr): 16 | def _getattr(module, attr): 17 | return getattr(module, attr) 18 | 19 | return reduce(_getattr, [module] + attr.split(".")) 20 | 21 | 22 | def test_forward_hooks(make_opt_350m): 23 | dist.init_process_group("nccl") 24 | model = make_opt_350m().eval().cuda() 25 | 26 | inputs = torch.randint(0, 6400, size=(4, 32)).cuda() 27 | 28 | acc = {} 29 | 30 | # Test pytorch hook. 31 | def hook_fn(module, inputs, outputs): 32 | acc["torch"] = outputs.detach().cpu() 33 | return outputs * 2 34 | 35 | submodule_to_hook = rgetattr(model, MODULE_NAME) 36 | handle = submodule_to_hook.register_forward_hook(hook_fn) 37 | gt_retval = model(inputs).logits 38 | 39 | handle.remove() 40 | 41 | # Test flexmodel hook. 42 | def editing_fn(module, outputs, save_ctx, trainable_modules): 43 | return outputs * 2 44 | 45 | flexmodel = FlexModel(model, acc) 46 | flexmodel.register_forward_hook( 47 | HookFunction(MODULE_NAME, editing_function=editing_fn) 48 | ) 49 | fm_retval = flexmodel(inputs).logits 50 | 51 | torch.testing.assert_close(gt_retval, fm_retval) 52 | torch.testing.assert_close(acc["torch"], acc[MODULE_NAME][0]) 53 | dist.destroy_process_group() 54 | 55 | 56 | def test_full_backward_hooks(make_opt_350m): 57 | dist.init_process_group("nccl") 58 | inputs = torch.randint(0, 6400, size=(4, 32)).cuda() 59 | 60 | acc = {} 61 | 62 | # Test pytorch hook. 63 | gt_grad_model = make_opt_350m().cuda() 64 | 65 | def hook_fn(module, grad_inputs, grad_outputs): 66 | acc["torch"] = grad_inputs[0].detach().cpu() 67 | grad_inputs = (grad_inputs[0] * 2, *grad_inputs[1:]) 68 | return grad_inputs 69 | 70 | submodule_to_hook = rgetattr(gt_grad_model, MODULE_NAME) 71 | handle = submodule_to_hook.register_full_backward_hook(hook_fn) 72 | gt_retval = gt_grad_model(inputs).logits 73 | gt_retval.mean().backward() 74 | 75 | handle.remove() 76 | 77 | # Test flexmodel hook. 78 | fm_grad_model = make_opt_350m().cuda() 79 | 80 | def editing_fn(module, grad_inputs, save_ctx, trainable_modules): 81 | return grad_inputs * 2 82 | 83 | flexmodel = FlexModel(fm_grad_model, acc) 84 | flexmodel.register_full_backward_hook( 85 | HookFunction(MODULE_NAME, editing_function=editing_fn) 86 | ) 87 | fm_retval = flexmodel(inputs).logits 88 | fm_retval.mean().backward() 89 | 90 | torch.testing.assert_close(gt_retval, fm_retval) 91 | torch.testing.assert_close(acc["torch"], acc[MODULE_NAME][0]) 92 | for (gt_name, gt_grad), (fm_name, fm_grad) in zip( 93 | gt_grad_model.named_parameters(), fm_grad_model.named_parameters() 94 | ): 95 | assert gt_name == fm_name 96 | torch.testing.assert_close(gt_grad, fm_grad) 97 | 98 | dist.destroy_process_group() 99 | 100 | 101 | def test_tensor_hooks(make_opt_350m): 102 | dist.init_process_group("nccl") 103 | inputs = torch.randint(0, 6400, size=(4, 32)).cuda() 104 | 105 | acc = {} 106 | 107 | # Test pytorch hook. 108 | gt_grad_model = make_opt_350m().cuda() 109 | 110 | def hook_fn(grad): 111 | acc["torch"] = grad.detach().cpu() 112 | return grad 113 | 114 | submodule_to_hook = rgetattr(gt_grad_model, f"{MODULE_NAME}.weight") 115 | handle = submodule_to_hook.register_hook(hook_fn) 116 | gt_retval = gt_grad_model(inputs).logits 117 | gt_retval.mean().backward() 118 | 119 | handle.remove() 120 | 121 | # Test flexmodel hook. 122 | fm_grad_model = make_opt_350m().cuda() 123 | 124 | def editing_fn(module, grad, save_ctx, trainable_modules): 125 | return grad * 2 126 | 127 | flexmodel = FlexModel(fm_grad_model, acc) 128 | flexmodel.register_hook( 129 | HookFunction(f"{MODULE_NAME}.weight", editing_function=editing_fn) 130 | ) 131 | fm_retval = flexmodel(inputs).logits 132 | fm_retval.mean().backward() 133 | 134 | torch.testing.assert_close(gt_retval, fm_retval) 135 | torch.testing.assert_close(acc["torch"], acc[f"{MODULE_NAME}.weight"][0]) 136 | for (gt_name, gt_grad), (fm_name, fm_grad) in zip( 137 | gt_grad_model.named_parameters(), fm_grad_model.named_parameters() 138 | ): 139 | assert gt_name == fm_name 140 | torch.testing.assert_close(gt_grad, fm_grad) 141 | 142 | dist.destroy_process_group() 143 | 144 | 145 | def test_forward_pre_hooks(make_opt_350m): 146 | dist.init_process_group("nccl") 147 | model = make_opt_350m().eval().cuda() 148 | 149 | inputs = torch.randint(0, 6400, size=(4, 32)).cuda() 150 | 151 | acc = {} 152 | 153 | # Test pytorch hook. 154 | def hook_fn(module, inputs_): 155 | acc["torch"] = inputs_[0].detach().cpu() 156 | inputs_ = (inputs_[0] * 2, *inputs_[1:]) 157 | return inputs_ 158 | 159 | submodule_to_hook = rgetattr(model, MODULE_NAME) 160 | handle = submodule_to_hook.register_forward_pre_hook(hook_fn) 161 | gt_retval = model(inputs).logits 162 | 163 | handle.remove() 164 | 165 | # Test flexmodel hook. 166 | def editing_fn(module, inputs_, save_ctx, trainable_modules): 167 | return inputs_ * 2 168 | 169 | flexmodel = FlexModel(model, acc) 170 | flexmodel.register_forward_pre_hook( 171 | HookFunction(MODULE_NAME, editing_function=editing_fn) 172 | ) 173 | fm_retval = flexmodel(inputs).logits 174 | 175 | torch.testing.assert_close(gt_retval, fm_retval) 176 | torch.testing.assert_close(acc["torch"], acc[MODULE_NAME][0]) 177 | 178 | dist.destroy_process_group() 179 | 180 | 181 | def test_full_backward_pre_hooks(make_opt_350m): 182 | dist.init_process_group("nccl") 183 | inputs = torch.randint(0, 6400, size=(4, 32)).cuda() 184 | 185 | acc = {} 186 | 187 | # Test pytorch hook. 188 | gt_grad_model = make_opt_350m().cuda() 189 | 190 | def hook_fn(module, grad_outputs): 191 | acc["torch"] = grad_outputs[0].detach().cpu() 192 | grad_outputs = (grad_outputs[0] * 2, *grad_outputs[1:]) 193 | return grad_outputs 194 | 195 | submodule_to_hook = rgetattr(gt_grad_model, MODULE_NAME) 196 | handle = submodule_to_hook.register_full_backward_pre_hook(hook_fn) 197 | gt_retval = gt_grad_model(inputs).logits 198 | gt_retval.mean().backward() 199 | 200 | handle.remove() 201 | 202 | # Test flexmodel hook. 203 | fm_grad_model = make_opt_350m().cuda() 204 | 205 | def editing_fn(module, grad_outputs, save_ctx, trainable_modules): 206 | return grad_outputs * 2 207 | 208 | flexmodel = FlexModel(fm_grad_model, acc) 209 | flexmodel.register_full_backward_pre_hook( 210 | HookFunction(MODULE_NAME, editing_function=editing_fn) 211 | ) 212 | fm_retval = flexmodel(inputs).logits 213 | fm_retval.mean().backward() 214 | 215 | torch.testing.assert_close(gt_retval, fm_retval) 216 | torch.testing.assert_close(acc["torch"], acc[MODULE_NAME][0]) 217 | for (gt_name, gt_grad), (fm_name, fm_grad) in zip( 218 | gt_grad_model.named_parameters(), fm_grad_model.named_parameters() 219 | ): 220 | assert gt_name == fm_name 221 | torch.testing.assert_close(gt_grad, fm_grad) 222 | 223 | dist.destroy_process_group() 224 | -------------------------------------------------------------------------------- /_test/core/test_hook_function_group_manager.py: -------------------------------------------------------------------------------- 1 | from flex_model.core import HookFunction 2 | from flex_model.core.wrapper import _HookFunctionGroupManager 3 | 4 | 5 | def test_HookFunctionGroupManager_create(opt_350m_module_names): 6 | manager = _HookFunctionGroupManager() 7 | 8 | new_hook_fns = manager.create( 9 | "new_group", 10 | "self_attn", 11 | all_names=opt_350m_module_names, 12 | ) 13 | new_hook_fns = set(new_hook_fns) 14 | 15 | assert "new_group" in manager.groups 16 | 17 | for hook_fn, groups in manager.hook_fn_to_groups_map.items(): 18 | if hook_fn in new_hook_fns: 19 | assert "new_group" in groups 20 | assert "self_attn" in hook_fn.module_name 21 | else: 22 | assert "new_group" not in groups 23 | assert "self_attn" not in hook_fn.module_name 24 | 25 | 26 | def test_HookFunctionGroupManager_update_by_list(opt_350m_module_names): 27 | manager = _HookFunctionGroupManager() 28 | 29 | original_hf_group = manager.create( 30 | "new_group", 31 | "self_attn", 32 | all_names=opt_350m_module_names, 33 | ) 34 | new_hf_group = [ 35 | hf for hf in original_hf_group if "q_proj" in hf.module_name 36 | ] 37 | 38 | manager.update( 39 | new_hf_group, 40 | group_name="q_proj", 41 | ) 42 | 43 | assert "new_group" in manager.groups 44 | assert "q_proj" in manager.groups 45 | 46 | original_hf_group = set(original_hf_group) 47 | new_hf_group = set(new_hf_group) 48 | for hook_fn, groups in manager.hook_fn_to_groups_map.items(): 49 | if hook_fn in original_hf_group: 50 | assert "new_group" in groups 51 | assert "self_attn" in hook_fn.module_name 52 | else: 53 | assert "new_group" not in groups 54 | assert "self_attn" not in hook_fn.module_name 55 | 56 | if hook_fn in new_hf_group: 57 | assert "q_proj" in groups 58 | # Don't assert "q_proj" in module name since test does this. 59 | else: 60 | assert "q_proj" not in groups 61 | 62 | 63 | def test_HookFunctionGroupManager_update_by_hook_fn(opt_350m_module_names): 64 | manager = _HookFunctionGroupManager() 65 | 66 | hook_function = HookFunction( 67 | "model.decoder.layers.12", 68 | ) 69 | 70 | manager.update(hook_function, group_name="new_group") 71 | 72 | assert "new_group" in manager.groups 73 | 74 | for hook_fn, groups in manager.hook_fn_to_groups_map.items(): 75 | if hook_fn is hook_function: 76 | assert "new_group" in groups 77 | else: 78 | assert "new_group" not in groups 79 | 80 | 81 | def test_HookFunctionGroupManager_update_by_string(opt_350m_module_names): 82 | manager = _HookFunctionGroupManager() 83 | 84 | new_hf_group = manager.create( 85 | "new_group", 86 | "self_attn", 87 | all_names=opt_350m_module_names, 88 | ) 89 | 90 | manager.update( 91 | "k_proj", 92 | group_name="k_proj_group", 93 | ) 94 | 95 | assert "new_group" in manager.groups 96 | assert "k_proj_group" in manager.groups 97 | 98 | for hook_fn, groups in manager.hook_fn_to_groups_map.items(): 99 | if hook_fn in new_hf_group: 100 | assert "new_group" in groups 101 | else: 102 | assert "new_group" not in groups 103 | 104 | if "k_proj" in hook_fn.module_name: 105 | assert "k_proj_group" in groups 106 | else: 107 | assert "k_proj_group" not in groups 108 | 109 | 110 | def test_HookFunctionGroupManager_remove_by_list(opt_350m_module_names): 111 | manager = _HookFunctionGroupManager() 112 | 113 | new_hf_group = manager.create( 114 | "new_group", 115 | "self_attn", 116 | all_names=opt_350m_module_names, 117 | ) 118 | _other_hf_group = manager.create( 119 | "other_group", 120 | "q_proj", 121 | all_names=opt_350m_module_names, 122 | ) 123 | 124 | assert "new_group" in manager.groups 125 | assert "other_group" in manager.groups 126 | 127 | manager.remove(new_hf_group, "new_group") 128 | 129 | assert "new_group" not in manager.groups 130 | 131 | for hook_fn, groups in manager.hook_fn_to_groups_map.items(): 132 | if hook_fn in new_hf_group: 133 | assert "new_group" not in groups 134 | assert "other_group" not in groups 135 | else: 136 | assert "other_group" in groups 137 | assert "new_group" not in groups 138 | 139 | 140 | def test_HookFunctionGroupManager_remove_by_hook_fn(opt_350m_module_names): 141 | manager = _HookFunctionGroupManager() 142 | 143 | new_hf_group = manager.create( 144 | "new_group", 145 | "self_attn", 146 | all_names=opt_350m_module_names, 147 | ) 148 | hf_to_remove = new_hf_group[0] 149 | 150 | _other_hf_group = manager.create( 151 | "other_group", 152 | "q_proj", 153 | all_names=opt_350m_module_names, 154 | ) 155 | 156 | manager.remove(hf_to_remove, "new_group") 157 | 158 | for hook_fn, groups in manager.hook_fn_to_groups_map.items(): 159 | if hook_fn is hf_to_remove: 160 | assert "new_group" not in groups 161 | assert "other_group" not in groups 162 | continue 163 | 164 | if hook_fn in new_hf_group: 165 | assert "new_group" in groups 166 | assert "other_group" not in groups 167 | else: 168 | assert "other_group" in groups 169 | assert "new_group" not in groups 170 | 171 | 172 | def test_HookFunctionGroupManager_remove_by_string(opt_350m_module_names): 173 | manager = _HookFunctionGroupManager() 174 | 175 | new_hf_group = manager.create( 176 | "new_group", 177 | "self_attn", 178 | all_names=opt_350m_module_names, 179 | ) 180 | 181 | _other_hf_group = manager.create( 182 | "other_group", 183 | "q_proj", 184 | all_names=opt_350m_module_names, 185 | ) 186 | 187 | assert "new_group" in manager.groups 188 | assert "other_group" in manager.groups 189 | 190 | manager.remove("k_proj", "new_group") 191 | 192 | for hook_fn, groups in manager.hook_fn_to_groups_map.items(): 193 | if "k_proj" in hook_fn.module_name: 194 | assert "new_group" not in groups 195 | assert "other_group" not in groups 196 | else: 197 | if hook_fn in new_hf_group: 198 | assert "new_group" in groups 199 | assert "other_group" not in groups 200 | else: 201 | assert "new_group" not in groups 202 | assert "other_group" in groups 203 | 204 | 205 | def test_HookFunctionGroupManager_bisect(opt_350m_module_names): 206 | manager = _HookFunctionGroupManager() 207 | 208 | _new_hf_group = manager.create( 209 | "new_group", 210 | "self_attn", 211 | all_names=opt_350m_module_names, 212 | ) 213 | 214 | other_hf_group = manager.create( 215 | "other_group", 216 | "fc1", 217 | all_names=opt_350m_module_names, 218 | ) 219 | 220 | active, inactive = manager.bisect("new_group") 221 | 222 | assert active.isdisjoint(inactive) 223 | assert set(other_hf_group) & inactive == inactive 224 | 225 | active, inactive = manager.bisect(["new_group", "other_group"]) 226 | 227 | assert active.isdisjoint(inactive) 228 | assert inactive == set() 229 | -------------------------------------------------------------------------------- /_test/core/test_wrapper.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from flex_model.core import FlexModel, HookFunction 7 | import _test.multi_gpu.testing_utils as utils 8 | 9 | MODULE_NAME_1 = "model.decoder.layers.17.fc2" 10 | MODULE_NAME_2 = "model.decoder.layers.18.fc2" 11 | PROMPTS = [ 12 | "It's a nice day we're having", 13 | "The capital of Canada is", 14 | "What should I eat for dinner tonight?", 15 | "There's about three people going to", 16 | ] 17 | 18 | 19 | def test_register_forward_hook(make_opt_350m): 20 | """ 21 | Tests if a hook function is registered correctly, and if the fields are set 22 | appropriately. 23 | """ 24 | utils.init_process_group() 25 | model = make_opt_350m().cuda() 26 | 27 | activations = {} 28 | model = FlexModel( 29 | model, 30 | activations, 31 | ) 32 | 33 | my_hook_function = HookFunction(MODULE_NAME_1) 34 | 35 | model.register_forward_hook(my_hook_function) 36 | 37 | assert my_hook_function._shared_state.output_ptr is activations 38 | assert my_hook_function._shared_state.save_ctx is model.save_ctx 39 | assert my_hook_function._shared_state.modules is model.trainable_modules 40 | 41 | 42 | def test_register_trainable_module(make_opt_350m): 43 | """ 44 | Tests if a trainable module is registered correctly, and that all hook 45 | functions (regardless of when they're added), have a pointer to this 46 | module. 47 | """ 48 | utils.init_process_group() 49 | model = make_opt_350m().cuda() 50 | 51 | activations = {} 52 | trainable_module = nn.Linear(420, 69, bias=False).cuda() 53 | model = FlexModel(model, activations) 54 | 55 | my_hook_function_1 = HookFunction( 56 | MODULE_NAME_1, 57 | ) 58 | my_hook_function_2 = HookFunction( 59 | MODULE_NAME_2, 60 | ) 61 | 62 | model.register_forward_hook(my_hook_function_1) 63 | model.register_trainable_module("test", trainable_module) 64 | model.register_forward_hook(my_hook_function_2) 65 | 66 | assert "test" in my_hook_function_1._shared_state.modules 67 | assert "test" in my_hook_function_2._shared_state.modules 68 | assert my_hook_function_1._shared_state.modules["test"] is trainable_module 69 | assert my_hook_function_2._shared_state.modules["test"] is trainable_module 70 | 71 | 72 | def test_trainable_module_gradient(make_opt_350m): 73 | utils.init_process_group() 74 | model = make_opt_350m().cuda() 75 | 76 | activations = {} 77 | fc = nn.Linear(1024, 1024, bias=False, dtype=model.dtype).cuda() 78 | model = FlexModel(model, activations) 79 | 80 | inputs = torch.randint( 81 | low=0, 82 | high=15000, 83 | size=(4, 64), 84 | ).cuda() 85 | 86 | model.register_trainable_module("test", fc) 87 | 88 | def _apply_test_fc(m, inputs, save_ctx, trainable_modules): 89 | return trainable_modules["test"](inputs) 90 | 91 | hook_func = HookFunction( 92 | MODULE_NAME_1, expected_shape=None, editing_function=_apply_test_fc 93 | ) 94 | 95 | model.register_forward_hook(hook_func) 96 | 97 | outputs = model(inputs) 98 | loss = outputs.logits.mean() 99 | loss.backward() 100 | 101 | for n, p in model.named_parameters(): 102 | assert p.grad is not None, f"Parameter: {n} has None grad field." 103 | assert ( 104 | torch.count_nonzero(p.grad) != 0 105 | ), f"Parameter: {n} has all-zero grad field." 106 | 107 | 108 | def test_destroy(make_opt_350m): 109 | """ 110 | Tests the destroy method to ensure everything is cleared appropriately. 111 | """ 112 | utils.init_process_group() 113 | model = make_opt_350m().cuda() 114 | 115 | activations = {} 116 | trainable_module_1 = nn.Linear(420, 69, bias=False).cuda().requires_grad_() 117 | trainable_module_2 = nn.Linear(420, 69, bias=False).cuda().requires_grad_() 118 | model = FlexModel( 119 | model, 120 | activations, 121 | ) 122 | model.register_trainable_module("test1", trainable_module_1) 123 | model.register_trainable_module("test2", trainable_module_2) 124 | 125 | my_hook_function = HookFunction( 126 | MODULE_NAME_1, 127 | ) 128 | 129 | model.register_forward_hook(my_hook_function) 130 | model = model.module # Calls finalizer. 131 | 132 | assert not isinstance(model, FlexModel) 133 | assert not hasattr(model, "hook_functions") 134 | assert not hasattr(model, "_hook_function_handles") 135 | assert not hasattr(model, "_hooks_active") 136 | assert not hasattr(model, "output_ptr") 137 | assert not hasattr(model, "save_ctx") 138 | assert not hasattr(model, "trainable_modules") 139 | 140 | hook_types = {"_forward", "_forward_pre", "_backward"} 141 | for m in model.modules(): 142 | for hook_type in hook_types: 143 | attr = hook_type + "_hooks" 144 | assert len(getattr(m, attr)) == 0 145 | 146 | 147 | def test_save_ctx(make_opt_350m, opt_tokenizer): 148 | utils.init_process_group() 149 | model = make_opt_350m().cuda() 150 | 151 | tokenizer = opt_tokenizer 152 | 153 | activations = {} 154 | model = FlexModel(model, activations) 155 | 156 | prompts = [ 157 | "It's a nice day we're having", 158 | "The capital of Canada is", 159 | "What should I eat for dinner tonight?", 160 | "There's about three people going to", 161 | ] 162 | 163 | inputs = tokenizer(prompts, padding=True, return_tensors="pt")[ 164 | "input_ids" 165 | ].cuda() 166 | 167 | # Function to save an activation tensor for later use. The same activation 168 | # tensor is also saved into the `activations` dict we passed initially to 169 | # the `FlexModel.__init__()`. Hence we can verify that the `save_ctx` and 170 | # `activations` dict versions of the same tensor are indeed `torch.equal`. 171 | def retrieve_fn(current_module, inputs, save_ctx, modules): 172 | # Detach activation tensor and dump to cpu 173 | save_ctx.activation = inputs.detach().cpu() 174 | return inputs 175 | 176 | # Function to verify we still have access to the saved tensor 177 | def verify_fn(current_module, inputs, save_ctx, modules, act_dict): 178 | act_dict["save_ctx_activation"] = save_ctx.activation 179 | return inputs 180 | 181 | retrieve_hook_fn = HookFunction( 182 | "model.decoder.layers.12", 183 | editing_function=retrieve_fn, 184 | ) 185 | verify_hook_fn = HookFunction( 186 | "model.decoder.layers.18", 187 | editing_function=partial(verify_fn, act_dict=activations), 188 | ) 189 | model.register_forward_hook(retrieve_hook_fn) 190 | model.register_forward_hook(verify_hook_fn) 191 | 192 | _ = model(inputs) 193 | 194 | # Verify that the two verions of the same tensor are equal 195 | assert torch.equal( 196 | activations["save_ctx_activation"], 197 | activations["model.decoder.layers.12"][0], 198 | ) 199 | 200 | 201 | def test_FlexModel_group_all(make_opt_350m): 202 | utils.init_process_group() 203 | model = make_opt_350m().cuda() 204 | 205 | activations = {} 206 | model = FlexModel(model, activations) 207 | 208 | layers = [ 209 | f"model.decoder.layers.{i}" 210 | for i in range(len(model.module.model.decoder.layers)) 211 | ] 212 | hook_functions = [HookFunction(name) for name in layers] 213 | for hf in hook_functions: 214 | model.register_forward_hook(hf) 215 | 216 | manager = model._hook_fn_group_manager 217 | assert len(manager.hook_fn_to_groups_map) == len(hook_functions) 218 | for hf, group in manager.hook_fn_to_groups_map.items(): 219 | assert group == set(["all"]) 220 | 221 | 222 | def test_FlexModel_group_creation(make_opt_350m, opt_tokenizer): 223 | utils.init_process_group() 224 | model = make_opt_350m().cuda() 225 | prompts = [ 226 | "It's a nice day we're having", 227 | "The capital of Canada is", 228 | "What should I eat for dinner tonight?", 229 | "There's about three people going to", 230 | ] 231 | inputs = opt_tokenizer(prompts, padding=True, return_tensors="pt")[ 232 | "input_ids" 233 | ].cuda() 234 | 235 | activations = {} 236 | model = FlexModel(model, activations) 237 | 238 | # Run the model forward pass on a group, on the hook functions not in the 239 | # group, and on all hook functions. 240 | model.create_hook_group( 241 | group_name="new_group", 242 | group_constructor="self_attn", 243 | ) 244 | 245 | _ = model(inputs) 246 | 247 | all_group_tensors = {**activations} 248 | activations.clear() 249 | 250 | _ = model(inputs, groups="new_group") 251 | 252 | for name, tensor in activations.items(): 253 | assert "self_attn" in name 254 | 255 | new_group_tensors = {**activations} 256 | activations.clear() 257 | 258 | _ = model(inputs, groups="new_group", complement=True) 259 | 260 | non_new_group_tensors = {**activations} 261 | activations.clear() 262 | 263 | assert len(all_group_tensors) == len(new_group_tensors) + len( 264 | non_new_group_tensors 265 | ) 266 | for name, tensor in all_group_tensors.items(): 267 | assert name in new_group_tensors or name in non_new_group_tensors 268 | if name in new_group_tensors: 269 | new_ten = new_group_tensors.pop(name) 270 | assert torch.allclose(tensor, new_ten) 271 | else: 272 | non_new_ten = non_new_group_tensors.pop(name) 273 | assert torch.allclose(tensor, non_new_ten) 274 | 275 | assert len(new_group_tensors) == 0 276 | assert len(non_new_group_tensors) == 0 277 | 278 | for ( 279 | hook_fn, 280 | groups, 281 | ) in model._hook_fn_group_manager.hook_fn_to_groups_map.items(): 282 | if "self_attn" in hook_fn.module_name: 283 | assert "new_group" in groups 284 | assert "all" in groups 285 | -------------------------------------------------------------------------------- /_test/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorInstitute/flex_model/94e3cb434d26bc35c8503b4f6f2dd0b500ae90e8/_test/distributed/__init__.py -------------------------------------------------------------------------------- /_test/distributed/test_initialize.py: -------------------------------------------------------------------------------- 1 | from flex_model.distributed.distributed_state import _BaseGPUDeviceMesh 2 | 3 | 4 | def test_GPUDeviceMesh(): 5 | cases = { 6 | (1, 1, 1): [[[0]], [[0]], [[0]]], 7 | (4, 1, 1): [[[0, 1, 2, 3]], [[0], [1], [2], [3]], [[0], [1], [2], [3]]], 8 | (1, 4, 1): [[[0], [1], [2], [3]], [[0, 1, 2, 3]], [[0], [1], [2], [3]]], 9 | (1, 1, 4): [[[0], [1], [2], [3]], [[0], [1], [2], [3]], [[0, 1, 2, 3]]], 10 | (2, 2, 1): [[[0, 1], [2, 3]], [[0, 2], [1, 3]], [[0], [1], [2], [3]]], 11 | (1, 2, 2): [[[0], [1], [2], [3]], [[0, 2], [1, 3]], [[0, 1], [2, 3]]], 12 | (2, 2, 2): [ 13 | [[0, 1], [2, 3], [4, 5], [6, 7]], 14 | [[0, 4], [1, 5], [2, 6], [3, 7]], 15 | [[0, 2], [1, 3], [4, 6], [5, 7]], 16 | ], 17 | } 18 | 19 | for case, solution in cases.items(): 20 | tp = case[0] 21 | pp = case[1] 22 | dp = case[2] 23 | gpu_device_mesh = _BaseGPUDeviceMesh(tp, pp, dp) 24 | assert gpu_device_mesh.tp_group_ranks == solution[0], f"{case}" 25 | assert gpu_device_mesh.pp_group_ranks == solution[1], f"{case}" 26 | assert gpu_device_mesh.dp_group_ranks == solution[2], f"{case}" 27 | -------------------------------------------------------------------------------- /_test/multi_gpu/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorInstitute/flex_model/94e3cb434d26bc35c8503b4f6f2dd0b500ae90e8/_test/multi_gpu/__init__.py -------------------------------------------------------------------------------- /_test/multi_gpu/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_fairscale_mpu_vs_ddp import ( 2 | test_fairscale, 3 | test_forward_hooks_fairscale, 4 | test_full_backward_hooks_fairscale, 5 | test_forward_pre_hooks_fairscale, 6 | test_full_backward_pre_hooks_fairscale, 7 | ) 8 | from .test_dp_wrappers_vs_single_gpu import ( 9 | test_forward_hooks_wrapped, 10 | test_full_backward_hooks_wrapped, 11 | test_forward_pre_hooks_wrapped, 12 | test_full_backward_pre_hooks_wrapped, 13 | ) 14 | -------------------------------------------------------------------------------- /_test/multi_gpu/core/test_dp_wrappers_vs_single_gpu.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | from flex_model.core import FlexModel 8 | import _test.multi_gpu.testing_utils as utils 9 | from _test.multi_gpu.registry import SlurmJobResourceSpec, make_test_registry 10 | 11 | 12 | register_wrapper_test, get_wrapper_test = make_test_registry( 13 | "wrappers", 14 | SlurmJobResourceSpec(), 15 | ) 16 | 17 | 18 | MODULE_NAME = "fc1" 19 | MODULE_SHAPE = None 20 | FSDP_WRAP_LAYER = torch.nn.Linear 21 | 22 | 23 | WRAPPERS = { 24 | "ddp": utils.wrap_ddp, 25 | "fsdp": partial(utils.wrap_fsdp, layer_to_wrap=FSDP_WRAP_LAYER), 26 | } 27 | 28 | 29 | def _setup_model_and_inputs(wrap_fn): 30 | rank = dist.get_rank() 31 | world_size = dist.get_world_size() 32 | 33 | # Construct single-gpu (regular) and multi-gpu (ddp) models. 34 | base_model = utils.TestModel() 35 | 36 | wrapped_model = copy.deepcopy(base_model) 37 | wrapped_model = wrap_fn(wrapped_model, pg=None) 38 | 39 | # Dummy inputs. 40 | base_inputs = torch.randn(4, 10, device="cuda", dtype=torch.float32) 41 | wrapped_inputs = base_inputs.chunk(world_size)[rank] 42 | 43 | # Set requires grad for backward hooks testing. 44 | base_inputs.requires_grad = True 45 | wrapped_inputs.requires_grad = True 46 | 47 | return base_model, base_inputs, wrapped_model, wrapped_inputs 48 | 49 | 50 | def _run_base_model(base_model, base_inputs, register_fn, acc): 51 | rank = dist.get_rank() 52 | base_pg = dist.new_group([rank]) 53 | 54 | fm_base_model = FlexModel(base_model, acc, process_group=base_pg) 55 | register_fn(fm_base_model) 56 | 57 | base_outputs = fm_base_model(base_inputs) 58 | 59 | return base_outputs 60 | 61 | 62 | def _run_wrapped_model(wrapped_model, wrapped_inputs, register_fn, acc): 63 | world_size = dist.get_world_size() 64 | 65 | # Use default pg for whole world. 66 | fm_base_model = FlexModel(wrapped_model, acc, data_parallel_size=world_size) 67 | register_fn(fm_base_model) 68 | 69 | wrapped_outputs = fm_base_model(wrapped_inputs) 70 | 71 | return wrapped_outputs 72 | 73 | 74 | def _run_base_and_wrapped_models( 75 | acc, edit_fn, hook_type, wrap_fn, run_backward=False 76 | ): 77 | ( 78 | base_model, 79 | base_inputs, 80 | wrapped_model, 81 | wrapped_inputs, 82 | ) = _setup_model_and_inputs(wrap_fn) 83 | 84 | register_fn = partial( 85 | utils.register_hook_functions, 86 | editing_function=edit_fn, 87 | hook_type=hook_type, 88 | module_name_to_shape_map={MODULE_NAME: MODULE_SHAPE}, 89 | ) 90 | 91 | # Regular model. 92 | base_register_fn = partial(register_fn, module_prefix="") 93 | base_outputs = _run_base_model( 94 | base_model, 95 | base_inputs, 96 | base_register_fn, 97 | acc, 98 | ) 99 | if run_backward: 100 | base_outputs.mean().backward() 101 | 102 | # Wrapped model. 103 | wrapped_register_fn = partial(register_fn, module_prefix="module.") 104 | wrapped_outputs = _run_wrapped_model( 105 | wrapped_model, 106 | wrapped_inputs, 107 | wrapped_register_fn, 108 | acc, 109 | ) 110 | if run_backward: 111 | wrapped_outputs.mean().backward() 112 | wrapped_outputs = utils.all_gather(wrapped_outputs, dim=0) 113 | 114 | return base_outputs, wrapped_outputs, base_model, wrapped_model 115 | 116 | 117 | @register_wrapper_test 118 | def test_forward_hooks_wrapped(): 119 | utils.init_process_group() 120 | 121 | acc = {} 122 | 123 | def _edit(module, outputs, save_ctx, trainable_modules): 124 | return outputs * 2 125 | 126 | for wrapper_name, wrapper_fn in WRAPPERS.items(): 127 | acc[wrapper_name] = {} 128 | base_outputs, wrapped_outputs, _, _ = _run_base_and_wrapped_models( 129 | acc[wrapper_name], 130 | _edit, 131 | "register_forward_hook", 132 | wrapper_fn, 133 | ) 134 | 135 | # Validation. 136 | torch.testing.assert_close(base_outputs, wrapped_outputs) 137 | torch.testing.assert_close( 138 | acc[wrapper_name][MODULE_NAME][0], 139 | acc[wrapper_name][f"module.{MODULE_NAME}"][0], 140 | ) 141 | utils.print_success(f"test_forward_hooks [{wrapper_name}]") 142 | 143 | 144 | @register_wrapper_test 145 | def test_full_backward_hooks_wrapped(): 146 | utils.init_process_group() 147 | 148 | acc = {} 149 | 150 | def _edit(module, grad_inputs, save_ctx, trainable_modules): 151 | return grad_inputs * 2 152 | 153 | for wrapper_name, wrapper_fn in WRAPPERS.items(): 154 | acc[wrapper_name] = {} 155 | ( 156 | base_outputs, 157 | wrapped_outputs, 158 | base_model, 159 | wrapped_model, 160 | ) = _run_base_and_wrapped_models( 161 | acc[wrapper_name], 162 | _edit, 163 | "register_full_backward_hook", 164 | wrapper_fn, 165 | run_backward=True, 166 | ) 167 | 168 | # Validate. 169 | torch.testing.assert_close(base_outputs, wrapped_outputs) 170 | 171 | torch.testing.assert_close( 172 | acc[wrapper_name][MODULE_NAME][0], 173 | acc[wrapper_name][f"module.{MODULE_NAME}"][0], 174 | ) 175 | 176 | if wrapper_name == "ddp": 177 | for (base_name, base_param), (wrapped_name, wrapped_param) in zip( 178 | base_model.named_parameters(), wrapped_model.named_parameters() 179 | ): 180 | torch.testing.assert_close(base_param, wrapped_param) 181 | torch.testing.assert_close(base_param.grad, wrapped_param.grad) 182 | 183 | utils.print_success(f"test_full_backward_hooks [{wrapper_name}]") 184 | 185 | 186 | @register_wrapper_test 187 | def test_forward_pre_hooks_wrapped(): 188 | utils.init_process_group() 189 | 190 | acc = {} 191 | 192 | def _edit(module, inputs, save_ctx, trainable_modules): 193 | return inputs * 2 194 | 195 | for wrapper_name, wrapper_fn in WRAPPERS.items(): 196 | acc[wrapper_name] = {} 197 | base_outputs, wrapped_outputs, _, _ = _run_base_and_wrapped_models( 198 | acc[wrapper_name], 199 | _edit, 200 | "register_forward_hook", 201 | wrapper_fn, 202 | ) 203 | 204 | # Validation. 205 | torch.testing.assert_close(base_outputs, wrapped_outputs) 206 | torch.testing.assert_close( 207 | acc[wrapper_name][MODULE_NAME][0], 208 | acc[wrapper_name][f"module.{MODULE_NAME}"][0], 209 | ) 210 | utils.print_success(f"test_forward_pre_hooks [{wrapper_name}]") 211 | 212 | 213 | @register_wrapper_test 214 | def test_full_backward_pre_hooks_wrapped(): 215 | utils.init_process_group() 216 | 217 | acc = {} 218 | 219 | def _edit(module, grad_outputs, save_ctx, trainable_modules): 220 | return grad_outputs * 2 221 | 222 | for wrapper_name, wrapper_fn in WRAPPERS.items(): 223 | acc[wrapper_name] = {} 224 | ( 225 | base_outputs, 226 | wrapped_outputs, 227 | base_model, 228 | wrapped_model, 229 | ) = _run_base_and_wrapped_models( 230 | acc[wrapper_name], 231 | _edit, 232 | "register_full_backward_hook", 233 | wrapper_fn, 234 | run_backward=True, 235 | ) 236 | 237 | # Validate. 238 | torch.testing.assert_close(base_outputs, wrapped_outputs) 239 | 240 | torch.testing.assert_close( 241 | acc[wrapper_name][MODULE_NAME][0], 242 | acc[wrapper_name][f"module.{MODULE_NAME}"][0], 243 | ) 244 | 245 | if wrapper_name == "ddp": 246 | for (base_name, base_param), (wrapped_name, wrapped_param) in zip( 247 | base_model.named_parameters(), wrapped_model.named_parameters() 248 | ): 249 | torch.testing.assert_close(base_param, wrapped_param) 250 | torch.testing.assert_close(base_param.grad, wrapped_param.grad) 251 | 252 | utils.print_success( 253 | f"test_full_backward_pre_hooks [{wrapper_name}]" 254 | ) 255 | -------------------------------------------------------------------------------- /_test/multi_gpu/core/test_fairscale_mpu_vs_ddp.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import fairscale.nn.model_parallel as mpu 4 | import torch 5 | import torch.distributed as dist 6 | 7 | from flex_model.core import FlexModel 8 | from _test.multi_gpu.registry import SlurmJobResourceSpec, make_test_registry 9 | import _test.multi_gpu.testing_utils as utils 10 | 11 | 12 | ( 13 | register_fairscale_megatron_test, 14 | get_fairscale_megatron_test, 15 | ) = make_test_registry( 16 | "fairscale_megatron", 17 | SlurmJobResourceSpec(), 18 | ) 19 | 20 | TP_SIZE = 2 21 | PP_SIZE = 1 22 | HIDDEN = 10 23 | VOCAB = 420 24 | EXPANSION = 2 25 | HOOK_ACTIVATIONS = { 26 | "register_forward_hook": { 27 | "vocab_parallel_embedding": None, 28 | "parallel_embedding": None, 29 | "column_parallel_linear": (None, None, HIDDEN * EXPANSION), 30 | "row_parallel_linear": None, 31 | }, 32 | "register_full_backward_hook": { 33 | "vocab_parallel_embedding": None, 34 | "parallel_embedding": None, 35 | "column_parallel_linear": (None, None, HIDDEN * EXPANSION), 36 | "row_parallel_linear": None, 37 | }, 38 | "register_hook": { 39 | "vocab_parallel_embedding.weight": (VOCAB, None), 40 | "parallel_embedding.weight": (None, HIDDEN), 41 | "column_parallel_linear.weight": (HIDDEN * EXPANSION, None), 42 | "row_parallel_linear.weight": (None, HIDDEN * EXPANSION), 43 | }, 44 | "register_forward_pre_hook": { 45 | "vocab_parallel_embedding": None, 46 | "parallel_embedding": None, 47 | "column_parallel_linear": None, 48 | "row_parallel_linear": (None, None, HIDDEN * EXPANSION), 49 | }, 50 | "register_full_backward_pre_hook": { 51 | "vocab_parallel_embedding": None, 52 | "parallel_embedding": None, 53 | "column_parallel_linear": (None, None, HIDDEN * EXPANSION), 54 | "row_parallel_linear": None, 55 | }, 56 | } 57 | 58 | 59 | @register_fairscale_megatron_test 60 | def test_fairscale(): 61 | utils.init_fairscale_mpu(1, 1) 62 | 63 | model = utils.TestFairscaleModel(HIDDEN, VOCAB) 64 | inputs = torch.randint(0, 420, size=(4, 42)).cuda() 65 | 66 | outputs = model(inputs) 67 | 68 | utils.destroy_fairscale_mpu() 69 | 70 | utils.init_fairscale_mpu(TP_SIZE, PP_SIZE) 71 | 72 | sharded_inputs = inputs.chunk(mpu.get_data_parallel_world_size())[ 73 | mpu.get_data_parallel_rank() 74 | ] 75 | ddp_tp_model = utils.TestFairscaleModel(HIDDEN, VOCAB) 76 | ddp_tp_model.copy_state_from_unsharded(model, other_tp_world_size=1) 77 | ddp_tp_model = utils.wrap_ddp( 78 | ddp_tp_model, 79 | pg=mpu.get_data_parallel_group(), 80 | ) 81 | 82 | dist_outputs = ddp_tp_model(sharded_inputs) 83 | dist_outputs = utils.all_gather( 84 | dist_outputs, pg=mpu.get_data_parallel_group() 85 | ) 86 | 87 | torch.testing.assert_close(outputs, dist_outputs) 88 | 89 | utils.destroy_fairscale_mpu() 90 | utils.print_success("test_fairscale") 91 | 92 | 93 | def _run_ddp_model(register_fn, acc, run_backward=False): 94 | utils.init_fairscale_mpu(1, 1) 95 | 96 | # Base model. 97 | model = utils.TestFairscaleModel(HIDDEN, VOCAB) 98 | 99 | # DDP model. 100 | ddp_model = utils.wrap_ddp(model, pg=mpu.get_data_parallel_group()) 101 | 102 | fm_ddp_model = FlexModel( 103 | ddp_model, acc, data_parallel_size=dist.get_world_size() 104 | ) 105 | register_fn(fm_ddp_model) 106 | inputs = torch.randint(0, 420, size=(4, 42)).cuda() 107 | sharded_inputs = inputs.chunk(mpu.get_data_parallel_world_size())[ 108 | mpu.get_data_parallel_rank() 109 | ] 110 | 111 | ddp_outputs = fm_ddp_model(sharded_inputs) 112 | if run_backward: 113 | ddp_outputs.mean().backward() 114 | ddp_outputs = utils.all_gather( 115 | ddp_outputs, pg=mpu.get_data_parallel_group() 116 | ) 117 | 118 | return ddp_model, inputs, ddp_outputs 119 | 120 | 121 | def _run_ddp_tp_model(model, inputs, register_fn, acc, run_backward=False): 122 | utils.init_fairscale_mpu(TP_SIZE, PP_SIZE) 123 | 124 | # TP + DDP model. 125 | ddp_tp_model = utils.TestFairscaleModel(HIDDEN, VOCAB) 126 | ddp_tp_model.copy_state_from_unsharded(model, other_tp_world_size=1) 127 | ddp_tp_model = utils.wrap_ddp( 128 | ddp_tp_model, 129 | pg=mpu.get_data_parallel_group(), 130 | ) 131 | fm_ddp_tp_model = FlexModel( 132 | ddp_tp_model, 133 | acc, 134 | tensor_parallel_size=TP_SIZE, 135 | data_parallel_size=dist.get_world_size() // TP_SIZE, 136 | ) 137 | register_fn(fm_ddp_tp_model) 138 | 139 | sharded_inputs = inputs.chunk(mpu.get_data_parallel_world_size())[ 140 | mpu.get_data_parallel_rank() 141 | ] 142 | 143 | ddp_tp_outputs = fm_ddp_tp_model(sharded_inputs) 144 | if run_backward: 145 | ddp_tp_outputs.mean().backward() 146 | dist_outputs = utils.all_gather( 147 | ddp_tp_outputs, pg=mpu.get_data_parallel_group() 148 | ) 149 | 150 | return ddp_tp_model, dist_outputs 151 | 152 | 153 | def _run_ddp_and_ddp_tp_models(acc, edit_fn, hook_type, run_backward=False): 154 | register_fn = partial( 155 | utils.register_hook_functions, 156 | editing_function=edit_fn, 157 | hook_type=hook_type, 158 | module_name_to_shape_map=HOOK_ACTIVATIONS[hook_type], 159 | module_prefix="module.", 160 | ) 161 | 162 | # DDP model outputs and states. 163 | ddp_model, inputs, ddp_outputs = _run_ddp_model( 164 | register_fn, acc, run_backward 165 | ) 166 | ddp_model_states = ddp_model.module.get_unsharded_params_and_grads() 167 | 168 | # Reset MPU states. 169 | utils.destroy_fairscale_mpu() 170 | 171 | # DDP + TP model outputs and states. 172 | ddp_tp_model, ddp_tp_outputs = _run_ddp_tp_model( 173 | ddp_model.module, inputs, register_fn, acc, run_backward 174 | ) 175 | ddp_tp_model_states = ddp_tp_model.module.get_unsharded_params_and_grads() 176 | 177 | return ddp_model_states, ddp_tp_model_states, ddp_outputs, ddp_tp_outputs 178 | 179 | 180 | @register_fairscale_megatron_test 181 | def test_forward_hooks_fairscale(): 182 | # Fm params. 183 | acc = {} 184 | 185 | def _edit(module, outputs, save_ctx, trainable_modules): 186 | return outputs * 2 187 | 188 | _, _, ddp_outputs, ddp_tp_outputs = _run_ddp_and_ddp_tp_models( 189 | acc, 190 | _edit, 191 | hook_type="register_forward_hook", 192 | ) 193 | 194 | # Validate. 195 | torch.testing.assert_close(ddp_outputs, ddp_tp_outputs) 196 | 197 | for acts in acc.values(): 198 | torch.testing.assert_close(acts[0], acts[1]) 199 | 200 | utils.destroy_fairscale_mpu() 201 | utils.print_success("test_forward_hooks_fairscale") 202 | 203 | 204 | @register_fairscale_megatron_test 205 | def test_full_backward_hooks_fairscale(): 206 | # Fm params. 207 | acc = {} 208 | 209 | def _edit(module, grad_inputs, save_ctx, trainable_modules): 210 | return grad_inputs * 2 211 | 212 | ( 213 | ddp_model_states, 214 | ddp_tp_model_states, 215 | ddp_outputs, 216 | ddp_tp_outputs, 217 | ) = _run_ddp_and_ddp_tp_models( 218 | acc, 219 | _edit, 220 | hook_type="register_full_backward_hook", 221 | run_backward=True, 222 | ) 223 | 224 | # Validate. 225 | torch.testing.assert_close(ddp_outputs, ddp_tp_outputs) 226 | 227 | utils.assert_same_state(ddp_model_states, ddp_tp_model_states) 228 | 229 | utils.destroy_fairscale_mpu() 230 | utils.print_success("test_full_backward_hooks_fairscale") 231 | 232 | 233 | @register_fairscale_megatron_test 234 | def test_forward_pre_hooks_fairscale(): 235 | # Fm params. 236 | acc = {} 237 | 238 | def _edit(module, inputs, save_ctx, trainable_modules): 239 | return torch.where(inputs < 200, inputs + 10, inputs) 240 | 241 | _, _, ddp_outputs, ddp_tp_outputs = _run_ddp_and_ddp_tp_models( 242 | acc, 243 | _edit, 244 | hook_type="register_forward_pre_hook", 245 | ) 246 | 247 | # Validate. 248 | torch.testing.assert_close(ddp_outputs, ddp_tp_outputs) 249 | 250 | for acts in acc.values(): 251 | torch.testing.assert_close(acts[0], acts[1]) 252 | 253 | utils.destroy_fairscale_mpu() 254 | utils.print_success("test_forward_pre_hooks_fairscale") 255 | 256 | 257 | @register_fairscale_megatron_test 258 | def test_full_backward_pre_hooks_fairscale(): 259 | # Fm params. 260 | acc = {} 261 | 262 | def _edit(module, grad_outputs, save_ctx, trainable_modules): 263 | return grad_outputs * 2 264 | 265 | ( 266 | ddp_model_states, 267 | ddp_tp_model_states, 268 | ddp_outputs, 269 | ddp_tp_outputs, 270 | ) = _run_ddp_and_ddp_tp_models( 271 | acc, 272 | _edit, 273 | hook_type="register_full_backward_hook", 274 | run_backward=True, 275 | ) 276 | 277 | # Validate. 278 | torch.testing.assert_close(ddp_outputs, ddp_tp_outputs) 279 | 280 | utils.assert_same_state(ddp_model_states, ddp_tp_model_states) 281 | 282 | utils.destroy_fairscale_mpu() 283 | utils.print_success("test_full_backward_pre_hooks_fairscale") 284 | -------------------------------------------------------------------------------- /_test/multi_gpu/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_initialize import test_initialize_distributed_state 2 | 3 | from .test_mappings import ( 4 | test_all_gather_data_parallel, 5 | test_all_gather_tensor_parallel, 6 | test_batch_isend_irecv_pipeline_parallel, 7 | test_broadcast_data_parallel, 8 | test_broadcast_tensor_parallel, 9 | test_gather_pipeline_parallel_base, 10 | test_gather_pipeline_parallel_dtypes, 11 | test_scatter_data_parallel, 12 | test_scatter_tensor_parallel, 13 | ) 14 | -------------------------------------------------------------------------------- /_test/multi_gpu/distributed/test_initialize.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.distributed as dist 3 | 4 | from flex_model.distributed.distributed_state import ( 5 | _global_state_is_initialized, 6 | ) 7 | import flex_model.distributed as fm_dist 8 | from _test.multi_gpu.registry import SlurmJobResourceSpec, make_test_registry 9 | import _test.multi_gpu.testing_utils as utils 10 | 11 | register_initialize_test, get_initialize_test = make_test_registry( 12 | "initialize", 13 | SlurmJobResourceSpec(), 14 | ) 15 | 16 | 17 | @register_initialize_test 18 | def test_initialize_distributed_state(): 19 | utils.init_process_group() 20 | 21 | # model = utils.TestModel().cuda() 22 | # model = utils.wrap_ddp(model, rank=dist.get_rank()) 23 | model = nn.Linear(2, 4) 24 | 25 | fmps = fm_dist.initialize_distributed_state( 26 | model, 1, 1, dist.get_world_size() 27 | ) 28 | assert _global_state_is_initialized() 29 | 30 | assert fmps.get_local_rank() == dist.get_rank() 31 | assert fmps.get_local_world_size() == dist.get_world_size() 32 | assert fmps.get_subset_rank() == dist.get_rank() 33 | assert fmps.get_subset_world_size() == dist.get_world_size() 34 | assert fmps.get_tensor_parallel_rank() == 0 35 | assert fmps.get_pipeline_parallel_rank() == 0 36 | assert fmps.get_data_parallel_rank() == dist.get_rank() 37 | assert fmps.get_tensor_parallel_world_size() == 1 38 | assert fmps.get_pipeline_parallel_world_size() == 1 39 | assert fmps.get_data_parallel_world_size() == dist.get_world_size() 40 | 41 | utils.print_success("test_initialize_distributed_state") 42 | 43 | 44 | @register_initialize_test 45 | def test_initialize_multiple_models(): 46 | utils.init_process_group() 47 | 48 | model_1 = nn.Linear(2, 4) 49 | 50 | fmps_1 = fm_dist.initialize_distributed_state( 51 | model_1, 1, 1, dist.get_world_size() 52 | ) 53 | 54 | model_2 = nn.Linear(3, 5) 55 | 56 | fmps_2 = fm_dist.initialize_distributed_state( 57 | model_2, dist.get_world_size(), 1, 1 58 | ) 59 | 60 | assert fmps_1.get_local_rank() == dist.get_rank() 61 | assert fmps_1.get_local_world_size() == dist.get_world_size() 62 | assert fmps_1.get_subset_rank() == dist.get_rank() 63 | assert fmps_1.get_subset_world_size() == dist.get_world_size() 64 | assert fmps_1.get_tensor_parallel_rank() == 0 65 | assert fmps_1.get_pipeline_parallel_rank() == 0 66 | assert fmps_1.get_data_parallel_rank() == dist.get_rank() 67 | assert fmps_1.get_tensor_parallel_world_size() == 1 68 | assert fmps_1.get_pipeline_parallel_world_size() == 1 69 | assert fmps_1.get_data_parallel_world_size() == dist.get_world_size() 70 | 71 | assert fmps_2.get_local_rank() == dist.get_rank() 72 | assert fmps_2.get_local_world_size() == dist.get_world_size() 73 | assert fmps_2.get_subset_rank() == dist.get_rank() 74 | assert fmps_2.get_subset_world_size() == dist.get_world_size() 75 | assert fmps_2.get_tensor_parallel_rank() == dist.get_rank() 76 | assert fmps_2.get_pipeline_parallel_rank() == 0 77 | assert fmps_2.get_data_parallel_rank() == 0 78 | assert fmps_2.get_tensor_parallel_world_size() == dist.get_world_size() 79 | assert fmps_2.get_pipeline_parallel_world_size() == 1 80 | assert fmps_2.get_data_parallel_world_size() == 1 81 | 82 | utils.print_success("test_initialize_multiple_models") 83 | -------------------------------------------------------------------------------- /_test/multi_gpu/distributed/test_mappings.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import flex_model.distributed as fm_dist 7 | from _test.multi_gpu.registry import SlurmJobResourceSpec, make_test_registry 8 | import _test.multi_gpu.testing_utils as utils 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | _NUM_GPUS = 2 14 | 15 | 16 | register_mappings_test, get_mappings_test = make_test_registry( 17 | "mappings", 18 | SlurmJobResourceSpec( 19 | gpus_per_node=_NUM_GPUS, 20 | ntasks_per_node=_NUM_GPUS, 21 | ), 22 | ) 23 | 24 | 25 | @register_mappings_test 26 | def test_broadcast_tensor_parallel(): 27 | utils.init_process_group() 28 | 29 | model = nn.Linear(2, 4) 30 | fmps = fm_dist.initialize_distributed_state(model, _NUM_GPUS, 1, 1) 31 | 32 | if fmps.get_tensor_parallel_rank() == 0: 33 | tensor_to_bcast = torch.ones((1)).cuda() 34 | else: 35 | tensor_to_bcast = torch.zeros((1,)).cuda() 36 | result = fm_dist.broadcast_tensor_parallel(tensor_to_bcast, fmps) 37 | assert torch.equal(result, torch.ones((1)).cuda()) 38 | utils.print_success("test_broadcast_tensor_parallel") 39 | 40 | 41 | @register_mappings_test 42 | def test_broadcast_data_parallel(): 43 | utils.init_process_group() 44 | 45 | model = nn.Linear(2, 4) 46 | fmps = fm_dist.initialize_distributed_state(model, 1, 1, _NUM_GPUS) 47 | 48 | if fmps.get_data_parallel_rank() == 0: 49 | tensor_to_bcast = torch.ones((1)).cuda() 50 | else: 51 | tensor_to_bcast = torch.zeros((1,)).cuda() 52 | result = fm_dist.broadcast_data_parallel(tensor_to_bcast, fmps) 53 | assert torch.equal(result, torch.ones((1)).cuda()) 54 | utils.print_success("test_broadcast_data_parallel") 55 | 56 | 57 | @register_mappings_test 58 | def test_all_gather_tensor_parallel(): 59 | utils.init_process_group() 60 | 61 | model = nn.Linear(2, 4) 62 | fmps = fm_dist.initialize_distributed_state(model, _NUM_GPUS, 1, 1) 63 | 64 | tensor_to_gather = torch.ones((1)).cuda() * fmps.get_tensor_parallel_rank() 65 | result = fm_dist.all_gather_tensor_parallel(tensor_to_gather, 0, fmps) 66 | assert torch.equal( 67 | result, 68 | torch.arange(fmps.get_tensor_parallel_world_size()).cuda(), 69 | ) 70 | utils.print_success("test_all_gather_tensor_parallel") 71 | 72 | 73 | @register_mappings_test 74 | def test_all_gather_data_parallel(): 75 | utils.init_process_group() 76 | 77 | model = nn.Linear(2, 4) 78 | fmps = fm_dist.initialize_distributed_state(model, 1, 1, _NUM_GPUS) 79 | 80 | tensor_to_gather = torch.ones((1)).cuda() * fmps.get_data_parallel_rank() 81 | result = fm_dist.all_gather_data_parallel(tensor_to_gather, 0, fmps) 82 | assert torch.equal( 83 | result, 84 | torch.arange(fmps.get_data_parallel_world_size()).cuda(), 85 | ) 86 | utils.print_success("test_all_gather_data_parallel") 87 | 88 | 89 | @register_mappings_test 90 | def test_scatter_tensor_parallel(): 91 | utils.init_process_group() 92 | 93 | model = nn.Linear(2, 4) 94 | fmps = fm_dist.initialize_distributed_state(model, _NUM_GPUS, 1, 1) 95 | 96 | tensor_to_scatter = torch.arange( 97 | fmps.get_tensor_parallel_world_size() 98 | ).cuda() 99 | result = fm_dist.scatter_tensor_parallel(tensor_to_scatter, 0, fmps) 100 | assert torch.equal( 101 | result, 102 | torch.ones((1)).cuda() * fmps.get_tensor_parallel_rank(), 103 | ) 104 | utils.print_success("test_scatter_tensor_parallel") 105 | 106 | 107 | @register_mappings_test 108 | def test_scatter_data_parallel(): 109 | utils.init_process_group() 110 | 111 | model = nn.Linear(2, 4) 112 | fmps = fm_dist.initialize_distributed_state(model, 1, 1, _NUM_GPUS) 113 | 114 | tensor_to_scatter = torch.arange(fmps.get_data_parallel_world_size()).cuda() 115 | result = fm_dist.scatter_data_parallel(tensor_to_scatter, 0, fmps) 116 | assert torch.equal( 117 | result, 118 | torch.ones((1)).cuda() * fmps.get_data_parallel_rank(), 119 | ) 120 | utils.print_success("test_scatter_data_parallel") 121 | 122 | 123 | @register_mappings_test 124 | def test_batch_isend_irecv_pipeline_parallel(): 125 | utils.init_process_group() 126 | 127 | model = nn.Linear(2, 4) 128 | fmps = fm_dist.initialize_distributed_state(model, 1, _NUM_GPUS, 1) 129 | 130 | rank = fmps.get_pipeline_parallel_rank() 131 | world_size = fmps.get_pipeline_parallel_world_size() 132 | 133 | send_tensors = [torch.ones((1,)).cuda() * rank] 134 | send_to_ranks = [(rank + 1) % world_size] 135 | recv_tensors = [torch.empty((1,)).cuda()] 136 | recv_from_ranks = [(rank + 1) % world_size] 137 | 138 | fm_dist.batch_isend_irecv_pipeline_parallel( 139 | fmps, 140 | recv_tensors, 141 | recv_from_ranks, 142 | send_tensors, 143 | send_to_ranks, 144 | ) 145 | 146 | for tensor in recv_tensors: 147 | assert torch.equal( 148 | tensor, 149 | torch.ones((1,)).cuda() * (rank + 1) % world_size, 150 | ) 151 | utils.print_success("test_batch_isend_irecv_pipeline_parallel") 152 | 153 | 154 | @register_mappings_test 155 | def test_gather_pipeline_parallel_base(): 156 | utils.init_process_group() 157 | 158 | model = nn.Linear(2, 4) 159 | fmps = fm_dist.initialize_distributed_state(model, 1, _NUM_GPUS, 1) 160 | 161 | rank = fmps.get_pipeline_parallel_rank() 162 | world_size = fmps.get_pipeline_parallel_world_size() 163 | 164 | # Test on empty data. 165 | tensor_dict = {} 166 | result = fm_dist.gather_pipeline_parallel_tensor_dicts(fmps, tensor_dict) 167 | assert len(result) == 0 168 | 169 | # Test on multiple tensors per rank. 170 | tensor_dict = {} 171 | tensors_per_rank = 4 172 | for i in range(tensors_per_rank): 173 | tensor_idx = rank * tensors_per_rank + i 174 | tensor_dict[f"tensor_{tensor_idx}"] = torch.ones((1,)) * tensor_idx 175 | 176 | result = fm_dist.gather_pipeline_parallel_tensor_dicts(fmps, tensor_dict) 177 | 178 | if rank == 0: 179 | assert len(result) == tensors_per_rank * world_size 180 | for tensor_idx in range(world_size * tensors_per_rank): 181 | assert torch.equal( 182 | result[f"tensor_{tensor_idx}"], 183 | torch.ones((1,)) * tensor_idx, 184 | ) 185 | utils.print_success("test_gather_pipeline_parallel_base") 186 | 187 | 188 | @register_mappings_test 189 | def test_gather_pipeline_parallel_dtypes(): 190 | utils.init_process_group() 191 | 192 | model = nn.Linear(2, 4) 193 | fmps = fm_dist.initialize_distributed_state(model, 1, _NUM_GPUS, 1) 194 | 195 | rank = fmps.get_pipeline_parallel_rank() 196 | world_size = fmps.get_pipeline_parallel_world_size() 197 | 198 | tensor_dict = {} 199 | tensors_per_rank = 4 200 | dtypes = [torch.float32, torch.float16, torch.bfloat16] 201 | for dtype in dtypes: 202 | for i in range(tensors_per_rank): 203 | tensor_idx = rank * tensors_per_rank + i 204 | name = f"tensor_{tensor_idx}_{dtype}" 205 | tensor = torch.ones((1,), dtype=dtype) 206 | tensor_dict[name] = tensor 207 | 208 | result = fm_dist.gather_pipeline_parallel_tensor_dicts(fmps, tensor_dict) 209 | 210 | if rank == 0: 211 | assert len(result) == tensors_per_rank * world_size * len(dtypes) 212 | for dtype in dtypes: 213 | for i in range(tensors_per_rank): 214 | tensor_idx = rank * tensors_per_rank + i 215 | name = f"tensor_{tensor_idx}_{dtype}" 216 | tensor = torch.ones((1,), dtype=dtype) 217 | assert torch.equal( 218 | result[name], 219 | tensor, 220 | ) 221 | utils.print_success("test_gather_pipeline_parallel_dtypes") 222 | -------------------------------------------------------------------------------- /_test/multi_gpu/registry.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | _MULTIGPU_TESTS_REGISTRY = None 5 | _MULTIGPU_RESOURCE_SPECS = None 6 | 7 | 8 | @dataclass 9 | class SlurmJobResourceSpec: 10 | """Resource specification for a single slurm job.""" 11 | 12 | partition: str = "a100" 13 | qos: str = "a100_mchoi" 14 | time: int = 5 15 | mem: Optional[str] = None 16 | mem_per_gpu: str = "32G" 17 | nodes: int = 1 18 | gpus_per_node: int = 4 19 | ntasks_per_node: int = 4 20 | cpus_per_task: int = 6 21 | 22 | def __post_init__(self): 23 | # Providing mem overrides mem_per_gpu. 24 | if self.mem is not None: 25 | self.mem_per_gpu = None 26 | 27 | 28 | def make_test_registry( 29 | registry_name, 30 | resource_spec: SlurmJobResourceSpec = None, 31 | ): 32 | global _MULTIGPU_TESTS_REGISTRY 33 | global _MULTIGPU_RESOURCE_SPECS 34 | if _MULTIGPU_TESTS_REGISTRY is None: 35 | _MULTIGPU_TESTS_REGISTRY = {} 36 | if _MULTIGPU_RESOURCE_SPECS is None: 37 | _MULTIGPU_RESOURCE_SPECS = {} 38 | 39 | # Defaults. 40 | if resource_spec is None: 41 | resource_spec = SlurmJobResourceSpec() 42 | 43 | _MULTIGPU_TESTS_REGISTRY[registry_name] = {} 44 | _MULTIGPU_RESOURCE_SPECS[registry_name] = resource_spec 45 | 46 | def _register_fn(fn): 47 | """Register a test to run in a multi-gpu setting.""" 48 | fn_name = fn.__name__ 49 | _MULTIGPU_TESTS_REGISTRY[registry_name][fn_name] = fn 50 | 51 | return fn 52 | 53 | def _get_fn(): 54 | assert ( 55 | _MULTIGPU_TESTS_REGISTRY is not None 56 | ), "Multi-gpu test registry is uninitialized or empty" 57 | return _MULTIGPU_TESTS_REGISTRY[registry_name] 58 | 59 | return _register_fn, _get_fn 60 | 61 | 62 | def get_multigpu_test_registry(): 63 | global _MULTIGPU_TESTS_REGISTRY 64 | assert _MULTIGPU_TESTS_REGISTRY is not None 65 | return _MULTIGPU_TESTS_REGISTRY 66 | 67 | 68 | def get_multigpu_resource_specs(): 69 | global _MULTIGPU_RESOURCE_SPECS 70 | assert _MULTIGPU_RESOURCE_SPECS is not None 71 | return _MULTIGPU_RESOURCE_SPECS 72 | -------------------------------------------------------------------------------- /_test/multi_gpu/run_multi_gpu_tests_slurm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from dataclasses import asdict 4 | 5 | import submitit 6 | 7 | from _test.multi_gpu.registry import ( 8 | SlurmJobResourceSpec, 9 | get_multigpu_resource_specs, 10 | get_multigpu_test_registry, 11 | ) 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--test_name", type=str) 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | class TorchDistributedTestBatch: 22 | def __init__(self, test_functions): 23 | self.test_functions = test_functions 24 | 25 | def __repr__(self): 26 | repr_ = "" 27 | for name in self.test_functions.keys(): 28 | repr_ = repr_ + "\n\t" + name 29 | return repr_ 30 | 31 | def __call__(self): 32 | # Setup torch distributed environment. 33 | # NOTE: Allow each process to see all GPUs. Else we cannot properly call 34 | # `torch.cuda.set_device(N)` when N > 0 since all processes will only 35 | # have `CUDA_VISIBLE_DEVICES=0`. 36 | dist_env = submitit.helpers.TorchDistributedEnvironment().export( 37 | set_cuda_visible_devices=False 38 | ) 39 | os.environ["NCCL_IB_DISABLE"] = "1" 40 | 41 | # Print distributed environment details. 42 | print(f"MASTER: {dist_env.master_addr}:{dist_env.master_port}") 43 | print(f"RANK: {dist_env.rank}") 44 | print(f"WORLD_SIZE: {dist_env.world_size}") 45 | print(f"LOCAL_RANK: {dist_env.local_rank}") 46 | print(f"LOCAL_WORLD_SIZE: {dist_env.local_world_size}") 47 | 48 | # Run each test in batch. 49 | test_results = {} 50 | for name, fn in self.test_functions.items(): 51 | try: 52 | print(f"Running test: {name}") 53 | _res = fn() 54 | if dist_env.rank == 0: 55 | test_results[name] = 0 # Success. 56 | 57 | # Test failure, record and continue. 58 | except AssertionError as assert_e: 59 | if dist_env.rank == 0: 60 | test_results[name] = 1 # Failure. 61 | print("AssertionError detected!") 62 | print(assert_e) 63 | 64 | # Code or other error, crash. 65 | except Exception as e: 66 | print("Non-test exception detected!") 67 | print(e) 68 | raise SystemExit 69 | 70 | return test_results 71 | 72 | 73 | class MultiGPUSlurmJob: 74 | def __init__( 75 | self, 76 | test_batch: TorchDistributedTestBatch, 77 | resource_spec: SlurmJobResourceSpec, 78 | log_dir: str = "multi_gpu_test_logs", 79 | ): 80 | self.res_spec = resource_spec 81 | self.test_batch = test_batch 82 | self.log_dir = log_dir 83 | 84 | def run(self): 85 | slurm_params = asdict(self.res_spec) 86 | python = slurm_params.pop("python", None) 87 | 88 | executor = submitit.SlurmExecutor(folder=self.log_dir, python=python) 89 | 90 | executor.update_parameters(**slurm_params) 91 | 92 | job = executor.submit(self.test_batch) 93 | 94 | submitit.helpers.monitor_jobs([job]) 95 | 96 | results = job.results()[0] 97 | 98 | print("\n") 99 | for test_name, result in results.items(): 100 | status = "SUCCESS" if result == 0 else "FAILURE" 101 | print(f"{test_name}: {result} ({status})") 102 | print("\n") 103 | 104 | return 0 105 | 106 | 107 | def make_test_jobs(args): 108 | """Makes the test jobs for launching via submitit. 109 | 110 | First creates the `TorchDistributedTestBatch`, which is comprised of a 111 | collection of test functions to run. Then creates the `MultiGPUSlurmJob` 112 | using the `TorchDistributedTestBatch` and `SlurmJobResourceSpec`. 113 | """ 114 | # Get test fuction registries and their respective resource specs. 115 | 116 | test_registries = get_multigpu_test_registry() 117 | test_resource_specs = get_multigpu_resource_specs() 118 | 119 | slurm_jobs = {} 120 | print("Created test batches:") 121 | for test_reg_name, test_reg_fns in test_registries.items(): 122 | job = MultiGPUSlurmJob( 123 | TorchDistributedTestBatch(test_reg_fns), 124 | test_resource_specs[test_reg_name], 125 | ) 126 | 127 | if args.test_name is None or args.test_name == test_reg_name: 128 | print(f"{test_reg_name}: {job.test_batch}") 129 | slurm_jobs[test_reg_name] = job 130 | 131 | return slurm_jobs 132 | 133 | 134 | def main(args): 135 | # Import folders to register tests. 136 | from _test.multi_gpu import core, distributed # noqa: F401 137 | 138 | jobs = make_test_jobs(args) 139 | 140 | # Run each job. 141 | # TODO: Launch all at once. 142 | for job in jobs.values(): 143 | job.run() 144 | 145 | 146 | if __name__ == "__main__": 147 | args = parse_args() 148 | main(args) 149 | -------------------------------------------------------------------------------- /_test/multi_gpu/testing_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | from typing import Dict, Tuple 4 | 5 | import fairscale.nn.model_parallel as mpu 6 | import torch 7 | import torch.distributed as dist 8 | import torch.nn as nn 9 | from torch import Tensor 10 | from fairscale.nn.model_parallel.layers import ( 11 | ColumnParallelLinear, 12 | ParallelEmbedding, 13 | RowParallelLinear, 14 | VocabParallelEmbedding, 15 | ) 16 | from torch.distributed.fsdp import BackwardPrefetch, CPUOffload 17 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 18 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy 19 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 20 | from torch.nn.parallel import DistributedDataParallel as DDP 21 | 22 | from flex_model.core import HookFunction 23 | 24 | 25 | def print_success(test_name: str): 26 | rank = dist.get_rank() 27 | print(f"Rank{rank}: [{test_name}] - Test successful") 28 | 29 | 30 | def init_process_group(): 31 | torch.manual_seed(0) 32 | if dist.is_initialized(): 33 | return 34 | 35 | dist.init_process_group("nccl") 36 | 37 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 38 | 39 | if not local_rank == torch.cuda.current_device(): 40 | torch.cuda.set_device(local_rank) 41 | 42 | 43 | def init_fairscale_mpu(tp_size, pp_size): 44 | if not dist.is_initialized(): 45 | init_process_group() 46 | 47 | mpu.initialize_model_parallel( 48 | model_parallel_size_=tp_size, 49 | pipeline_length=pp_size, 50 | ) 51 | 52 | 53 | def destroy_fairscale_mpu(): 54 | mpu.destroy_model_parallel() 55 | dist.barrier() 56 | 57 | 58 | def destroy_process_group(): 59 | dist.destroy_process_group() 60 | 61 | 62 | def all_gather(tensor: Tensor, dim: int = 0, pg=None): 63 | world_size = dist.get_world_size(group=pg) 64 | 65 | tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] 66 | 67 | dist.all_gather(tensor_list, tensor, group=pg) 68 | 69 | return torch.cat(tensor_list, dim=dim) 70 | 71 | 72 | class TestModel(nn.Module): 73 | def __init__(self, device="cuda", dtype=torch.float32): 74 | super().__init__() 75 | self.fc1 = nn.Linear(10, 20, device=device, dtype=dtype) 76 | self.fc2 = nn.Linear(20, 10, device=device, dtype=dtype) 77 | 78 | def forward(self, inputs): 79 | return self.fc2(self.fc1(inputs)) 80 | 81 | 82 | class TestFairscaleModel(nn.Module): 83 | def __init__( 84 | self, 85 | hidden_size, 86 | vocab_size, 87 | expansion=2, 88 | device="cuda", 89 | dtype=torch.float32, 90 | ): 91 | super().__init__() 92 | self.hidden_size = hidden_size 93 | self.vocab_size = vocab_size 94 | self.expansion = expansion 95 | self.device = device 96 | self.dtyp = dtype 97 | 98 | # Vocab parallel and regular embedding 99 | self.vocab_parallel_embedding = ( 100 | VocabParallelEmbedding( 101 | self.vocab_size, 102 | self.hidden_size, 103 | ) 104 | .to(device) 105 | .to(dtype) 106 | ) 107 | 108 | # Parallel embedding and regular embedding 109 | self.parallel_embedding = ( 110 | ParallelEmbedding( 111 | self.vocab_size, 112 | self.hidden_size, 113 | ) 114 | .to(device) 115 | .to(dtype) 116 | ) 117 | 118 | # Column parallel linear and regular linear 119 | self.column_parallel_linear = ( 120 | ColumnParallelLinear( 121 | self.hidden_size, 122 | int(self.hidden_size * self.expansion), 123 | bias=False, 124 | gather_output=False, 125 | ) 126 | .to(device) 127 | .to(dtype) 128 | ) 129 | 130 | # Row parallel linear and regular linear 131 | self.row_parallel_linear = ( 132 | RowParallelLinear( 133 | int(self.hidden_size * self.expansion), 134 | self.hidden_size, 135 | bias=False, 136 | input_is_parallel=True, 137 | ) 138 | .to(device) 139 | .to(dtype) 140 | ) 141 | 142 | def get_unsharded_params_and_grads( 143 | self 144 | ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: 145 | """Get full parameter and gradient state tensors for each module weight.""" 146 | tp_group = mpu.get_model_parallel_group() 147 | tp_world_size = mpu.get_model_parallel_world_size() 148 | 149 | layer_to_shard_dim_map = { 150 | VocabParallelEmbedding: 0, 151 | ColumnParallelLinear: 0, 152 | ParallelEmbedding: 1, # Doesn't support negative indexing. 153 | RowParallelLinear: 1, 154 | } 155 | 156 | params = {} 157 | grads = {} 158 | for name, module in self.named_modules(): 159 | if isinstance(module, TestFairscaleModel): 160 | continue 161 | 162 | sharded_dim = layer_to_shard_dim_map[type(module)] 163 | 164 | if tp_world_size > 1: 165 | unsharded_param = all_gather( 166 | module.weight, 167 | dim=sharded_dim, 168 | pg=tp_group, 169 | ) 170 | unsharded_grad = None 171 | if module.weight.grad is not None: 172 | unsharded_grad = all_gather( 173 | module.weight.grad, 174 | dim=sharded_dim, 175 | pg=tp_group, 176 | ) 177 | else: 178 | unsharded_param = module.weight 179 | unsharded_grad = module.weight.grad 180 | 181 | params[f"{name}.weight"] = unsharded_param 182 | grads[f"{name}.weight.grad"] = unsharded_grad 183 | 184 | return params, grads 185 | 186 | def copy_state_from_unsharded( 187 | self, other: nn.Module, other_tp_world_size: int = 1 188 | ): 189 | """Copy the parameter states from an unsharded version of this model.""" 190 | tp_rank = mpu.get_model_parallel_rank() 191 | tp_world_size = mpu.get_model_parallel_world_size() 192 | 193 | # Guarantee local slice of param exists. 194 | assert other_tp_world_size < tp_world_size 195 | assert tp_world_size % other_tp_world_size == 0 196 | 197 | # Helpers 198 | def _resharded_param_dim(param, dim=-1): 199 | full_dim = param.shape[dim] * other_tp_world_size 200 | return full_dim // tp_world_size 201 | 202 | def _make_param_slice(start, end, shape, dim): 203 | slices = [] 204 | for i in range(len(shape)): 205 | if i == dim: 206 | slices.append(slice(start, end)) 207 | else: 208 | slices.append(slice(0, shape[i])) 209 | return tuple(slices) 210 | 211 | # Reshard each parameter. 212 | for (name, module), (other_name, other_module) in zip( 213 | self.named_modules(), other.named_modules() 214 | ): 215 | if isinstance(module, TestFairscaleModel): 216 | continue 217 | assert name == other_name 218 | assert type(module) == type(other_module) 219 | 220 | param = module.weight 221 | other_param = other_module.weight 222 | 223 | dim = ( 224 | 0 225 | if isinstance( 226 | module, (VocabParallelEmbedding, ColumnParallelLinear) 227 | ) 228 | else 1 229 | ) 230 | resharded_param_dim = _resharded_param_dim(other_param, dim=dim) 231 | start = ( 232 | tp_rank % other_tp_world_size + tp_rank 233 | ) * resharded_param_dim 234 | end = start + resharded_param_dim 235 | param_slice = _make_param_slice(start, end, param.shape, dim=dim) 236 | 237 | with torch.no_grad(): 238 | param.copy_(other_param[param_slice]) 239 | assert param.is_contiguous() 240 | 241 | def forward(self, inputs): 242 | embed_1 = self.vocab_parallel_embedding(inputs) 243 | embed_2 = self.parallel_embedding(inputs) 244 | embed = embed_1 + embed_2 245 | 246 | out_1 = self.column_parallel_linear(embed) 247 | out_2 = self.row_parallel_linear(out_1) 248 | 249 | return out_2 250 | 251 | 252 | def assert_same_state( 253 | self_states: Dict[str, Tensor], other_states: Dict[str, Tensor] 254 | ) -> None: 255 | """Check if self and other have the same parameter and gradient states.""" 256 | self_params, self_grads = self_states 257 | other_params, other_grads = other_states 258 | 259 | for (self_name, self_param), (other_name, other_param) in zip( 260 | self_params.items(), other_params.items() 261 | ): 262 | assert self_name == other_name 263 | torch.testing.assert_close(self_param, other_param) 264 | 265 | for (self_name, self_grad), (other_name, other_grad) in zip( 266 | self_grads.items(), other_grads.items() 267 | ): 268 | assert self_name == other_name 269 | torch.testing.assert_close(self_grad, other_grad) 270 | 271 | 272 | def wrap_ddp(base_model, pg=None): 273 | return DDP( 274 | base_model, 275 | process_group=pg, 276 | ) 277 | 278 | 279 | def wrap_fsdp(base_model, layer_to_wrap, pg=None): 280 | """Standard FSDP wrap in full-shard mode, CPU RAM efficient.""" 281 | # Initialize fsdp options. 282 | backward_prefetch = BackwardPrefetch.BACKWARD_PRE 283 | 284 | # Shard model parameters, optimizer, grads over all GPUs. 285 | sharding_strategy = ShardingStrategy.FULL_SHARD 286 | 287 | # Test everying in fp32 default. 288 | mixed_precision = MixedPrecision( 289 | param_dtype=None, 290 | reduce_dtype=None, 291 | buffer_dtype=None, 292 | cast_root_forward_inputs=True, 293 | ) 294 | 295 | # Don't offload to CPU. 296 | cpu_offload = CPUOffload(offload_params=False) 297 | 298 | transformer_auto_wrapper_policy = functools.partial( 299 | transformer_auto_wrap_policy, 300 | transformer_layer_cls={layer_to_wrap}, 301 | ) 302 | 303 | # Wrap model. 304 | model = FSDP( 305 | base_model, 306 | process_group=pg, # default pg. 307 | sharding_strategy=sharding_strategy, 308 | cpu_offload=cpu_offload, 309 | auto_wrap_policy=transformer_auto_wrapper_policy, 310 | backward_prefetch=backward_prefetch, 311 | mixed_precision=mixed_precision, 312 | ignored_modules=None, 313 | param_init_fn=None, 314 | device_id=torch.cuda.current_device(), 315 | sync_module_states=True, 316 | forward_prefetch=True, 317 | limit_all_gathers=True, 318 | use_orig_params=False, 319 | ) 320 | return model 321 | 322 | 323 | def register_hook_functions( 324 | model, editing_function, hook_type, module_name_to_shape_map, module_prefix 325 | ): 326 | for name, expected_shape in module_name_to_shape_map.items(): 327 | register_fn = getattr(model, hook_type, None) 328 | assert register_fn is not None, "Reg. fn {hook_type} couldn't be used." 329 | 330 | full_name = module_prefix + name 331 | register_fn( 332 | HookFunction( 333 | full_name, 334 | expected_shape=expected_shape, 335 | editing_function=editing_function, 336 | ) 337 | ) 338 | -------------------------------------------------------------------------------- /_test/traverse/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorInstitute/flex_model/94e3cb434d26bc35c8503b4f6f2dd0b500ae90e8/_test/traverse/__init__.py -------------------------------------------------------------------------------- /_test/traverse/test_nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers.modeling_outputs import BaseModelOutputWithPast 3 | 4 | from flex_model.traverse.nodes import BaseModelOutputWithPastNode 5 | 6 | 7 | def test_BaseModelOutputWithPastNode(): 8 | node = BaseModelOutputWithPastNode() 9 | obj = BaseModelOutputWithPast( 10 | last_hidden_state=torch.ones((1)), 11 | past_key_values=torch.ones((1)) * 2, 12 | hidden_states=torch.ones((1)) * 3, 13 | attentions=torch.ones((1)) * 4, 14 | ) 15 | 16 | contents = node.flatten(obj) 17 | for i, c in enumerate(contents): 18 | assert torch.equal(c, torch.ones((1)) * (i + 1)) 19 | 20 | new_obj = node.unflatten(contents) 21 | assert torch.equal(new_obj.last_hidden_state, obj.last_hidden_state) 22 | assert torch.equal(new_obj.past_key_values, obj.past_key_values) 23 | assert torch.equal(new_obj.hidden_states, obj.hidden_states) 24 | assert torch.equal(new_obj.attentions, obj.attentions) 25 | -------------------------------------------------------------------------------- /_test/traverse/test_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from flex_model.traverse.ops import flatten, unflatten 4 | 5 | 6 | def test_flatten_and_unflatten(): 7 | layer_output = [ 8 | 1, 9 | 2, 10 | torch.ones((1)).cuda(), 11 | "zzz", 12 | (torch.ones((1)).cuda() * 2, torch.ones((1)).cuda() * 3), 13 | ] 14 | treedef, leaves = flatten(layer_output) 15 | for i, leaf_ten in enumerate(leaves): 16 | assert torch.equal(leaf_ten, torch.ones((1)).cuda() * (i + 1)) 17 | 18 | edited_leaves = [leaf_ten * 2 for leaf_ten in leaves] 19 | 20 | result = unflatten(treedef, edited_leaves) 21 | new_treedef, new_leaves = flatten(result) 22 | assert new_treedef == treedef 23 | assert new_leaves == edited_leaves 24 | -------------------------------------------------------------------------------- /demos/README.md: -------------------------------------------------------------------------------- 1 | # FlexModel Demos 2 | **Note**: If your system does not have InfiniBand, please be sure to run `export NCCL_IB_DISABLE=1` prior to running the experiments. 3 | ## Induction Heads 4 | 5 | The file `induction_heads_multigpu.py` takes in a Llama-2 HuggingFace model path and outputs two plots which visualize the per token loss and a heatmap of the induction scores across all attention heads. 6 | 7 | Before running the file, you need to insert `DummyModule` from `flex_model.core` in order to store the activation maps. `DummyModule` is just an `nn.Module` that is an identity function and its purpose is to allow us to attach a `HookFunction` at any arbitrary point during a forward pass. An example of this is shown below: 8 | 9 | ```python 10 | from flex_model.core import DummyModule 11 | ... 12 | 13 | class LlamaAttention(nn.Module): 14 | def __init__(...): 15 | ... 16 | self.dummy = DummyModule() 17 | ... 18 | 19 | def forward(...): 20 | ... 21 | # attention maps are created once softmax is applied 22 | attn_weights = nn.functional.softmax(...) 23 | 24 | # pass through DummyModule so we can attach forward hooks 25 | attn_weights = self.dummy(attn_weights) 26 | ... 27 | ``` 28 | You may also need to change the module string literal under the `get_module_names` function. If you do not have safe tensors downloaded as part of your HF model, you will need to pass in `use_safetensors=False` as part of the model loading. 29 | 30 | To run the file, you can use `torchrun`. We have tested this demo by running Llama-2-70b-hf on 4x A100-80G. 31 | ``` 32 | torchrun --nnodes=1 --nproc-per-node=4 induction_heads_multigpu.py --model_path /path/to/llama 33 | ``` 34 | Example outputs for Llama-2-70b-hf: 35 | 36 | ![Token Loss](./example_plots/induction_loss.png) 37 | ![Token Loss](./example_plots/induction_score_by_head.png) -------------------------------------------------------------------------------- /demos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorInstitute/flex_model/94e3cb434d26bc35c8503b4f6f2dd0b500ae90e8/demos/__init__.py -------------------------------------------------------------------------------- /demos/example_plots/induction_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorInstitute/flex_model/94e3cb434d26bc35c8503b4f6f2dd0b500ae90e8/demos/example_plots/induction_loss.png -------------------------------------------------------------------------------- /demos/example_plots/induction_score_by_head.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorInstitute/flex_model/94e3cb434d26bc35c8503b4f6f2dd0b500ae90e8/demos/example_plots/induction_score_by_head.png -------------------------------------------------------------------------------- /demos/induction_heads_multigpu.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import os 5 | from argparse import Namespace 6 | from functools import partial 7 | from typing import Any, Callable 8 | 9 | import einops 10 | import matplotlib.pyplot as plt 11 | import torch 12 | import torch.distributed as dist 13 | import torch.nn.functional as F 14 | from flex_model.core import FlexModel, HookFunction 15 | from torch import nn 16 | from torch.distributed.fsdp import ( 17 | FullyShardedDataParallel as FSDP, 18 | ) 19 | from torch.distributed.fsdp import ( 20 | MixedPrecision, 21 | ShardingStrategy, 22 | ) 23 | from torch.distributed.fsdp.wrap import ( 24 | transformer_auto_wrap_policy, 25 | ) 26 | from transformers import ( 27 | LlamaConfig, 28 | LlamaForCausalLM, 29 | ) 30 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 31 | 32 | 33 | def setup() -> None: 34 | """Instantiate process group.""" 35 | dist.init_process_group("nccl") 36 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 37 | torch.cuda.set_device(local_rank) 38 | return local_rank 39 | 40 | 41 | def cleanup() -> None: 42 | """Destroy process group.""" 43 | dist.destroy_process_group() 44 | 45 | 46 | def args() -> Namespace: 47 | """Parse command-line arguments.""" 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--model_path", type=str, required=True) 50 | parser.add_argument("--seq_length", default=50, type=int, required=False) 51 | return parser.parse_args() 52 | 53 | 54 | def setup_model(model_path: str, local_rank: int) -> \ 55 | tuple[nn.Module, LlamaConfig]: 56 | """Instantiate model, tokenizer, and config. 57 | 58 | Args: 59 | ---- 60 | model_path: A path to the model being instantiated 61 | local_rank: The local rank of the worker 62 | 63 | Returns: 64 | ------- 65 | A tuple of length two containing the model and the config. 66 | """ 67 | config = LlamaConfig.from_pretrained(model_path) 68 | if local_rank == 0: 69 | model = LlamaForCausalLM.from_pretrained( 70 | model_path, 71 | torch_dtype=torch.bfloat16, 72 | ) 73 | else: 74 | with torch.device("meta"): 75 | model = LlamaForCausalLM.from_pretrained( 76 | model_path, 77 | torch_dtype=torch.bfloat16, 78 | ) 79 | return model, config 80 | 81 | 82 | def fsdp_config(local_rank: int) -> dict[str:Any]: 83 | """Return the config to be used by FSDP. 84 | 85 | Args: 86 | ---- 87 | local_rank: The local rank of the worker 88 | 89 | Returns: 90 | ------- 91 | A dictionary containing keyword -> respective configuration. 92 | """ 93 | 94 | def _module_init_fn(module: nn.Module) -> Callable: 95 | """Return the function used for initializing modules on FSDP workers.""" 96 | return module.to_empty( 97 | device=torch.cuda.current_device(), 98 | recurse=False, 99 | ) 100 | 101 | auto_wrap_policy = partial( 102 | transformer_auto_wrap_policy, 103 | transformer_layer_cls={ 104 | LlamaDecoderLayer, 105 | }, 106 | ) 107 | sharding_strategy = ShardingStrategy.FULL_SHARD 108 | device_id = torch.cuda.current_device() 109 | sync_module_states = True 110 | param_init_fn = _module_init_fn if local_rank != 0 else None 111 | mp_policy = MixedPrecision( 112 | param_dtype=torch.bfloat16, 113 | buffer_dtype=torch.bfloat16, 114 | reduce_dtype=torch.bfloat16, 115 | ) 116 | 117 | config = { 118 | "auto_wrap_policy": auto_wrap_policy, 119 | "sharding_strategy": sharding_strategy, 120 | "device_id": device_id, 121 | "sync_module_states": sync_module_states, 122 | "param_init_fn": param_init_fn, 123 | "mixed_precision": mp_policy, 124 | } 125 | return config 126 | 127 | 128 | def calculate_induction_score( 129 | num_hidden_layers: int, 130 | num_attention_heads: int, 131 | activation_dict: dict[str, torch.Tensor], 132 | module_names: list[str], 133 | sequence_length: int, 134 | ) -> None: 135 | """Calculate and save a heatmap of the induction scores for each attention 136 | head. 137 | 138 | Args: 139 | ---- 140 | num_hidden_layers: The number of transformer blocks in the model 141 | num_attention_heads: The number of attention heads in the model 142 | activation_dict: Dictionary containing the activations retrieved using 143 | FlexModel 144 | module_names: A list of the module names to which we have attached 145 | hooks 146 | sequence_length: The sequence length of the prompt passed into the 147 | model 148 | """ 149 | # Create the matrix to store the induction scores for each head across 150 | # all layers 151 | induction_score_store = torch.zeros( 152 | ( 153 | num_hidden_layers, 154 | num_attention_heads, 155 | ), 156 | device=torch.cuda.current_device(), 157 | ) 158 | 159 | for i, module_name in enumerate(module_names): 160 | # Retrieve the gathered activation maps for a given module 161 | attn_maps = ( 162 | activation_dict[module_name][0] 163 | .detach() 164 | .to( 165 | torch.cuda.current_device(), 166 | ) 167 | ) 168 | 169 | # Attention maps are of shape [batch, head, seq, seq] 170 | 171 | # We take the diagonal over the last two dims i.e. the query/key dims 172 | 173 | # We offset by 1-sequence_length because we want to see how much 174 | # attention is paid from the *current* token to the token that occurred 175 | # right after the *previous occurrence* of the *current* token (which 176 | # is 1-sequence_length tokens back). A better visualization can be 177 | # found on Anthropic's In-context Learning and Induction Heads paper 178 | induction_stripe = attn_maps.diagonal( 179 | dim1=-2, 180 | dim2=-1, 181 | offset=1 - sequence_length, 182 | ) 183 | 184 | # We average across the diagonal and the batch dims to get the final 185 | # induction scores 186 | induction_score = einops.reduce( 187 | induction_stripe, 188 | "batch head_index position -> head_index", 189 | "mean", 190 | ) 191 | induction_score_store[i, :] = induction_score 192 | 193 | plt.imshow(induction_score_store.detach().cpu().numpy(), origin="lower") 194 | plt.xlabel("Head") 195 | plt.ylabel("Layer") 196 | plt.title("Induction Score by Head") 197 | plt.colorbar() 198 | plt.savefig("induction_score_by_head.png", bbox_inches="tight") 199 | 200 | 201 | def get_module_names(num_hidden_layers: int) -> list[str]: 202 | """Return the list of module names to apply hooks onto. 203 | 204 | Args: 205 | ---- 206 | num_hidden_layers: The number of transformer blocks in the model 207 | 208 | Returns: 209 | ------- 210 | A list of model names that we're applying HookFunctions to 211 | """ 212 | prefix = "_fsdp_wrapped_module.model.layers." 213 | postfix = "._fsdp_wrapped_module.self_attn.dummy" 214 | module_names = [f"{prefix}{i}{postfix}" for i in range(num_hidden_layers)] 215 | return module_names 216 | 217 | 218 | def calculate_per_token_loss( 219 | logits: torch.Tensor, 220 | prompt: torch.Tensor, 221 | ) -> None: 222 | """Calculate and plot the cross-entropy loss per token. 223 | 224 | Args: 225 | ---- 226 | logits: The model's output logits 227 | prompt: The input prompt sequence 228 | """ 229 | # Calculate per token loss 230 | 231 | # First take log softmax across the vocab dim to get log probabilities 232 | log_probs = F.log_softmax(logits, dim=-1) 233 | 234 | # log_probs[..., :-1, :] takes the log probs up to the final token while 235 | # keeping the shape the same. 236 | 237 | # .gather(...) collects the correct log probs across the vocab dim given 238 | # the prompt 239 | 240 | # The reason we need prompt[..., 1:, None] is to ensure that the index 241 | # argument has the same rank as log_probs 242 | 243 | # Finally, we need [..., 0] at the end so that we get rid of the extra 244 | # trailing rank we created (we also could've done a .squeeze()) 245 | predicted_log_probs = -log_probs[..., :-1, :].gather( 246 | dim=-1, 247 | index=prompt[..., 1:, None], 248 | )[..., 0] 249 | 250 | # Average loss across the batch dimension 251 | loss_by_position = einops.reduce( 252 | predicted_log_probs, 253 | "batch position -> position", 254 | "mean", 255 | ) 256 | 257 | plt.plot( 258 | list(range(len(loss_by_position))), 259 | loss_by_position.detach().cpu().numpy(), 260 | ) 261 | plt.xlabel("Token Index") 262 | plt.ylabel("Loss") 263 | plt.title("Loss by position on random repeated tokens") 264 | plt.savefig("induction_loss.png", bbox_inches="tight") 265 | 266 | 267 | def main(args: Namespace) -> None: 268 | """Execute main demo. 269 | 270 | Args: 271 | ---- 272 | args: Command-line arguments 273 | """ 274 | local_rank = setup() 275 | 276 | seq_len = args.seq_length 277 | batch_size = 4 278 | min_vocab_idx, max_vocab_idx = 500, 15000 279 | 280 | prompt = torch.randint( 281 | min_vocab_idx, max_vocab_idx, (batch_size, seq_len), 282 | ).to( 283 | torch.cuda.current_device(), 284 | ) 285 | repeated_tokens = einops.repeat( 286 | prompt, 287 | "batch seq_len -> batch (2 seq_len)", 288 | ) 289 | 290 | model, config = setup_model(args.model_path, local_rank) 291 | fsdp_cfg = fsdp_config(local_rank) 292 | 293 | model = FSDP( 294 | model, 295 | **fsdp_cfg, 296 | ) 297 | 298 | # Wrap the model 299 | output_dict = {} 300 | model = FlexModel( 301 | model, 302 | output_dict, 303 | data_parallel_size=dist.get_world_size(), 304 | ) 305 | 306 | # Register hooks for activations 307 | module_names = get_module_names(config.num_hidden_layers) 308 | for module_name in module_names: 309 | model.register_forward_hook( 310 | HookFunction( 311 | module_name, 312 | (None, None, None, None), 313 | ), 314 | ) 315 | 316 | out = model(repeated_tokens).logits 317 | 318 | # Do plotting on main rank 319 | if dist.get_rank() == 0: 320 | calculate_induction_score( 321 | config.num_hidden_layers, 322 | config.num_attention_heads, 323 | output_dict, 324 | module_names, 325 | seq_len, 326 | ) 327 | plt.clf() 328 | 329 | # Note: we are only calculating this over the main rank's output 330 | # for the purpose of demonstration 331 | calculate_per_token_loss(out, repeated_tokens) 332 | cleanup() 333 | 334 | 335 | if __name__ == "__main__": 336 | parsed_args = args() 337 | main(parsed_args) 338 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | import os 7 | import sys 8 | 9 | sys.path.insert(0, os.path.abspath("../../flex_model")) 10 | 11 | # -- Project information ----------------------------------------------------- 12 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 13 | 14 | project = "FlexModel" 15 | copyright = "2023, Matthew Choi" 16 | author = "Matthew Choi" 17 | release = "0.0.1" 18 | 19 | # -- General configuration --------------------------------------------------- 20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 21 | 22 | extensions = [ 23 | "sphinx.ext.autosummary", 24 | "sphinx.ext.autodoc", 25 | ] 26 | autosummary_generate = True 27 | autodoc_member_order = "bysource" 28 | 29 | templates_path = ["_templates"] 30 | exclude_patterns = [] 31 | 32 | 33 | # -- Options for HTML output ------------------------------------------------- 34 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 35 | 36 | html_theme = "sphinx_rtd_theme" 37 | html_static_path = ["_static"] 38 | -------------------------------------------------------------------------------- /docs/source/demo_links/induction.rst: -------------------------------------------------------------------------------- 1 | Induction Heads Demo 2 | -------------------- 3 | 4 | .. include:: ../../../demos/induction_heads_multigpu.py 5 | :code: python 6 | -------------------------------------------------------------------------------- /docs/source/demos.rst: -------------------------------------------------------------------------------- 1 | Demos 2 | ===== 3 | 4 | 1. :doc:`demo_links/induction` 5 | -------------------------------------------------------------------------------- /docs/source/example_links/fsdp_example.rst: -------------------------------------------------------------------------------- 1 | FSDP Example 2 | ---------------- 3 | 4 | .. include:: ../../../examples/fsdp_example.py 5 | :code: python 6 | -------------------------------------------------------------------------------- /docs/source/example_links/megatron_example.rst: -------------------------------------------------------------------------------- 1 | Megatron Example 2 | ---------------- 3 | 4 | .. include:: ../../../examples/megatron_example.py 5 | :code: python 6 | -------------------------------------------------------------------------------- /docs/source/example_links/single_gpu_example.rst: -------------------------------------------------------------------------------- 1 | Single-GPU Example 2 | ------------------ 3 | 4 | .. include:: ../../../examples/single_gpu_example.py 5 | :code: python 6 | -------------------------------------------------------------------------------- /docs/source/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | 1. :doc:`example_links/single_gpu_example` 5 | 6 | 2. :doc:`example_links/megatron_example` 7 | 8 | 3. :doc:`example_links/fsdp_example` 9 | -------------------------------------------------------------------------------- /docs/source/flex_model.core.rst: -------------------------------------------------------------------------------- 1 | Core: Wrapper and HookFunction 2 | ============================== 3 | 4 | .. currentmodule:: flex_model.core 5 | 6 | 7 | Modules 8 | ------- 9 | 10 | .. autosummary:: 11 | :toctree: generated 12 | :nosignatures: 13 | 14 | FlexModel 15 | HookFunction 16 | 17 | .. autoclass:: FlexModel 18 | :members: 19 | 20 | .. autoclass:: HookFunction 21 | :members: 22 | 23 | 24 | Miscellaneous 25 | ------------- 26 | 27 | .. autoclass:: DummyModule 28 | -------------------------------------------------------------------------------- /docs/source/flex_model.distributed.rst: -------------------------------------------------------------------------------- 1 | Distributed: Backend, mappings and strategies 2 | ============================================= 3 | 4 | .. currentmodule:: flex_model.distributed 5 | 6 | 7 | Distributed API 8 | --------------- 9 | 10 | .. autosummary:: 11 | :toctree: generated 12 | :nosignatures: 13 | 14 | initialize_distributed_state 15 | 16 | 17 | Mappings 18 | -------- 19 | 20 | .. autosummary:: 21 | :toctree: generated 22 | :nosignatures: 23 | 24 | broadcast_tensor_parallel 25 | broadcast_data_parallel 26 | all_gather_tensor_parallel 27 | all_gather_data_parallel 28 | scatter_tensor_parallel 29 | scatter_data_parallel 30 | gather_pipeline_parallel_tensor_dicts 31 | 32 | 33 | Strategies 34 | ---------- 35 | 36 | .. autosummary:: 37 | :toctree: generated 38 | :nosignatures: 39 | 40 | BaseRoutingStrategy 41 | ParameterTensorParallelRoutingStrategy 42 | ActivationTensorAllToAllRoutingStrategy 43 | BaseOffloadStrategy 44 | NullMemoryOffloadStrategy 45 | CPUPinnedMemoryOffloadStrategy 46 | CPUPagedMemoryOffloadStrategy 47 | GPUMemoryOffloadStrategy 48 | BaseFunctionStrategy 49 | NonValidatedFunctionStrategy 50 | -------------------------------------------------------------------------------- /docs/source/flex_model.rst: -------------------------------------------------------------------------------- 1 | flex\_model package 2 | =================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | flex_model.core 11 | flex_model.distributed 12 | flex_model.traverse 13 | 14 | Submodules 15 | ---------- 16 | 17 | flex\_model.utils module 18 | ------------------------ 19 | 20 | .. automodule:: flex_model.utils 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | Module contents 26 | --------------- 27 | 28 | .. automodule:: flex_model 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | -------------------------------------------------------------------------------- /docs/source/flex_model.traverse.rst: -------------------------------------------------------------------------------- 1 | Traverse: Pytree traversal utility 2 | ================================== 3 | 4 | .. currentmodule:: flex_model.traverse 5 | 6 | 7 | Ops 8 | --- 9 | 10 | .. autosummary:: 11 | :toctree: generated 12 | :nosignatures: 13 | 14 | flatten 15 | unflatten 16 | 17 | 18 | Nodes 19 | ----- 20 | 21 | .. autosummary:: 22 | :toctree: generated 23 | :nosignatures: 24 | 25 | InternalObject 26 | LeafObject 27 | ScalarObject 28 | InternalNode 29 | LeafNode 30 | ScalarNode 31 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. FlexModel documentation master file, created by 2 | sphinx-quickstart on Wed Aug 30 14:56:13 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to FlexModel's documentation! 7 | ===================================== 8 | 9 | FlexModel is a wrapper for Pytorch models which exposes powerful primitives 10 | for model surgery and introspection. 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | :caption: Getting Started 15 | 16 | intro 17 | examples 18 | demos 19 | 20 | Check-out our examples for single and multi-gpu, which use both megatron-style 21 | layers and PyTorch's FSDP wrapper. Additionally, we have a demo which includes 22 | code for induction head identification in Llama-2-70b. 23 | 24 | 25 | .. toctree:: 26 | :maxdepth: 1 27 | :caption: Key Python API: 28 | 29 | flex_model.core 30 | flex_model.distributed 31 | flex_model.traverse 32 | -------------------------------------------------------------------------------- /docs/source/intro.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | ``FlexModel`` is a simple-to-use and robust wrapper for model surgery and 4 | introspection. It provides a few powerful primitives for activation retrieval, 5 | activation editing and auxillary module training. Users can define their own 6 | ``HookFunction`` instances, which provide a single-threaded runtime for 7 | manipulating layer activations. The entire wrapper and hook runtime can be 8 | used under arbitrary distributed model topologies. 9 | 10 | 11 | Motivation 12 | ********** 13 | Repositories for mechanistic interpretability have very built-out feature sets, 14 | however they do not provide many utilities for using models that are 15 | distributed in potentially complicated ways. This framework is intended to 16 | provide these same utilities in a way which is robust to potentially 17 | complicated distributed strategies. 18 | 19 | 20 | Limitations 21 | *********** 22 | Being a new framework, we currently lack many ease-of-use and high-level 23 | features that more mature frameworks include. However, we hope that the 24 | well-tested primitives currently exposed are powerful enough where these 25 | features will be simple to implement. 26 | 27 | 28 | Additionally, there is only support for up to 3-D distributed parallelism. 29 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | flex_model 2 | ========== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | flex_model 8 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # `FlexModel` Usage Examples 2 | This folder contains simple use-cases of `FlexModel` across a variety of distributed backends. 3 | 4 | # How to Run 5 | Each example script has instructions on how to run it at the top of the script. 6 | Most examples can typically be run using PyTorch's `torchrun` command. These scripts also require access to model and tokenizer checkpoint files. 7 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorInstitute/flex_model/94e3cb434d26bc35c8503b4f6f2dd0b500ae90e8/examples/__init__.py -------------------------------------------------------------------------------- /examples/fsdp_example.py: -------------------------------------------------------------------------------- 1 | """Runs Llama-2-13B on 2 GPUs using PyTorch's FSDP wrapper. This script 2 | demonstrates basic usage of the `FlexModel` wrapper with a generic 3 | `HookFunction`. 4 | 5 | Running: 6 | 7 | torchrun --nodes 1 --nproc_per_node 2 fsdp_example.py 8 | """ 9 | import argparse 10 | import functools 11 | import os 12 | from typing import Dict, List 13 | 14 | import torch 15 | import torch.distributed as dist 16 | import torch.nn as nn 17 | from torch import Tensor 18 | from torch.distributed.fsdp import BackwardPrefetch, CPUOffload 19 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 20 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy 21 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 22 | from transformers import LlamaForCausalLM, LlamaTokenizerFast 23 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 24 | 25 | from flex_model.core import FlexModel, HookFunction 26 | from flex_model.utils import setup_logger 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--log_level", type=str, default="debug") 32 | parser.add_argument( 33 | "--checkpoint_dir", type=str, default="/model-weights/Llama-2-13b-hf" 34 | ) 35 | parser.add_argument( 36 | "--tokenizer_dir", type=str, default="/model-weights/Llama-2-13b-hf" 37 | ) 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | def get_llama2_tokenizer(tokenizer_dir): 43 | tokenizer = LlamaTokenizerFast.from_pretrained( 44 | tokenizer_dir, 45 | local_files_only=True, 46 | ) 47 | tokenizer.model_max_length = 512 48 | 49 | # Llama-2 has no PAD token, substitute the EOS token. 50 | tokenizer.pad_token = tokenizer.eos_token 51 | 52 | return tokenizer 53 | 54 | 55 | def make_llama2_fsdp(checkpoint_dir): 56 | # Load llama-2 model and prepare it for FSDP (CPU RAM-efficient) 57 | if dist.get_rank() == 0: 58 | base_model = LlamaForCausalLM.from_pretrained( 59 | checkpoint_dir, 60 | local_files_only=True, 61 | torch_dtype=torch.bfloat16, 62 | ) 63 | param_init_fn = None 64 | else: 65 | with torch.device("meta"): 66 | base_model = LlamaForCausalLM.from_pretrained( 67 | checkpoint_dir, 68 | local_files_only=True, 69 | torch_dtype=torch.bfloat16, 70 | ) 71 | 72 | def _param_init_fn(module: nn.Module): 73 | module = module.to_empty( 74 | device=torch.cuda.current_device(), recurse=False 75 | ) 76 | return module 77 | 78 | param_init_fn = _param_init_fn 79 | 80 | # Initialize fsdp options. 81 | backward_prefetch = BackwardPrefetch.BACKWARD_PRE 82 | 83 | # Shard model parameters, optimizer, grads over all GPUs. 84 | sharding_strategy = ShardingStrategy.FULL_SHARD 85 | 86 | mixed_precision = MixedPrecision( 87 | param_dtype=torch.bfloat16, 88 | reduce_dtype=torch.bfloat16, 89 | buffer_dtype=torch.bfloat16, 90 | cast_root_forward_inputs=True, 91 | ) 92 | 93 | # Don't offload to CPU. 94 | cpu_offload = CPUOffload(offload_params=False) 95 | 96 | transformer_auto_wrapper_policy = functools.partial( 97 | transformer_auto_wrap_policy, 98 | transformer_layer_cls={LlamaDecoderLayer}, 99 | ) 100 | 101 | # Wrap model. 102 | model = FSDP( 103 | base_model, 104 | process_group=None, # default pg. 105 | sharding_strategy=sharding_strategy, 106 | cpu_offload=cpu_offload, 107 | auto_wrap_policy=transformer_auto_wrapper_policy, 108 | backward_prefetch=backward_prefetch, 109 | mixed_precision=mixed_precision, 110 | ignored_modules=None, 111 | param_init_fn=param_init_fn, 112 | device_id=torch.cuda.current_device(), 113 | sync_module_states=True, 114 | forward_prefetch=True, 115 | limit_all_gathers=True, 116 | use_orig_params=False, 117 | ) 118 | 119 | return model 120 | 121 | 122 | def init_dist(): 123 | dist.init_process_group("nccl") 124 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 125 | torch.cuda.set_device(local_rank) 126 | 127 | 128 | def main(args): 129 | """Forward pass of FSDP-wrapped llama-2-13b-hf model retrieving activations. 130 | 131 | This script must be run via Huggingface Accelerate FSDP. Retrieves 132 | activations over all DP-workers by gathering them in the batch dimension. 133 | """ 134 | setup_logger("debug") 135 | 136 | init_dist() 137 | 138 | rank = dist.get_rank() 139 | world_size = dist.get_world_size() 140 | 141 | prompts = [ 142 | "It's a nice day we're having", 143 | "The capital of Canada is", 144 | "What should I eat for dinner tonight?", 145 | "There's about three people going to", 146 | ] 147 | 148 | model = make_llama2_fsdp(args.checkpoint_dir) 149 | 150 | # Load tokenizer 151 | tokenizer = get_llama2_tokenizer(args.tokenizer_dir) 152 | 153 | # Define output to dump activations to 154 | activation_dict: Dict[str, List[Tensor]] = {} 155 | 156 | # Wrap model in FlexModel 157 | model = FlexModel( 158 | model, 159 | activation_dict, 160 | data_parallel_size=world_size, 161 | ) 162 | 163 | # Create a hook function 164 | module_name = ( 165 | "_fsdp_wrapped_module.model.layers.30._fsdp_wrapped_module.mlp" 166 | ) 167 | hook_function = HookFunction( 168 | module_name=module_name, 169 | expected_shape=(None, None, None), 170 | editing_function=None, 171 | ) 172 | 173 | # Register hook function with the model 174 | model.register_forward_hook(hook_function) 175 | 176 | # Tokenize a prompt 177 | inputs = tokenizer(prompts, padding="max_length", return_tensors="pt")[ 178 | "input_ids" 179 | ] 180 | 181 | # Split the batch across dp workers 182 | dp_worker_inputs = inputs.chunk(world_size, dim=0)[rank] 183 | 184 | # Run through model to generate logits and activations 185 | _outputs = model(dp_worker_inputs) 186 | 187 | # Activations are only dumped to main process 188 | if rank == 0: 189 | activation = activation_dict[module_name][0] 190 | print(f"Activation shape: {activation.shape}") 191 | print(activation) 192 | 193 | assert activation.shape[0] == 4 194 | assert activation.shape[-1] == 5120 195 | 196 | 197 | if __name__ == "__main__": 198 | args = parse_args() 199 | main(args) 200 | -------------------------------------------------------------------------------- /examples/megatron_example.py: -------------------------------------------------------------------------------- 1 | """Runs Llama-2-13B on 2 GPUs using Fairscale's implementation of Megatron-LM 2 | layers. This script demonstrates basic usage of `FlexModel` with a generic 3 | `HookFunction`. 4 | 5 | Running: 6 | 7 | torchrun --nnodes 1 --nproc_per_node 2 megatron_example.py 8 | 9 | """ 10 | import argparse 11 | from typing import Dict, List 12 | 13 | import torch 14 | from llama import Llama 15 | from torch import Tensor 16 | 17 | from flex_model.core import FlexModel, HookFunction 18 | from flex_model.utils import setup_logger 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--log_level", type=str, default="debug") 24 | parser.add_argument( 25 | "--checkpoint_dir", type=str, default="/model-weights/Llama-2-13b" 26 | ) 27 | parser.add_argument( 28 | "--tokenizer_dir", 29 | type=str, 30 | default="/model-weights/Llama-2-13b/tokenizer.model", 31 | ) 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def main(args): 37 | """Forward pass through llama-2-13b which uses megatron for TP, PP, and DP.""" 38 | setup_logger(args.log_level) 39 | 40 | prompts = [ 41 | "It's a nice day we're having", 42 | "The capital of Canada is", 43 | "What should I eat for dinner tonight?", 44 | "There's about three people going to", 45 | ] 46 | # Load llama-2 using megatron layers 47 | generator = Llama.build( 48 | ckpt_dir=args.checkpoint_dir, 49 | tokenizer_path=args.tokenizer_dir, 50 | max_seq_len=512, 51 | max_batch_size=32, 52 | ) 53 | model = generator.model 54 | 55 | # Define tokenizer function 56 | def tokenize_fn(prompts): 57 | input_tokens = [ 58 | generator.tokenizer.encode(x, bos=True, eos=False) for x in prompts 59 | ] 60 | bsz = len(input_tokens) 61 | total_len = max(len(t) for t in input_tokens) 62 | pad_id = 0 63 | tokens = torch.full( 64 | (bsz, total_len), pad_id, dtype=torch.long, device="cuda" 65 | ) 66 | for k, t in enumerate(input_tokens): 67 | tokens[k, : len(t)] = torch.tensor( 68 | t, dtype=torch.long, device="cuda" 69 | ) 70 | return tokens 71 | 72 | # Define output to dump activations to 73 | activation_dict: Dict[str, List[Tensor]] = {} 74 | 75 | # Wrap model in FlexModel (llama-2-13b requires tensor parallel size 2) 76 | model = FlexModel( 77 | model, 78 | activation_dict, 79 | tensor_parallel_size=2, 80 | ) 81 | 82 | # Create a hook function 83 | module_name = "layers.28.feed_forward.w3" 84 | hook_function = HookFunction( 85 | module_name=module_name, 86 | expected_shape=(None, None, 13824), 87 | editing_function=None, 88 | ) 89 | 90 | # Register hook function with the model 91 | model.register_forward_hook(hook_function) 92 | 93 | # Tokenize a prompt 94 | inputs = tokenize_fn(prompts) 95 | 96 | # Run through model to generate logits and activations 97 | _outputs = model(inputs, start_pos=0) 98 | 99 | # Activations are only dumped to main process. Activations per-module key 100 | # are accumulated in a list. 101 | if torch.distributed.get_rank() == 0: 102 | activation = activation_dict[module_name][0] 103 | print(f"Activation shape: {activation.shape}") 104 | print(activation) 105 | 106 | assert activation.shape[0] == 4 107 | assert activation.shape[-1] == 13824 108 | 109 | 110 | if __name__ == "__main__": 111 | args = parse_args() 112 | main(args) 113 | -------------------------------------------------------------------------------- /examples/single_gpu_example.py: -------------------------------------------------------------------------------- 1 | """Runs Llama-2-13B on a single GPU using Huggingface Transformers. This script 2 | demonstrates basic usage of the `FlexModel` wrapper with a generic 3 | `HookFunction`. 4 | 5 | Running: 6 | 7 | python single_gpu_example.py 8 | """ 9 | import argparse 10 | from typing import Dict, List 11 | 12 | from torch import Tensor 13 | from transformers import AutoModelForCausalLM, AutoTokenizer 14 | 15 | from flex_model.core import FlexModel, HookFunction 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "--checkpoint_dir", type=str, default="/model-weights/Llama-2-7b-hf" 22 | ) 23 | parser.add_argument( 24 | "--tokenizer_dir", type=str, default="/model-weights/Llama-2-7b-hf" 25 | ) 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def main(args): 31 | """Single forward pass of llama-2-13b-hf model retrieving a single activation. 32 | 33 | This script runs on a single GPU only. 34 | """ 35 | # Load llama-2-13b-hf model 36 | model = AutoModelForCausalLM.from_pretrained( 37 | args.checkpoint_dir, 38 | local_files_only=True, 39 | torch_dtype="auto", 40 | device_map="auto", 41 | ) 42 | 43 | # Load tokenizer 44 | tokenizer = AutoTokenizer.from_pretrained( 45 | args.tokenizer_dir, 46 | local_files_only=True, 47 | ) 48 | 49 | ## NEW ## 50 | # Define output to dump activations to 51 | activation_dict: Dict[str, List[Tensor]] = {} 52 | 53 | # Wrap model in FlexModel 54 | model = FlexModel(model, activation_dict) 55 | 56 | # Create a hook function 57 | hook_function = HookFunction( 58 | module_name="model.layers.24", 59 | expected_shape=(None, None, None), # Not sharded, can pass None per dim 60 | editing_function=None, # Just doing retrieval 61 | ) 62 | 63 | # Register hook function with the model 64 | model.register_forward_hook(hook_function) 65 | ## NEW ## 66 | 67 | # Tokenize a prompt 68 | inputs = tokenizer( 69 | "Where is the best spot for lunch?", return_tensors="pt" 70 | )["input_ids"] 71 | 72 | # Run through model to generate logits and activations 73 | _outputs = model(inputs) 74 | 75 | print(activation_dict) 76 | 77 | 78 | if __name__ == "__main__": 79 | args = parse_args() 80 | main(args) 81 | -------------------------------------------------------------------------------- /flex_model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import core, distributed, traverse 2 | -------------------------------------------------------------------------------- /flex_model/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .core_utils import DummyModule 2 | from .hook_function import HookFunction 3 | from .wrapper import FlexModel 4 | -------------------------------------------------------------------------------- /flex_model/core/core_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | 4 | import torch.nn as nn 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class DummyModule(nn.Module): 10 | """Identity module used to expose activations. 11 | 12 | Can be placed in any :code:`nn.Module` to artificially create an activation to 13 | be hooked onto. For instance, explicitly calling a module's :code:`.forward()` 14 | method will not run forward hooks and therefore will not generate an 15 | activation. However, applying this module to the output of that will 16 | generate an activation which can be hooked onto. 17 | """ 18 | 19 | def __init__(self): 20 | super().__init__() 21 | 22 | def forward(self, inputs: Any) -> Any: 23 | return inputs 24 | -------------------------------------------------------------------------------- /flex_model/core/hook_function.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from argparse import Namespace 3 | from functools import partial 4 | from typing import Any, Callable, Optional, Tuple, Union 5 | 6 | import torch.nn as nn 7 | from torch import Tensor 8 | 9 | import flex_model.distributed as fm_dist 10 | from flex_model.traverse import ( 11 | InternalObject, 12 | LeafObject, 13 | ScalarObject, 14 | flatten, 15 | unflatten, 16 | ) 17 | 18 | LayerOutputs = Union[InternalObject, LeafObject, ScalarObject] 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def default_editing_function( 23 | current_module: nn.Module, 24 | inputs: Tensor, 25 | save_ctx: Namespace, 26 | modules: nn.ModuleDict, 27 | ) -> Tensor: 28 | """No-op editing function for logging and debug purposes. 29 | 30 | :note: This editing function showcases the expected function signature for 31 | custom editing functions. 32 | 33 | :note: If no editing function is provided for a :code:`HookFunction`, then 34 | this is the default editing function. 35 | 36 | :param nn.Module current_module: Submodule instance hooked into. 37 | :param Tensor inputs: Activation tensor produced during the forward pass of 38 | the :code:`current_module`. 39 | :param Namespace save_ctx: Save context pointer where cached data can be 40 | accessed or stored. 41 | :param nn.ModuleDict: Pointer to trainable modules globally exposed to all 42 | :class:`HookFunction` instances. 43 | 44 | :returns: Edited (or not) activation tensor 45 | :rtype: Tensor 46 | """ 47 | logger.debug(f"Running default editing function on tensor: {inputs.shape}") 48 | return inputs 49 | 50 | 51 | class HookFunction: 52 | """Function which retrieves/edits activations in a Pytorch `nn.Module`. 53 | 54 | The user provides the :code:`module_name` of the target submodule. The user 55 | can optionally pass in an :code:`editing_function` containing arbitrarily complex 56 | python code, which will be used to edit the full submodule activation 57 | tensor. If certain dimensions of the activation tensor are expected to be 58 | sharded over distributed workers, the user must also provide an 59 | :code:`expected_shape` hint so the activation tensor can be assembled. 60 | 61 | :var str module_name: Name of the :code:`nn.Module` submodule to hook into. 62 | :var expected_shape: Shape of the full activation tensor. Only the 63 | dimensions which are sharded need to be provided. Other dimensions 64 | can be annotated as :code:`None` and will be auto-completed. 65 | :type expected_shape: Optional[Tuple[Optional[int], ...]] 66 | :var editing_function: Function which is run on the full activation tensor 67 | and returns some edited function. Global contexts like the 68 | save context and trainable modules are available for use in the 69 | editing function runtime. 70 | :type editing_function: Optional[Callable] 71 | :var save_ctx: Global save context that is exposed to the 72 | :code:`editing_function`. 73 | :type save_ctx: Optional[Namespace] 74 | :var modules: Global trainable modules that are exposed to the 75 | :code:`editing_function`. 76 | :type modules: Optional[nn.ModuleDict] 77 | 78 | :note: :code:`save_ctx` and :code:`modules` are populated when the :class:`HookFunction` 79 | is registered with a :class:`FlexModel` instance. 80 | 81 | Example: 82 | 83 | .. highlight:: python 84 | .. code-block:: python 85 | 86 | # Define editing function to be run on an activation tensor. 87 | def my_editing_function(current_module, 88 | inputs, 89 | save_ctx, 90 | modules) -> Tensor: 91 | 92 | # Cache data for later. 93 | _, s, _ = torch.svd(inputs) 94 | save_ctx.activation_singular_values = s 95 | 96 | # Edit the activation tensor. 97 | inputs = torch.where(inputs > 1.0, inputs, 0.0) 98 | 99 | # Apply a torch layer to the activation tensor. 100 | outputs = modules["linear_projection"](inputs) 101 | 102 | # Pass edited activation tensor to next layer. 103 | return outputs 104 | 105 | # Instantiate registration-ready hook function. 106 | my_hook_function = HookFunction( 107 | "my_model.layers.16.self_attention", 108 | expected_shape=(4, 512, 5120), 109 | editing_function=my_editing_function, 110 | ) 111 | """ 112 | 113 | def __init__( 114 | self, 115 | module_name: str, 116 | expected_shape: Optional[Tuple[Optional[int], ...]] = None, 117 | editing_function: Optional[Callable] = None, 118 | unpack_idx: int = 0, 119 | ) -> None: 120 | """Initializes the instance by wrapping the :code:`editing_function`. 121 | 122 | :param str module_name: Name of the :code:`nn.Module` submodule to hook 123 | into. 124 | :param expected_shape: Shape of the full activation tensor. 125 | :type expected_shape: Optional[Tuple[Optional[int], ...]] 126 | :param editing_function: Function which edits the activation 127 | tensor. 128 | :type editing_function: Optional[Callable] 129 | :param str hook_type: Type of hook to register, eg. forward, backward, 130 | etc. 131 | :param int unpack_idx: Index of the tensor in the unpacked layer output 132 | list. When layer outputs are pre-processed before editing function 133 | execution, valid `torch.Tensor` objects are extracted into a list 134 | by recursive unpacking. Hence the `unpack_idx` parameter allows 135 | for specification of which tensor to consider the activation 136 | tensor for downstream processing in the `HookFunction`. 137 | """ 138 | # User-provided state. 139 | self.module_name = module_name 140 | self.expected_shape = expected_shape 141 | # TODO (mchoi): If editing function not passed (ie. just doing 142 | # retrieval), then we can fire async collectives instead 143 | # since there's no data dependency. 144 | if editing_function is None: 145 | self.editing_function = default_editing_function 146 | else: 147 | self.editing_function = editing_function 148 | self.unpack_idx = unpack_idx 149 | 150 | # FM instance registry-provided state and other runtime state. 151 | self._shared_state = None 152 | self._hook_type = None 153 | self.module = None # Safe to cache this state, never changes. 154 | 155 | # Default strategies, initialized once at first runtime. 156 | # TODO: These should be bound when self is registered to a 157 | # FlexModel instance. 158 | self.routing_strategy = fm_dist.ActivationTensorAllToAllRoutingStrategy 159 | self.offload_strategy = fm_dist.CPUPinnedMemoryOffloadStrategy 160 | self.function_strategy = fm_dist.NonValidatedFunctionStrategy 161 | 162 | # Valid hook function implementations. 163 | self.hook_type_to_impl_fn = { 164 | "forward": self._forward_hook_impl, 165 | "full_backward": self._full_backward_hook_impl, 166 | "tensor": self._tensor_hook_impl, 167 | "forward_pre": self._forward_pre_hook_impl, 168 | "full_backward_pre": self._full_backward_pre_hook_impl, 169 | } 170 | 171 | def _forward_hook_impl( 172 | self, 173 | module: nn.Module, 174 | _inputs: Union[LayerOutputs, Tensor], 175 | outputs: Union[LayerOutputs, Tensor], 176 | ) -> Union[LayerOutputs, Tensor]: 177 | """Runs a hook function for editing forward module outputs.""" 178 | if self.module is None: 179 | self.module = module 180 | outputs = self._peel_and_apply(outputs) 181 | return outputs 182 | 183 | def _full_backward_hook_impl( 184 | self, 185 | module: nn.Module, 186 | grad_inputs: Union[LayerOutputs, Tensor], 187 | _grad_outputs: Union[LayerOutputs, Tensor], 188 | ) -> Union[LayerOutputs, Tensor]: 189 | """Runs a hook function for editing backward module input gradients.""" 190 | if self.module is None: 191 | self.module = module 192 | outputs = self._peel_and_apply(grad_inputs) 193 | return outputs 194 | 195 | def _tensor_hook_impl( 196 | self, 197 | grad: Tensor, 198 | ) -> Tensor: 199 | """Runs a hook function for editing tensor gradients.""" 200 | # No module since this is tensor-level. 201 | outputs = self._apply(grad) 202 | return outputs 203 | 204 | def _forward_pre_hook_impl( 205 | self, 206 | module: nn.Module, 207 | args: Union[LayerOutputs, Tensor], 208 | ) -> Union[LayerOutputs, Tensor]: 209 | """Runs a hook function for editing forward module inputs.""" 210 | if self.module is None: 211 | self.module = module 212 | outputs = self._peel_and_apply(args) 213 | return outputs 214 | 215 | def _full_backward_pre_hook_impl( 216 | self, 217 | module: nn.Module, 218 | grad_outputs: Union[LayerOutputs, Tensor], 219 | ) -> Union[LayerOutputs, Tensor]: 220 | """Runs a hook function for editing backward module output gradients.""" 221 | if self.module is None: 222 | self.module = module 223 | outputs = self._peel_and_apply(grad_outputs) 224 | return outputs 225 | 226 | def _apply(self, tensor: Optional[Tensor]) -> Tensor: 227 | """Template function for editing a sharded activation tensor. 228 | 229 | This function is used alone in cases where hook functions operate 230 | directly on a tensor, and not an entire module. 231 | """ 232 | # Runtime initialization of strategies. 233 | # TODO: Only routing strategies need to be init at first iteration. 234 | if not isinstance(self.routing_strategy, fm_dist.BaseRoutingStrategy): 235 | self.routing_strategy = self.routing_strategy.initialize( 236 | self._shared_state.fmps, 237 | tensor, 238 | self.expected_shape, 239 | ) 240 | self.offload_strategy = self.offload_strategy.initialize( 241 | self.module_name, self._shared_state.output_ptr 242 | ) 243 | self.function_strategy = self.function_strategy.initialize( 244 | self.editing_function 245 | ) 246 | 247 | if tensor is None: 248 | return 249 | 250 | start_shape = tensor.shape 251 | tensor = self.routing_strategy.execute_prologue(tensor) 252 | 253 | # Need to pre-divide activation grads by dp world size, see: 254 | # https://yi-wang-2005.medium.com/pytorch-distributeddataparallel-internals-c01c30a41192. 255 | if self._hook_type in ["full_backward", "full_backward_pre"]: 256 | tensor = ( 257 | tensor / self._shared_state.fmps.get_data_parallel_world_size() 258 | ) 259 | 260 | self.offload_strategy.execute(tensor) 261 | tensor = self.function_strategy.execute( 262 | self.module, 263 | tensor, 264 | self._shared_state.save_ctx, 265 | self._shared_state.modules, 266 | ) 267 | 268 | if self._hook_type in ["full_backward", "full_backward_pre"]: 269 | tensor = ( 270 | tensor * self._shared_state.fmps.get_data_parallel_world_size() 271 | ) 272 | 273 | tensor = self.routing_strategy.execute_epilogue(tensor) 274 | end_shape = tensor.shape 275 | 276 | assert start_shape == end_shape, ( 277 | f"Input tensor and output tensor shape mismatch: {start_shape} -> " 278 | f"{end_shape}. The tensor returned by the editing function must " 279 | f"not change in shape at the output." 280 | ) 281 | 282 | return tensor 283 | 284 | def _unpack_layer_outputs( 285 | self, 286 | outputs: Union[LayerOutputs, Tensor], 287 | ) -> Tuple[Tensor, partial]: 288 | """Converts layer output object into an activation tensor and def. 289 | 290 | The output of model layers can be arbitrary python objects, so this 291 | function unpacks this object and separates out Pytorch tensors using 292 | the :code:`FlexModel.traverse` library. Outputs are sorted into :code:`treedef` 293 | and :code:`leaves`. The :code:`treedef` define the structure of the object, and 294 | the :code:`leaves` correspond to a list of the found tensors. When the 295 | activation tensor needs to be sent to the next layer at the end of 296 | the :class:`HookFunction` execution, the returned :code:`_repack` function 297 | reconstructs the layer output. 298 | 299 | :param outputs: The current module's layer outputs. 300 | :type outputs: Union[LayerOutputs, Tensor] 301 | 302 | :returns: The (potentially sharded) activation 303 | tensor and a function to undo the unpacking operation. 304 | :rtype: Tuple[Tensor, partial] 305 | 306 | :raises AssertionError: Occurs if no tensor is found at all in the 307 | layer outputs. 308 | """ 309 | treedef, leaves = flatten(outputs) 310 | 311 | if len(leaves) == 0: 312 | logger.debug( 313 | "Unpacked tensor is None, nothing to operate on " 314 | "(input activation grad is likely None)" 315 | ) 316 | left_leaves, tensor, right_leaves = [], None, [] 317 | else: 318 | left_leaves, tensor, right_leaves = ( 319 | leaves[: self.unpack_idx], 320 | leaves[self.unpack_idx], 321 | leaves[self.unpack_idx + 1 :], 322 | ) 323 | 324 | def _repack(_edited_tensor) -> LayerOutputs: 325 | """Pack activation tensor back into layer output container.""" 326 | layer_outputs = unflatten( 327 | treedef, 328 | left_leaves + [_edited_tensor] + right_leaves, 329 | ) 330 | return layer_outputs 331 | 332 | return tensor, _repack 333 | 334 | def _peel_and_apply( 335 | self, 336 | inputs_or_outputs: Union[LayerOutputs, Tensor], 337 | ) -> Union[LayerOutputs, Tensor]: 338 | """Template function for editing layer input or output activation tensors. 339 | 340 | Given arbitary layer outputs, this function does unpacking of layer 341 | outputs and repacking of the potentially edited layer outputs. 342 | 343 | :param nn.Module module: Module which was hooked into. 344 | :param inputs_or_outputs: Layer inputs or outputs, depending on if 345 | it's hooked into a backward or forward hook respectively. 346 | :type inputs_or_outputs: Union[LayerOutputs, Tensor] 347 | 348 | :returns: The edited layer outputs. 349 | :rtype: Union[LayerOutputs, Tensor] 350 | """ 351 | tensor, repack_fn = self._unpack_layer_outputs(inputs_or_outputs) 352 | 353 | tensor = self._apply(tensor) 354 | 355 | edited_inputs_or_outputs = repack_fn(tensor) 356 | 357 | return edited_inputs_or_outputs 358 | 359 | def _dispatch_hook_function( 360 | self, 361 | hook_function_args: Tuple[Any, ...], 362 | ) -> Union[LayerOutputs, Tensor]: 363 | """Dispatches the correct handling function depending on the hook type. 364 | 365 | There are many different types of Pytorch hooks with varying function 366 | signatures. This function unpacks the Pytorch hook function input 367 | arguments depending on the hook type and dispatches the corresponding 368 | handling function. 369 | 370 | :note: The unpacking here is in constrast to the unpacking of layer 371 | outputs, which is done in the next step if needed. 372 | 373 | :returns: Potentially edited layer outputs. 374 | These outputs are sent as input to the next layer. 375 | :rtype: Union[LayerOutputs, Tensor] 376 | 377 | :raise NotImplementedError: The requested hook type isn't yet 378 | supported. 379 | """ 380 | logger.debug(f"*{self.module_name}: Hook function activated*") 381 | 382 | handle_fn = self.hook_type_to_impl_fn[self._hook_type] 383 | retval = handle_fn(*hook_function_args) 384 | 385 | return retval 386 | 387 | def __call__(self, *args, **kwargs) -> LayerOutputs: 388 | """Entrypoint called by Pytorch hook logic. 389 | 390 | Allows us to bind the entire :class:`HookFunction` to an :code:`nn.Module` 391 | using Pytorch hook registration. 392 | 393 | :note: Doesn't currently support accepting keyword argments passed into 394 | :code:`nn.Module`s. 395 | """ 396 | if len(kwargs) != 0: 397 | raise NotImplementedError("HookFunction doesn't support kwargs.") 398 | 399 | outputs = self._dispatch_hook_function(args) 400 | return outputs 401 | -------------------------------------------------------------------------------- /flex_model/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed_state import ( 2 | initialize_distributed_state, 3 | ) 4 | from .mappings import ( 5 | all_gather_data_parallel, 6 | all_gather_tensor_parallel, 7 | batch_isend_irecv_pipeline_parallel, 8 | broadcast_data_parallel, 9 | broadcast_tensor_parallel, 10 | gather_pipeline_parallel_tensor_dicts, 11 | scatter_data_parallel, 12 | scatter_tensor_parallel, 13 | unity, 14 | ) 15 | from .strategies import ( 16 | BaseRoutingStrategy, 17 | ParameterTensorParallelRoutingStrategy, 18 | ActivationTensorAllToAllRoutingStrategy, 19 | BaseOffloadStrategy, 20 | NullMemoryOffloadStrategy, 21 | CPUPinnedMemoryOffloadStrategy, 22 | CPUPagedMemoryOffloadStrategy, 23 | GPUMemoryOffloadStrategy, 24 | BaseFunctionStrategy, 25 | NonValidatedFunctionStrategy, 26 | ) 27 | -------------------------------------------------------------------------------- /flex_model/distributed/mappings.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List, Optional, Tuple, Union 3 | 4 | import torch 5 | from torch.distributed import ProcessGroup 6 | from torch import Tensor 7 | 8 | from flex_model.distributed.distributed_state import _ParallelStateAPI 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def _log_shape(rank, fn_name, in_shape, out_shape): 14 | logger.debug( 15 | f"Local rank{rank} - {fn_name} | Input: {in_shape} -> {out_shape}" 16 | ) 17 | 18 | 19 | def unity(tensor: Tensor, fmps: _ParallelStateAPI) -> Tensor: 20 | """No-op function. 21 | 22 | :param Tensor tensor: Activation tensor. 23 | :param _ParallelStateAPI fmps: FlexModel parallel state handle. 24 | 25 | :returns: Input tensor unmodified. 26 | :rtype: Tensor 27 | """ 28 | rank = fmps.get_local_rank() 29 | _log_shape(rank, "unity", tensor.shape, tensor.shape) 30 | 31 | return tensor 32 | 33 | 34 | def _core_broadcast( 35 | tensor: Tensor, pg: ProcessGroup, rank: int, world_size: int 36 | ) -> Tensor: 37 | in_shape = tensor.shape 38 | 39 | if world_size == 1: 40 | _log_shape(rank, "_core_broadcast", in_shape, tensor.shape) 41 | return tensor 42 | 43 | # We only interact among tensor parallel group to bcast 44 | torch.distributed.broadcast( 45 | tensor=tensor, 46 | src=0, 47 | group=pg, 48 | async_op=False, 49 | ) 50 | _log_shape(rank, "_core_broadcast", in_shape, tensor.shape) 51 | return tensor 52 | 53 | 54 | def broadcast_tensor_parallel( 55 | tensor: Tensor, fmps: _ParallelStateAPI 56 | ) -> Tensor: 57 | """Broadcast tensor to all ranks in the tensor parallel group. 58 | 59 | :param Tensor tensor: Activation tensor. 60 | :param _ParallelStateAPI fmps: FlexModel parallel state handle. 61 | 62 | :returns: Input tensor unmodified. 63 | :rtype: Tensor 64 | """ 65 | group = fmps.get_tensor_parallel_group() 66 | rank = fmps.get_tensor_parallel_rank() 67 | world_size = fmps.get_tensor_parallel_world_size() 68 | return _core_broadcast(tensor, group, rank, world_size) 69 | 70 | 71 | def broadcast_data_parallel(tensor: Tensor, fmps: _ParallelStateAPI) -> Tensor: 72 | """Broadcast tensor to all ranks in the data parallel group. 73 | 74 | :param Tensor tensor: Activation tensor. 75 | :param _ParallelStateAPI fmps: FlexModel parallel state handle. 76 | 77 | :returns: Input tensor unmodified. 78 | :rtype: Tensor 79 | """ 80 | group = fmps.get_data_parallel_group() 81 | rank = fmps.get_data_parallel_rank() 82 | world_size = fmps.get_data_parallel_world_size() 83 | return _core_broadcast(tensor, group, rank, world_size) 84 | 85 | 86 | def _core_all_gather( 87 | tensor: Tensor, dim: int, group: ProcessGroup, rank: int, world_size: int 88 | ): 89 | in_shape = tensor.shape 90 | 91 | # Clone here to prevent in-place exceptions. 92 | tensor = tensor.clone() 93 | 94 | if world_size == 1: 95 | _log_shape(rank, "_core_all_gather", in_shape, tensor.shape) 96 | return tensor 97 | 98 | tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] 99 | tensor_list[rank] = tensor 100 | 101 | torch.distributed.all_gather( 102 | tensor_list, 103 | tensor, 104 | group=group, 105 | async_op=False, 106 | ) 107 | 108 | output_tensor = torch.cat(tensor_list, dim=dim) 109 | 110 | _log_shape(rank, "_core_all_gather", in_shape, output_tensor.shape) 111 | 112 | return output_tensor 113 | 114 | 115 | def all_gather_tensor_parallel( 116 | tensor: Tensor, dim: int, fmps: _ParallelStateAPI 117 | ) -> Tensor: 118 | """All-to-all gather tensors from ranks in the tensor parallel group. 119 | 120 | :param Tensor tensor: Activation tensor. 121 | :param _ParallelStateAPI fmps: FlexModel parallel state handle. 122 | 123 | :returns: Input tensor unmodified. 124 | :rtype: Tensor 125 | """ 126 | group = fmps.get_tensor_parallel_group() 127 | rank = fmps.get_tensor_parallel_rank() 128 | world_size = fmps.get_tensor_parallel_world_size() 129 | return _core_all_gather(tensor, dim, group, rank, world_size) 130 | 131 | 132 | def all_gather_data_parallel( 133 | tensor: Tensor, dim: int, fmps: _ParallelStateAPI 134 | ) -> Tensor: 135 | """All-to-all gather tensors from ranks in the data parallel group. 136 | 137 | :param Tensor tensor: Activation tensor. 138 | :param _ParallelStateAPI fmps: FlexModel parallel state handle. 139 | 140 | :returns: Input tensor unmodified. 141 | :rtype: Tensor 142 | """ 143 | group = fmps.get_data_parallel_group() 144 | rank = fmps.get_data_parallel_rank() 145 | world_size = fmps.get_data_parallel_world_size() 146 | return _core_all_gather(tensor, dim, group, rank, world_size) 147 | 148 | 149 | def _core_scatter( 150 | tensor: Tensor, dim: int, group: ProcessGroup, rank: int, world_size: int 151 | ): 152 | in_shape = tensor.shape 153 | 154 | if world_size == 1: 155 | _log_shape(rank, "_core_scatter", in_shape, tensor.shape) 156 | return tensor 157 | 158 | input_list = torch.chunk(tensor, world_size, dim=dim) 159 | output_tensor = input_list[rank].contiguous() 160 | 161 | _log_shape(rank, "_core_scatter", in_shape, output_tensor.shape) 162 | 163 | return output_tensor 164 | 165 | 166 | def scatter_tensor_parallel( 167 | tensor: Tensor, dim: int, fmps: _ParallelStateAPI 168 | ) -> Tensor: 169 | """Scatter tensors to ranks in the tensor parallel group. 170 | 171 | :param Tensor tensor: Activation tensor. 172 | :param _ParallelStateAPI fmps: FlexModel parallel state handle. 173 | 174 | :returns: Input tensor unmodified. 175 | :rtype: Tensor 176 | """ 177 | group = fmps.get_tensor_parallel_group() 178 | rank = fmps.get_tensor_parallel_rank() 179 | world_size = fmps.get_tensor_parallel_world_size() 180 | return _core_scatter(tensor, dim, group, rank, world_size) 181 | 182 | 183 | def scatter_data_parallel( 184 | tensor: Tensor, dim: int, fmps: _ParallelStateAPI 185 | ) -> Tensor: 186 | """Scatter tensors to ranks in the data parallel group. 187 | 188 | :param Tensor tensor: Activation tensor. 189 | :param _ParallelStateAPI fmps: FlexModel parallel state handle. 190 | 191 | :returns: Input tensor unmodified. 192 | :rtype: Tensor 193 | """ 194 | group = fmps.get_data_parallel_group() 195 | rank = fmps.get_data_parallel_rank() 196 | world_size = fmps.get_data_parallel_world_size() 197 | return _core_scatter(tensor, dim, group, rank, world_size) 198 | 199 | 200 | def _group_by_dtype( 201 | tensor_dict: Dict[str, Tensor] 202 | ) -> Dict[torch.dtype, Dict[str, Tensor]]: 203 | dtypes = [torch.float32, torch.float16, torch.bfloat16] 204 | dtype_groups: Dict[torch.dtype, Dict[str, Tensor]] = { 205 | dtype: {} for dtype in dtypes 206 | } 207 | 208 | for name, tensor in tensor_dict.items(): 209 | assert tensor.dtype in dtype_groups, ( 210 | f"Tensor with dtype: {tensor.dtype} is not supported for " 211 | f"gathering across PP ranks." 212 | ) 213 | dtype_groups[tensor.dtype][name] = tensor 214 | 215 | return dtype_groups 216 | 217 | 218 | # Tensor buffer metadata type. 219 | _TBUF_META = Dict[ 220 | str, 221 | Union[ 222 | int, 223 | torch.dtype, 224 | Dict[str, Tuple[int, int]], 225 | Dict[str, torch.Size], 226 | ], 227 | ] 228 | 229 | 230 | def _make_flat_buffer( 231 | fmps: _ParallelStateAPI, 232 | tensor_dict: Dict[str, Tensor], 233 | ) -> Tuple[Optional[Tensor], Optional[_TBUF_META]]: 234 | tensors = [] 235 | name_to_index_map = {} 236 | name_to_shape_map = {} 237 | curr_idx = 0 238 | for name, tensor in tensor_dict.items(): 239 | shape = tensor.shape 240 | numel = tensor.numel() 241 | tensors.append(tensor.flatten()) 242 | 243 | name_to_index_map[name] = (curr_idx, curr_idx + numel) 244 | name_to_shape_map[name] = shape 245 | 246 | curr_idx += numel 247 | 248 | if len(tensors) == 0: 249 | return None, None 250 | 251 | tensor_buffer = torch.cat(tensors) 252 | 253 | meta: _TBUF_META = { 254 | "buffer_rank": fmps.get_pipeline_parallel_rank(), 255 | "buffer_size": tensor_buffer.numel(), 256 | "buffer_dtype": tensor_buffer.dtype, 257 | "name_to_index_map": name_to_index_map, 258 | "name_to_shape_map": name_to_shape_map, 259 | } 260 | 261 | return tensor_buffer, meta 262 | 263 | 264 | def _gather_pipeline_parallel( 265 | fmps: _ParallelStateAPI, 266 | tbuf_groups: Dict[torch.dtype, Optional[Tensor]], 267 | all_metadata_groups: List[Optional[Dict[torch.dtype, _TBUF_META]]], 268 | ) -> Dict[str, Tensor]: 269 | rank = fmps.get_pipeline_parallel_rank() 270 | 271 | # Setup collections for communication 272 | def _empty_groups() -> Dict[torch.dtype, List[Union[Tensor, int]]]: 273 | return {dtype: [] for dtype in tbuf_groups.keys()} 274 | 275 | recv_tbuf_groups = _empty_groups() 276 | recv_rank_groups = _empty_groups() 277 | send_tbuf_groups = _empty_groups() 278 | send_rank_groups = _empty_groups() 279 | 280 | # Construct recv tensors and src ranks. 281 | # NOTE: Only rank0 participates in recvs. 282 | if rank == 0: 283 | for metadata_groups in all_metadata_groups: 284 | # Skip if the rank has no tbufs to recv for any dtype. 285 | if metadata_groups is None: 286 | continue 287 | 288 | for dtype, metadata in metadata_groups.items(): 289 | # Skip if there's no tbuf to recv for the dtype or the source 290 | # rank is 0 (rank0 never sends). 291 | if metadata is None or metadata["buffer_rank"] == 0: 292 | continue 293 | 294 | buffer_rank = metadata["buffer_rank"] 295 | buffer_size = metadata["buffer_size"] 296 | buffer_dtype = metadata["buffer_dtype"] 297 | assert ( 298 | buffer_dtype == dtype 299 | ), f"Dtype mismatch: {buffer_dtype} and {dtype}" 300 | 301 | tbuf = torch.empty((buffer_size,), dtype=buffer_dtype) 302 | src_rank = buffer_rank 303 | recv_tbuf_groups[dtype].append(tbuf) 304 | recv_rank_groups[dtype].append(src_rank) 305 | 306 | logger.debug( 307 | f"Rank{rank}: Constructed recv - " 308 | f"({tbuf.numel()}) [{src_rank}] -> [0]" 309 | ) 310 | 311 | # Construct send tensors and dst ranks. 312 | # NOTE: Only non-rank0 participate in sends. 313 | else: 314 | for dtype, tbuf in tbuf_groups.items(): 315 | # Skip if there's no tbuf to send for the dtype. 316 | if tbuf is None: 317 | continue 318 | 319 | # Send dst always rank0. 320 | send_tbuf_groups[dtype].append(tbuf) 321 | send_rank_groups[dtype].append(0) 322 | 323 | logger.debug( 324 | f"Rank{rank}: Constructed send - " 325 | f"({tbuf.numel()}) [{rank}] -> [0]" 326 | ) 327 | 328 | def _set_device(_buffer_list, device): 329 | return [_buffer.to(device) for _buffer in _buffer_list] 330 | 331 | # Batched communication across all dtype groups. 332 | all_recv_tbufs = [] 333 | all_recv_ranks = [] 334 | all_send_tbufs = [] 335 | all_send_ranks = [] 336 | for dtype in tbuf_groups.keys(): 337 | recv_tbufs = _set_device( 338 | recv_tbuf_groups[dtype], device=torch.cuda.current_device() 339 | ) 340 | send_tbufs = _set_device( 341 | send_tbuf_groups[dtype], device=torch.cuda.current_device() 342 | ) 343 | all_recv_tbufs.extend(recv_tbufs) 344 | all_recv_ranks.extend(recv_rank_groups[dtype]) 345 | all_send_tbufs.extend(send_tbufs) 346 | all_send_ranks.extend(send_rank_groups[dtype]) 347 | 348 | batch_isend_irecv_pipeline_parallel( 349 | fmps, 350 | all_recv_tbufs, 351 | all_recv_ranks, 352 | all_send_tbufs, 353 | all_send_ranks, 354 | ) 355 | all_recv_tbufs = _set_device(all_recv_tbufs, device="cpu") 356 | all_send_tbufs = _set_device(all_send_tbufs, device="cpu") 357 | 358 | # Unshard each tbuf into individual tensors. 359 | output_tensor_dict: Dict[str, Tensor] = {} 360 | if rank == 0: 361 | 362 | def _reshard_tbuf(meta, tbuf): 363 | for name, (start, end) in meta["name_to_index_map"].items(): 364 | shape = meta["name_to_shape_map"][name] 365 | output_tensor_dict[name] = tbuf[start:end].reshape(shape) 366 | 367 | # Add rank0 local tbufs. 368 | for dtype, tbuf in tbuf_groups.items(): 369 | meta = all_metadata_groups[0][dtype] 370 | if meta is not None: 371 | _reshard_tbuf(meta, tbuf) 372 | 373 | # Add gathered tbufs. 374 | for recv_tbuf, recv_r in zip(all_recv_tbufs, all_recv_ranks): 375 | dtype = recv_tbuf.dtype 376 | meta = all_metadata_groups[recv_r][dtype] 377 | 378 | buf_rank = meta["buffer_rank"] 379 | buf_dtype = meta["buffer_dtype"] 380 | assert ( 381 | buf_dtype == dtype 382 | ), f"Dtype mismatch: {buf_dtype} and {dtype}" 383 | assert buf_rank == recv_r, f"Rank mismatch: {buf_rank} and {recv_r}" 384 | 385 | _reshard_tbuf(meta, recv_tbuf) 386 | 387 | return output_tensor_dict 388 | 389 | 390 | def batch_isend_irecv_pipeline_parallel( 391 | fmps: _ParallelStateAPI, 392 | recv_tensors: List[Tensor], 393 | recv_from_ranks: List[int], 394 | send_tensors: List[Tensor], 395 | send_to_ranks: List[int], 396 | ) -> None: 397 | """Run batched peer-to-peer communications. 398 | 399 | :param List[Tensor] recv_tensors: Tensors to receive. 400 | :param List[int] recv_from_ranks: Ranks to receive from. 401 | :param List[Tensor] send_tensors: Tensors to send. 402 | :param List[int] send_to_ranks: Ranks to send to. 403 | """ 404 | rank = fmps.get_pipeline_parallel_rank() 405 | group = fmps.get_pipeline_parallel_group() 406 | 407 | assert len(recv_tensors) == len(recv_from_ranks), ( 408 | f"Mistmatch in recv tensors({len(recv_tensors)}) and " 409 | f"recv ranks({len(recv_from_ranks)})" 410 | ) 411 | assert len(send_tensors) == len(send_to_ranks), ( 412 | f"Mistmatch in send tensors({len(send_tensors)}) and " 413 | f"send ranks({len(send_to_ranks)})" 414 | ) 415 | 416 | p2p_ops = [] 417 | for recv_t, recv_r in zip(recv_tensors, recv_from_ranks): 418 | op = torch.distributed.P2POp( 419 | torch.distributed.irecv, 420 | recv_t, 421 | peer=recv_r, 422 | group=group, 423 | ) 424 | p2p_ops.append(op) 425 | 426 | logger.debug(f"Rank{rank}: P2POp (irecv) [{rank}] <- [{recv_r}]") 427 | 428 | for send_t, send_r in zip(send_tensors, send_to_ranks): 429 | op = torch.distributed.P2POp( 430 | torch.distributed.isend, 431 | send_t, 432 | peer=send_r, 433 | group=group, 434 | ) 435 | p2p_ops.append(op) 436 | 437 | logger.debug(f"Rank{rank}: P2POp (isend) [{rank}] -> [{send_r}]") 438 | 439 | if len(p2p_ops) == 0: 440 | return 441 | 442 | logger.debug(f"Rank{rank}: Launching P2POps") 443 | 444 | reqs = torch.distributed.batch_isend_irecv(p2p_ops) 445 | for req in reqs: 446 | req.wait() 447 | 448 | def _gen_debug_msg(t_list): 449 | return ", ".join([f"({t.numel()}, {t.dtype})" for t in t_list]) 450 | 451 | logger.debug( 452 | f"Rank{rank}: Received buffers - [{_gen_debug_msg(recv_tensors)}]" 453 | ) 454 | logger.debug(f"Rank{rank}: Sent buffers - [{_gen_debug_msg(send_tensors)}]") 455 | 456 | # TODO: Remove after verification that no race cond. occurs. 457 | torch.cuda.synchronize() 458 | 459 | 460 | def gather_pipeline_parallel_tensor_dicts( 461 | fmps: _ParallelStateAPI, 462 | tensor_dict: Dict[str, Tensor], 463 | ) -> Dict[str, Tensor]: 464 | """Gather groups of tensors from ranks of the pipeline group to pipeline rank0. 465 | 466 | Note: Assumes input tensors are on CPU and placed output tensors on CPU. 467 | - This behaviour is subject to change depending on various optimizations. 468 | 469 | :param _ParallelStateAPI fmps: FlexModel parallel state handle. 470 | :param tensor_dict: Some python object that can be pickled. May contain tensors. 471 | :type tensor_dict Dict[str, Tensor]: 472 | 473 | :returns: A collection of the objects sent from all pipeline paralel group ranks. 474 | :rtype: Dict[str, Tensor] 475 | """ 476 | in_shapes = [] 477 | for tensor in tensor_dict.values(): 478 | in_shapes.append(tensor.shape) 479 | 480 | world_size = fmps.get_pipeline_parallel_world_size() 481 | rank = fmps.get_pipeline_parallel_rank() 482 | group = fmps.get_pipeline_parallel_group() 483 | 484 | tensor_dict_groups = _group_by_dtype(tensor_dict) 485 | 486 | # Convert tensor dicts into flattened buffers with metadata. 487 | tbuf_groups = {} 488 | metadata_groups = {} 489 | for dtype, tensor_dict in tensor_dict_groups.items(): 490 | tbuf, meta = _make_flat_buffer(fmps, tensor_dict) 491 | 492 | tbuf_groups[dtype] = tbuf 493 | metadata_groups[dtype] = meta 494 | 495 | # Gather metadata on rank 0 to setup recv tensors. 496 | all_metadata_groups: List[Optional[Dict[torch.dtype, _TBUF_META]]] = [ 497 | None for _ in range(world_size) 498 | ] 499 | torch.distributed.gather_object( 500 | metadata_groups, 501 | all_metadata_groups if rank == 0 else None, 502 | dst=0, 503 | group=group, 504 | ) 505 | 506 | # Communicate. 507 | output_tensor_dict = _gather_pipeline_parallel( 508 | fmps, tbuf_groups, all_metadata_groups 509 | ) 510 | 511 | for in_shape, out_tensor in zip(in_shapes, output_tensor_dict.values()): 512 | _log_shape( 513 | rank, 514 | "gather_pipeline_parallel_tensor_dicts", 515 | in_shape, 516 | out_tensor.shape, 517 | ) 518 | 519 | return output_tensor_dict 520 | -------------------------------------------------------------------------------- /flex_model/distributed/strategies.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional 2 | 3 | from torch import Tensor 4 | 5 | import flex_model.distributed as fm_dist 6 | from flex_model.distributed.distributed_state import _ParallelStateAPI 7 | 8 | """ 9 | We define strategies for: 10 | 1. Routing: Where/how activations are communicated in a 3D mesh. 11 | 2. Offload: Which devices offload tensors to CPU. 12 | 3. Function: Which devices run user-provided editing functions. 13 | 14 | A strategy is a function which defines some operation on a single tensor. 15 | Instantiation of a strategy may require arguments dependent on the specific 16 | strategy, but can be reapplied as long as the model definition and sharding 17 | strategy do not change. 18 | """ 19 | 20 | 21 | class BaseRoutingStrategy: 22 | """ 23 | Defines a routing strategy, which every device participates in. Moves 24 | corresponding tensors via collective communication. 25 | """ 26 | 27 | def __init__(self, prologue_fn, epilogue_fn): 28 | self.prologue_fn = prologue_fn 29 | self.epilogue_fn = epilogue_fn 30 | 31 | @classmethod 32 | def initialize(cls, tensor, expected_shape) -> None: 33 | raise NotImplementedError("Routing Strategy must implement this") 34 | 35 | def execute_prologue(self, tensor: Tensor) -> Tensor: 36 | tensor = self.prologue_fn(tensor) 37 | return tensor 38 | 39 | def execute_epilogue(self, tensor: Tensor) -> Tensor: 40 | tensor = self.epilogue_fn(tensor) 41 | return tensor 42 | 43 | 44 | class ParameterTensorParallelRoutingStrategy(BaseRoutingStrategy): 45 | """Defines a routing strategy for parameter tensors supporting TP sharding.""" 46 | 47 | @classmethod 48 | def initialize( 49 | cls, fmps: _ParallelStateAPI, tensor: Tensor, expected_shape 50 | ) -> None: 51 | # Handle unspecified dimensions. 52 | input_shape = tensor.shape 53 | if expected_shape is None: 54 | expected_shape = tuple(None for _ in range(len(input_shape))) 55 | 56 | full_tensor_shape = tuple( 57 | d1 if d2 is None else d2 58 | for d1, d2 in zip(input_shape, expected_shape) 59 | ) 60 | 61 | # Determine which, if any, dimensions need to be gathered over TP. 62 | gather_dims = [] 63 | for i, (in_dim, full_dim) in enumerate( 64 | zip(input_shape, full_tensor_shape) 65 | ): 66 | if in_dim != full_dim: 67 | gather_dims.append(i) 68 | 69 | if len(gather_dims) < 1: 70 | gather_tp = False 71 | elif len(gather_dims) == 1: 72 | gather_tp = True 73 | else: 74 | # TODO: Multi-dim TP gathering. 75 | raise NotImplementedError( 76 | f"Tensor-parallel routing only supports one dimension, found {len(gather_dims)}" 77 | ) 78 | sharded_dim = gather_dims[0] if len(gather_dims) > 0 else None 79 | 80 | def _gather_only_tp(t): 81 | return fm_dist.all_gather_tensor_parallel(t, sharded_dim, fmps) 82 | 83 | def _scatter_only_tp(t): 84 | return fm_dist.scatter_tensor_parallel(t, sharded_dim, fmps) 85 | 86 | def _unity(t): 87 | return fm_dist.unity(t, fmps) 88 | 89 | if gather_tp: 90 | prologue_fn = _gather_only_tp 91 | epilogue_fn = _scatter_only_tp 92 | else: 93 | prologue_fn = _unity 94 | epilogue_fn = _unity 95 | 96 | return cls(prologue_fn, epilogue_fn) 97 | 98 | 99 | class ActivationTensorAllToAllRoutingStrategy(BaseRoutingStrategy): 100 | """ 101 | Defines a routing strategy which materializes the activation tensor on all 102 | TP and DP ranks via all-gather collectives. 103 | """ 104 | 105 | @classmethod 106 | def initialize( 107 | cls, 108 | fmps: _ParallelStateAPI, 109 | tensor: Optional[Tensor], 110 | expected_shape, 111 | ) -> None: 112 | def _unity(t): 113 | return fm_dist.unity(t, fmps) 114 | 115 | if tensor is None: 116 | return cls(_unity, _unity) 117 | 118 | dp_world_size = fmps.get_data_parallel_world_size() 119 | 120 | # Handle unspecified dimensions. 121 | input_shape = tensor.shape 122 | if expected_shape is None: 123 | expected_shape = tuple(None for _ in range(len(input_shape))) 124 | 125 | full_tensor_shape = tuple( 126 | d1 if d2 is None else d2 127 | for d1, d2 in zip(input_shape, expected_shape) 128 | ) 129 | 130 | # Determine which, if any, dimensions need to be gathered over TP. 131 | gather_dims = [] 132 | for i, (in_dim, full_dim) in enumerate( 133 | zip(input_shape, full_tensor_shape) 134 | ): 135 | if in_dim != full_dim: 136 | gather_dims.append(i) 137 | 138 | if len(gather_dims) < 1: 139 | gather_tp = False 140 | elif len(gather_dims) == 1: 141 | gather_tp = True 142 | else: 143 | # TODO: Multi-dim TP gathering. 144 | raise NotImplementedError( 145 | f"Tensor-parallel routing only supports one dimension, found {len(gather_dims)}" 146 | ) 147 | 148 | # Only relevant if we need to gather_tp. 149 | sharded_dim = gather_dims[0] if len(gather_dims) > 0 else None 150 | 151 | gather_dp = True if dp_world_size > 1 else False 152 | 153 | # Define helper functions for collection/dispersion. 154 | def _gather_only_tp(t): 155 | return fm_dist.all_gather_tensor_parallel(t, sharded_dim, fmps) 156 | 157 | def _scatter_only_tp(t): 158 | return fm_dist.scatter_tensor_parallel(t, sharded_dim, fmps) 159 | 160 | def _gather_only_dp(t): 161 | return fm_dist.all_gather_data_parallel(t, 0, fmps) 162 | 163 | def _scatter_only_dp(t): 164 | return fm_dist.scatter_data_parallel(t, 0, fmps) 165 | 166 | def _gather_tp_then_dp(t): 167 | return fm_dist.all_gather_data_parallel( 168 | fm_dist.all_gather_tensor_parallel(t, sharded_dim, fmps), 169 | 0, 170 | fmps, 171 | ) 172 | 173 | def _scatter_dp_then_tp(t): 174 | return fm_dist.scatter_tensor_parallel( 175 | fm_dist.scatter_data_parallel(t, 0, fmps), 176 | sharded_dim, 177 | fmps, 178 | ) 179 | 180 | if not gather_tp and not gather_dp: 181 | prologue_fn = _unity 182 | epilogue_fn = _unity 183 | 184 | elif not gather_tp and gather_dp: 185 | prologue_fn = _gather_only_dp 186 | epilogue_fn = _scatter_only_dp 187 | 188 | elif gather_tp and not gather_dp: 189 | prologue_fn = _gather_only_tp 190 | epilogue_fn = _scatter_only_tp 191 | 192 | else: 193 | prologue_fn = _gather_tp_then_dp 194 | epilogue_fn = _scatter_dp_then_tp 195 | 196 | return cls(prologue_fn, epilogue_fn) 197 | 198 | 199 | class BaseOffloadStrategy: 200 | """ 201 | Defines an offload strategy, which each device may or may not participate 202 | in. Offloading means taking the tensor and disconnecting it from any 203 | computation graph for separate downstream processing. 204 | """ 205 | 206 | def __init__(self, offload_fn): 207 | self.offload_fn = offload_fn 208 | 209 | def execute(self, tensor: Tensor) -> None: 210 | self.offload_fn(tensor) 211 | 212 | 213 | class NullMemoryOffloadStrategy(BaseOffloadStrategy): 214 | @classmethod 215 | def initialize(cls, name: str, output_ptr: Dict[str, List[Tensor]]) -> None: 216 | return cls(lambda x: x) 217 | 218 | 219 | class CPUPinnedMemoryOffloadStrategy(BaseOffloadStrategy): 220 | @classmethod 221 | def initialize(cls, name: str, output_ptr: Dict[str, List[Tensor]]) -> None: 222 | def _offload(t): 223 | if name not in output_ptr: 224 | output_ptr[name] = [] 225 | output_ptr[name].append(t.detach().to("cpu", non_blocking=True)) 226 | 227 | return cls(_offload) 228 | 229 | 230 | class CPUPagedMemoryOffloadStrategy(BaseOffloadStrategy): 231 | @classmethod 232 | def initialize(cls, name: str, output_ptr: Dict[str, List[Tensor]]) -> None: 233 | def _offload(t): 234 | if name not in output_ptr: 235 | output_ptr[name] = [] 236 | 237 | output_ptr[name].append(t.detach().cpu()) 238 | 239 | return cls(_offload) 240 | 241 | 242 | class GPUMemoryOffloadStrategy(BaseOffloadStrategy): 243 | @classmethod 244 | def initialize(cls, name: str, output_ptr: Dict[str, List[Tensor]]) -> None: 245 | def _offload(t): 246 | if name not in output_ptr: 247 | output_ptr[name] = [] 248 | 249 | output_ptr[name].append(t.detach().clone()) 250 | 251 | return _offload 252 | 253 | 254 | class BaseFunctionStrategy: 255 | """ 256 | Defines an editing function execution strategy. Can validate an editing 257 | function (ie. check breakpoints, etc.). Determines which tensors have the 258 | editing function applied. 259 | """ 260 | 261 | def __init__(self, user_func: Callable): 262 | self.user_func = user_func 263 | 264 | @classmethod 265 | def initialize(cls, func: Callable): 266 | raise NotImplementedError 267 | 268 | def execute(self, *args, **kwargs) -> Any: 269 | return self.user_func(*args, **kwargs) 270 | 271 | 272 | class NonValidatedFunctionStrategy(BaseFunctionStrategy): 273 | @classmethod 274 | def initialize(cls, user_func: Callable): 275 | def is_valid(fn) -> bool: 276 | return True 277 | 278 | if is_valid(user_func): 279 | return cls(user_func) 280 | else: 281 | raise Exception("Provided editing function is not valid") 282 | -------------------------------------------------------------------------------- /flex_model/package_info.py: -------------------------------------------------------------------------------- 1 | MAJOR = 0 2 | MINOR = 1 3 | PATCH = 0 4 | 5 | VERSION = (MAJOR, MINOR, PATCH) 6 | 7 | __version__ = ".".join(map(str, VERSION)) 8 | 9 | __package_name__ = "flex_model" 10 | __contact_names__ = "Matthew Choi" 11 | __contact_emails__ = "matthew.choi@vectorinstitute.ai" 12 | __homepage__ = "https://flexmodel.readthedocs.io/en/latest/index.html" 13 | __repository_url__ = "https://github.com/VectorInstitute/flex_model" 14 | __description__ = "FlexModel - A Framework for Interpretability of Distributed Large Language Models" 15 | __license__ = "mit" 16 | -------------------------------------------------------------------------------- /flex_model/traverse/__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import ( 2 | InternalNode, 3 | InternalObject, 4 | LeafNode, 5 | LeafObject, 6 | ScalarNode, 7 | ScalarObject, 8 | ) 9 | from .ops import flatten, unflatten 10 | -------------------------------------------------------------------------------- /flex_model/traverse/nodes.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 4 | 5 | from torch import Tensor 6 | from transformers.modeling_outputs import BaseModelOutputWithPast 7 | 8 | # Types which do not have classes 9 | # TODO: Add high-level type for root node tree definitions 10 | InternalObject = Any 11 | LeafObject = Any 12 | ScalarObject = Any 13 | ScalarNode = Any 14 | 15 | 16 | _INTERNAL_NODE_TYPE_REGISTRY: Dict[type, InternalNode] = {} 17 | _LEAF_NODE_TYPE_REGISTRY: Dict[type, LeafNode] = {} 18 | 19 | 20 | def register_internal_node_type(internal_node_type: type) -> Callable: 21 | """Decorator for registering :class:`InternalNode` classes with a 22 | corresponding :code:`type`. 23 | 24 | :param type internal_node_type: The :code:`type` associated with the node. 25 | 26 | :returns: Inner function which registers the :class:`InternalNode` child 27 | class with the :code:`type`. 28 | """ 29 | 30 | def _inner(_internal_node_cls: InternalNode) -> InternalNode: 31 | _INTERNAL_NODE_TYPE_REGISTRY[internal_node_type] = _internal_node_cls 32 | return _internal_node_cls 33 | 34 | return _inner 35 | 36 | 37 | def register_leaf_node_type(leaf_node_type: type) -> Callable: 38 | """Decorator for registering :class:`LeafNode` classes with a 39 | corresponding :code:`type`. 40 | 41 | :param type internal_node_type: The :code:`type` associated with the node. 42 | 43 | :returns: Inner function which registers the :class:`LeafNode` child 44 | class with the :code:`type`. 45 | """ 46 | 47 | def _inner(_leaf_node_cls: LeafNode) -> LeafNode: 48 | _LEAF_NODE_TYPE_REGISTRY[leaf_node_type] = _leaf_node_cls 49 | return _leaf_node_cls 50 | 51 | return _inner 52 | 53 | 54 | class InternalNode: 55 | """Node correponding to unpackable container. These are nodes which we can 56 | always unpack other python objects from to continue traversal. 57 | 58 | :var children: A list of children nodes. 59 | :type children: Optional[List[Union[InternalNode, LeafNode, ScalarNode]]] 60 | """ 61 | 62 | def __init__( 63 | self, 64 | children: Optional[ 65 | List[Union[InternalNode, LeafNode, ScalarNode]] 66 | ] = None, 67 | ) -> None: 68 | self.children = children if children is not None else [] 69 | 70 | def __eq__(self, other: Any) -> bool: 71 | """Traverse subtree checking for node equality recursively. 72 | 73 | :param Any other: Other node defining a subtree to check equality 74 | against. 75 | 76 | :returns: True if the subtrees match, else false. 77 | :rtype: bool 78 | """ 79 | 80 | def _dfs( 81 | node1: Union[InternalNode, LeafNode, ScalarNode], 82 | node2: Union[InternalNode, LeafNode, ScalarNode], 83 | ) -> bool: 84 | # Mismatched types 85 | if type(node1) != type(node2): 86 | return False 87 | 88 | # Leaf node case 89 | if is_leaf_node(node1) and is_leaf_node(node2): 90 | return True 91 | 92 | # Internal node case 93 | elif is_internal_node(node1) and is_internal_node(node2): 94 | if len(node1.children) != len(node2.children): 95 | return False 96 | 97 | subtrees_equal = [] 98 | for n1, n2 in zip(node1.children, node2.children): 99 | subtree = _dfs(n1, n2) 100 | subtrees_equal.append(subtree) 101 | return all(subtrees_equal) 102 | 103 | # Scalar node case 104 | else: 105 | return node1 == node2 106 | 107 | return _dfs(self, other) 108 | 109 | def __repr__(self) -> str: 110 | return f"Node({self.children})" 111 | 112 | def __str__(self) -> str: 113 | return self.__repr__() 114 | 115 | def flatten(self, instance): 116 | """Flatten the associated instance by returning its contents. 117 | :note: Child classes must implement this. 118 | 119 | :note: All flattening functions return a tuple of the unpacked elements. 120 | """ 121 | raise NotImplementedError 122 | 123 | def unflatten(self, children): 124 | """Pack the contents (children) back into the associated container. 125 | :note: Child classes must implement this. 126 | """ 127 | raise NotImplementedError 128 | 129 | 130 | @register_internal_node_type(tuple) 131 | class TupleNode(InternalNode): 132 | """Unpackable node corresponding to tuples.""" 133 | 134 | def __repr__(self) -> str: 135 | return f"TupleNode({self.children})" 136 | 137 | def __str__(self) -> str: 138 | return self.__repr__() 139 | 140 | def flatten(self, instance: Tuple[Any, ...]) -> Tuple[Any, ...]: 141 | """Unpack the tuple. Flattening functions always return a tuple of the 142 | unpacked elements, so this function does nothing. 143 | 144 | :param Tuple[Any, ...] instance: Tuple instance to unpack. 145 | 146 | :returns: A tuple of the unpacked elements. 147 | :rtype: Tuple[Any, ...] 148 | """ 149 | return instance 150 | 151 | def unflatten(self, children: List[Any]) -> Tuple[Any, ...]: 152 | """Re-assemble the tuple. 153 | 154 | :param List[Any] children: List of elements to repack the tuple with. 155 | 156 | :returns: A tuple of the repacked elements. 157 | :rtype: Tuple[Any, ...] 158 | """ 159 | return tuple(child for child in children) 160 | 161 | 162 | @register_internal_node_type(list) 163 | class ListNode(InternalNode): 164 | """Unpackable node corresponding to lists.""" 165 | 166 | def __repr__(self) -> str: 167 | return f"ListNode({self.children})" 168 | 169 | def __str__(self) -> str: 170 | return self.__repr__() 171 | 172 | def flatten(self, instance: List[Any]) -> Tuple[Any, ...]: 173 | """Unpack the list. 174 | 175 | :param List[Any] instance: List instance to unpack. 176 | 177 | :returns: A tuple of the unpacked elements. 178 | :rtype: Tuple[Any, ...] 179 | """ 180 | return tuple(instance) 181 | 182 | def unflatten(self, children: List[Any]) -> List[Any]: 183 | """Re-assemble the list. 184 | 185 | :param List[Any] children: List of elemnets to repack the list with. 186 | 187 | :returns: A list of the repacked elements. 188 | :rtype: List[Any] 189 | """ 190 | return list(child for child in children) 191 | 192 | 193 | @register_internal_node_type(BaseModelOutputWithPast) 194 | class BaseModelOutputWithPastNode(InternalNode): 195 | """Node corresponding to Huggingface BaseModelOutputWithPast object.""" 196 | 197 | def __repr__(self) -> str: 198 | return f"BaseModelOutputWithPastNode({self.children})" 199 | 200 | def flatten( 201 | self, instance: BaseModelOutputWithPast 202 | ) -> Tuple[Any, Any, Any, Any]: 203 | """Unpack the :code:`BaseModelOutputWithPast`. 204 | 205 | :param BaseModelOutputWithPast instance: :code:`BaseModelOutputWithPast` to 206 | unpack. 207 | 208 | :returns: A 4-tuple containing hidden state and other cached values. 209 | :rtype: Tuple[Any, Any, Any, Any, Any] 210 | """ 211 | contents = ( 212 | instance.last_hidden_state, 213 | instance.past_key_values, 214 | instance.hidden_states, 215 | instance.attentions, 216 | ) 217 | return contents 218 | 219 | def unflatten(self, children: List[Any]) -> BaseModelOutputWithPast: 220 | """Re-assemble the :code:`BaseModelOutputWithPast`. 221 | 222 | :param List[Any] children: List of elements to repack the :code:`BaseModelOutputWithPast` 223 | with. 224 | 225 | :returns: A :code:`BaseModelOutputWithPast`. 226 | :rtype: BaseModelOutputWithPast 227 | """ 228 | return BaseModelOutputWithPast(*children) 229 | 230 | 231 | class LeafNode: 232 | """Leaf node, typically corresponding to a tensor. 233 | 234 | :note: Leaf nodes should not hold a ref to the underlying data, only some 235 | metadata. 236 | """ 237 | 238 | def __init__(self, val: Any = None) -> None: 239 | self.val = val 240 | 241 | def __eq__(self, other: Any) -> bool: 242 | raise NotImplementedError 243 | 244 | def __repr__(self) -> str: 245 | return "LeafNode" 246 | 247 | def __str__(self) -> str: 248 | return self.__repr__() 249 | 250 | 251 | @register_leaf_node_type(Tensor) 252 | class TensorNode(LeafNode): 253 | """Leaf node corresponding to a Pytorch tensor.""" 254 | 255 | def __eq__(self, other: Any) -> bool: 256 | if not isinstance(other, TensorNode): 257 | return False 258 | 259 | return self.val == other.val 260 | 261 | def __repr__(self) -> str: 262 | return f"TensorNode<{self.val}>" 263 | 264 | 265 | def get_internal_node(internal_obj: InternalObject) -> InternalNode: 266 | """Retrieve the corresponding :class:`InternalNode` representation of an 267 | :class:`InternalObject`. 268 | 269 | :param InternalObject internal_obj: Target object. 270 | 271 | :returns: The corresponding node representation. 272 | :rtype: InternalNode 273 | """ 274 | return _INTERNAL_NODE_TYPE_REGISTRY[type(internal_obj)] 275 | 276 | 277 | def get_leaf_node(leaf_obj: LeafObject) -> LeafNode: 278 | """Retrieve the corresponding :class:`LeafNode` representation of an 279 | :class:`LeafObject`. 280 | 281 | :param LeafObject leaf_obj: Target object. 282 | 283 | :returns: The corresponding node representation. 284 | :rtype: LeafNode 285 | """ 286 | return _LEAF_NODE_TYPE_REGISTRY[type(leaf_obj)] 287 | 288 | 289 | # TODO: Deprecate in favour of _flatten 290 | def _recursively_find_first_tensor( 291 | obj: Union[InternalObject, LeafObject, ScalarObject] 292 | ) -> Optional[Tensor]: 293 | if is_leaf_obj(obj): 294 | return obj 295 | 296 | if not is_internal_obj(obj): 297 | return 298 | 299 | for ele in obj: 300 | res = _recursively_find_first_tensor(ele) 301 | if res is not None: 302 | return res 303 | 304 | 305 | def is_leaf_obj(obj: Any) -> bool: 306 | """Return true if the object corresponds to a leaf object. 307 | 308 | :param Any obj: Object to query against. 309 | 310 | :returns: True if the object has a corresponding leaf object representation. 311 | :rtype: bool 312 | """ 313 | return type(obj) in _LEAF_NODE_TYPE_REGISTRY 314 | 315 | 316 | def is_internal_obj(obj: Any) -> bool: 317 | """Return true if the object corrsponds to an internal node. 318 | :param Any obj: Object to query against. 319 | 320 | :returns: True if the object has a corresponding internal object representation. 321 | :rtype: bool 322 | """ 323 | return type(obj) in _INTERNAL_NODE_TYPE_REGISTRY 324 | 325 | 326 | def is_leaf_node(node: Any) -> bool: 327 | """Return true if the object is a leaf node. 328 | :param Any obj: Object to query against. 329 | 330 | :returns: True if the object has a corresponding leaf node representation. 331 | :rtype: bool 332 | """ 333 | return isinstance(node, LeafNode) 334 | 335 | 336 | def is_internal_node(node: Any) -> bool: 337 | """Return true if the object is an internal node. 338 | :param Any obj: Object to query against. 339 | 340 | :returns: True if the object has a corresponding internal node representation. 341 | :rtype: bool 342 | """ 343 | return isinstance(node, InternalNode) 344 | -------------------------------------------------------------------------------- /flex_model/traverse/ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, List, Optional, Tuple, Union 4 | 5 | from torch import Tensor 6 | 7 | from .nodes import ( 8 | InternalNode, 9 | LeafNode, 10 | ScalarNode, 11 | get_internal_node, 12 | get_leaf_node, 13 | is_internal_node, 14 | is_internal_obj, 15 | is_leaf_node, 16 | is_leaf_obj, 17 | ) 18 | 19 | 20 | def flatten( 21 | root_obj: Any, 22 | ) -> Tuple[Union[InternalNode, LeafNode, ScalarNode], List[Optional[Tensor]]]: 23 | """Flatten an arbitrary python object into a tree definition and a 24 | collection of leaves. These can then be repacked by :code:`unflatten` to 25 | perfectly reconstruct the original python object. The python object is 26 | recursively unpacked using node representations, which each locally know 27 | how to unpack themselves. 28 | 29 | :note: The traversal is done in a depth-first way to bias us towards 30 | finding the left-most leaf node first. 31 | 32 | :param Any root_obj: The python object to flatten. 33 | 34 | :returns: A tree definition of the python object and a list of leaf 35 | objects (typically Pytorch tensors). 36 | :rtype: Tuple[Union[InternalNode, LeafNode, ScalarNode], List[Optional[Tensor]]] 37 | """ 38 | order = [] 39 | leaves = [] 40 | 41 | def _dfs(obj): 42 | # Leaf obj case 43 | if is_leaf_obj(obj): 44 | leaf_node = get_leaf_node(obj)(val=obj.shape) 45 | order.append(leaf_node) 46 | leaves.append(obj) 47 | return leaf_node 48 | 49 | # Internal obj recursive case 50 | elif is_internal_obj(obj): 51 | # NOTE: Each node needs to know how to flatten its associated type 52 | # instance. Ie. BaseModelOutputWithPast needs to be able to 53 | # return its attributes in a tuple. They should also be able 54 | # to perfectly recreate instances of themselves using a list of 55 | # children. 56 | internal_node = get_internal_node(obj)() 57 | order.append(internal_node) 58 | 59 | # Internal node knows how to unpack its equivalent internal object 60 | unvisited_children = internal_node.flatten(obj) 61 | 62 | # Recurse into internal object's children 63 | for child in unvisited_children: 64 | internal_node.children.append(_dfs(child)) 65 | return internal_node 66 | 67 | # Scalar obj case 68 | else: 69 | # Scalar nodes are just objects 70 | scalar_node = obj 71 | order.append(scalar_node) 72 | return scalar_node 73 | 74 | _dfs(root_obj) 75 | return order[0], leaves 76 | 77 | 78 | def unflatten( 79 | root_node: Union[InternalNode, LeafNode, ScalarNode], 80 | leaves: List[Optional[Tensor]], 81 | ) -> Any: 82 | """Repack a tree definition and list of leaves into the original python 83 | object. 84 | 85 | :param root_node: Root node which defines the tree definition of the python 86 | object. 87 | :type root_node: Union[InternalNode, LeafNode, ScalarNode], leaves: List[Optional[Tensor]] 88 | :param leaves: List of leaf nodes. 89 | :type leaves: List[Optional[Tensor]] 90 | 91 | :returns: The reconstructed python objects. 92 | :rtype: Any 93 | """ 94 | leaves = list(reversed(leaves)) 95 | 96 | def _dfs(node): 97 | # Leaf node case 98 | if is_leaf_node(node): 99 | return leaves.pop() 100 | 101 | # Internal node case 102 | elif is_internal_node(node): 103 | # Node knows how to pack itself up again into its corresponding obj 104 | obj = node.unflatten(_dfs(child) for child in node.children) 105 | return obj 106 | 107 | # Scalar node case 108 | else: 109 | return node 110 | 111 | return _dfs(root_node) 112 | -------------------------------------------------------------------------------- /flex_model/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logger(level): 5 | logging.basicConfig( 6 | format="%(asctime)s | %(name)s | %(funcName)s | %(message)s", 7 | level=level.upper(), 8 | ) 9 | -------------------------------------------------------------------------------- /profiling/README.md: -------------------------------------------------------------------------------- 1 | # FlexModel Performance Testing 2 | To optimize performance of models wrapped with `FlexModel`, this folder is 3 | used to generate PyTorch Kineto profiles using `torch.profiler.profile()`. To 4 | generate these reports for a variety of single/multi-pu experiments, run the 5 | command: 6 | ``` 7 | torchrun --nnodes --nproc_per_node --rdzv_id 6969 \ 8 | profile_hooks.py --dtype bf16 \ 9 | --model_dim 4096 \ 10 | --profile_show 11 | ``` 12 | 13 | This will run the profiler on a test model (see `utils.TestNetwork`) which 14 | uses Megatron core `ColumnParallelLinear` and `RowParallelLinear` layers. 15 | Additional script options can be found inside of the `profile_hooks.py` script. 16 | 17 | # Visualizing Profiles 18 | Visualizations can be created by running the Jupyter Notebook within this 19 | folder. 20 | -------------------------------------------------------------------------------- /profiling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorInstitute/flex_model/94e3cb434d26bc35c8503b4f6f2dd0b500ae90e8/profiling/__init__.py -------------------------------------------------------------------------------- /profiling/profile_hooks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pprint 4 | 5 | import torch 6 | from torch.profiler import ProfilerActivity 7 | 8 | from flex_model.core import FlexModel 9 | from flex_model.utils import setup_logger 10 | from profiling.utils import ( 11 | ExperimentNetworkManager, 12 | init_megatron_dist, 13 | spoof_megatron_config, 14 | ) 15 | 16 | DTYPES = { 17 | "fp32": torch.float32, 18 | "fp16": torch.float16, 19 | "bf16": torch.bfloat16, 20 | } 21 | 22 | 23 | def _add_profile_args(parser): 24 | group = parser.add_argument_group("profile") 25 | group.add_argument("--profile", action="store_true") 26 | group.add_argument("--profile_show", action="store_true") 27 | group.add_argument("--profile_save_profile", action="store_true") 28 | group.add_argument("--profile_warmup_steps", type=int, default=2) 29 | group.add_argument("--profile_active_steps", type=int, default=10) 30 | group.add_argument("--profile_wait_steps", type=int, default=1) 31 | group.add_argument("--profile_dir", type=str) 32 | group.add_argument("--profile_row_limit", type=int, default=5) 33 | group.add_argument("--profile_force_exp", type=str) 34 | return parser 35 | 36 | 37 | def _add_distributed_args(parser): 38 | group = parser.add_argument_group("distributed") 39 | group.add_argument("--tp_size", type=int) 40 | # TODO: Full Megatron-LM port. 41 | # parser.add_argument("--dp_size", type=int) 42 | # parser.add_argument("--pp_size", type=int) 43 | return parser 44 | 45 | 46 | def parse_args(): 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument("--batch_size", type=int, default=4096) 49 | parser.add_argument("--dtypes", type=str, help="Comma-separated dtypes") 50 | parser.add_argument("--model_dim", type=int) 51 | parser.add_argument("--log_level", type=str, default="warning") 52 | parser.add_argument("--single_gpu_only", action="store_true") 53 | parser.add_argument("--multi_gpu_only", action="store_true") 54 | parser.add_argument("--n_layers", type=int, default=32) 55 | parser.add_argument("--debug", action="store_true") 56 | 57 | parser = _add_profile_args(parser) 58 | parser = _add_distributed_args(parser) 59 | 60 | args = parser.parse_args() 61 | 62 | args = validate_args(args) 63 | 64 | return args 65 | 66 | 67 | def validate_args(args): 68 | # Manual override for one dtype. 69 | if args.dtypes is not None and args.dtypes in DTYPES: 70 | dtypes = [] 71 | for d in args.dtypes.split(","): 72 | if d in DTYPES: 73 | dtypes.append(DTYPES[d]) 74 | else: 75 | raise Exception(f"Unsupported dtype provided: {d}") 76 | args.dtype_sweep = dtypes 77 | else: 78 | args.dtype_sweep = list(DTYPES.values()) 79 | 80 | # Manual override for one model dim. 81 | if args.model_dim is not None: 82 | args.model_dim_sweep = [args.model_dim] 83 | else: 84 | args.model_dim_sweep = [2**i for i in range(9, 15)] 85 | 86 | # Debug overrides both dtype and model dim to test all configurations. 87 | if args.debug: 88 | args.dtype_sweep = [v for v in DTYPES.values()] 89 | args.model_dim_sweep = [16, 32] 90 | 91 | # Determine which experiments to run. 92 | assert not ( 93 | args.single_gpu_only and args.multi_gpu_only 94 | ), "Cannot have both single gpu and multi gpu only flags both True." 95 | 96 | # TP is across all gpus by default. 97 | if args.tp_size is None: 98 | args.tp_size = int(os.environ["WORLD_SIZE"]) 99 | 100 | # Make folder for profiling. 101 | if args.profile_save_profile: 102 | if args.profile_dir is None: 103 | args.profile_dir = ( 104 | f"{os.getcwd()}/profiles/profile_{os.environ['SLURM_JOB_ID']}" 105 | ) 106 | if not os.path.isdir(args.profile_dir): 107 | os.makedirs(args.profile_dir, exist_ok=True) 108 | 109 | # Select a subset of experiments to run. 110 | args.exp_prefix = "" 111 | if args.single_gpu_only: 112 | args.exp_prefix = "single_gpu" 113 | elif args.multi_gpu_only: 114 | args.exp_prefix = "multi_gpu" 115 | 116 | return args 117 | 118 | 119 | def print_args(args): 120 | pp = pprint.PrettyPrinter(width=80) 121 | pp.pprint(vars(args)) 122 | 123 | 124 | def main(args): 125 | setup_logger(args.log_level) 126 | 127 | # Silence kineto warnings. 128 | os.environ["KINETO_LOG_LEVEL"] = "5" 129 | 130 | # Initialize distributed and megatron-lm parallel state. 131 | init_megatron_dist(args) 132 | rank = torch.distributed.get_rank() 133 | torch.manual_seed(rank) 134 | if rank == 0: 135 | print_args(args) 136 | 137 | # Construct experiment setup functions and create profile folders. 138 | manager = ExperimentNetworkManager() 139 | experiments = manager.get_experiment_handles(args.exp_prefix) 140 | 141 | for exp in experiments: 142 | os.makedirs(f"{args.profile_dir}/{exp.__name__}", exist_ok=True) 143 | 144 | num_steps = ( 145 | args.profile_wait_steps 146 | * args.profile_warmup_steps 147 | * args.profile_active_steps 148 | ) 149 | 150 | # Profiler setup. 151 | schedule = torch.profiler.schedule( 152 | wait=args.profile_wait_steps, 153 | warmup=args.profile_warmup_steps, 154 | active=args.profile_active_steps, 155 | repeat=1, 156 | ) 157 | 158 | # Run benchmarks for each experiment, sweeping over parameters. 159 | for model_dim in args.model_dim_sweep: 160 | for dtype in args.dtype_sweep: 161 | # Setup inputs and spoof config. 162 | inputs = torch.randn( 163 | num_steps, args.batch_size, model_dim, dtype=dtype 164 | ).cuda() 165 | 166 | # Need to spoof Megatron-LM config so we can use the Col and Row 167 | # parallel layers. 168 | spoof_config = spoof_megatron_config(dtype) 169 | 170 | for exp in experiments: 171 | exp_name = exp.__name__ 172 | if ( 173 | args.profile_force_exp 174 | and exp_name != args.profile_force_exp 175 | ): 176 | continue 177 | 178 | if args.profile_save_profile: 179 | trace_handler = torch.profiler.tensorboard_trace_handler( 180 | f"{args.profile_dir}/{exp.__name__}" 181 | ) 182 | else: 183 | trace_handler = None 184 | 185 | # Setup network. 186 | network = exp( 187 | model_dim=model_dim, 188 | n_layers=args.n_layers, 189 | config=spoof_config, 190 | ).cuda() 191 | 192 | # Run benchmark. 193 | with torch.profiler.profile( 194 | schedule=schedule, 195 | activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 196 | record_shapes=True, 197 | with_flops=True, 198 | profile_memory=True, 199 | on_trace_ready=trace_handler, 200 | ) as prof: 201 | for i in range(num_steps): 202 | network(inputs[i]) 203 | prof.step() 204 | 205 | # Print profiles. 206 | if args.profile_show and rank == 0: 207 | sort_variables = [ 208 | "self_cpu_time_total", 209 | "self_cuda_time_total", 210 | "self_cpu_memory_usage", 211 | "self_cuda_memory_usage", 212 | ] 213 | key_avgs = prof.key_averages(group_by_input_shape=True) 214 | print("=" * 160) 215 | print(f"{exp_name}_{model_dim}_{dtype}") 216 | print(" -> ".join(sort_variables)) 217 | for sort_var in sort_variables: 218 | print( 219 | key_avgs.table( 220 | sort_by=sort_var, 221 | row_limit=args.profile_row_limit, 222 | ) 223 | ) 224 | 225 | # Cleanup. 226 | if isinstance(network, FlexModel): 227 | network.restore() 228 | else: 229 | manager.cleanup() 230 | 231 | 232 | if __name__ == "__main__": 233 | args = parse_args() 234 | main(args) 235 | -------------------------------------------------------------------------------- /profiling/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import os 4 | import random 5 | from collections import defaultdict 6 | from itertools import chain 7 | from typing import Callable, List 8 | 9 | import megatron.core.parallel_state as mpu 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | from megatron.core.parallel_state import ( 14 | initialize_model_parallel as initialize_megatron_model_parallel, 15 | ) 16 | from megatron.core.tensor_parallel import ( 17 | ColumnParallelLinear, 18 | RowParallelLinear, 19 | model_parallel_cuda_manual_seed, 20 | ) 21 | 22 | from flex_model.core import FlexModel, HookFunction 23 | 24 | 25 | class TestNetwork(nn.Module): 26 | """Network for running profiling experiments against. 27 | 28 | Layers are either `nn.Linear` layers, or alternating `ColumnParallelLinear` 29 | and `RowParallelLinear` layers. ReLU activation functions are placed every 30 | two layers too. Will be deprecated in future releases in-favour of 31 | canonical Megatron-LM models. 32 | """ 33 | 34 | def __init__(self, model_dim, n_layers, is_distributed, config): 35 | super().__init__() 36 | self.model_dim = model_dim 37 | self.n_layers = n_layers 38 | self.is_distributed = is_distributed 39 | self.config = config 40 | 41 | assert self.n_layers % 2 == 0 42 | 43 | if self.is_distributed: 44 | # Alternating column parallel - row parallel layers. 45 | init_method = nn.init.xavier_normal_ 46 | layers = list( 47 | chain.from_iterable( 48 | ( 49 | ColumnParallelLinear( 50 | model_dim, 51 | model_dim, 52 | config=config, 53 | init_method=init_method, 54 | ), 55 | RowParallelLinear( 56 | model_dim, 57 | model_dim, 58 | input_is_parallel=True, 59 | config=config, 60 | init_method=init_method, 61 | ), 62 | ) 63 | for _ in range(self.n_layers // 2) 64 | ) 65 | ) 66 | else: 67 | layers = [ 68 | nn.Linear(model_dim, model_dim, dtype=self.config.params_dtype) 69 | for _ in range(self.n_layers) 70 | ] 71 | self.layers = nn.ModuleList(layers) 72 | 73 | self.act_fn = nn.ReLU() 74 | 75 | def forward(self, inputs): 76 | rep = inputs 77 | for i in range(len(self.layers)): 78 | rep = self.layers[i](rep) 79 | 80 | # Column and row parallel layers return tuple(output, output_bias). 81 | if isinstance(rep, tuple): 82 | rep = rep[0] 83 | 84 | if i % 2 == 0: 85 | rep = self.act_fn(rep) 86 | 87 | return rep 88 | 89 | 90 | def hook_function_identity(self, inputs, outputs, name): 91 | """Hook function that does nothing.""" 92 | return outputs 93 | 94 | 95 | def hook_function_cpu(self, inputs, outputs, acc, name): 96 | """Hook function that dumps to cpu.""" 97 | rank = mpu.get_tensor_model_parallel_rank() 98 | 99 | # NOTE: See note about output of col and row parallel. 100 | _outputs = outputs[0] if isinstance(outputs, tuple) else outputs 101 | 102 | if rank == 0: 103 | acc[name] = _outputs.detach().cpu() 104 | return outputs 105 | 106 | 107 | def hook_function_gpu(self, inputs, outputs, acc, name): 108 | """Hook function that dumps to GPU.""" 109 | rank = mpu.get_tensor_model_parallel_rank() 110 | 111 | _outputs = outputs[0] if isinstance(outputs, tuple) else outputs 112 | 113 | if rank == 0: 114 | acc[name] = _outputs.detach() 115 | 116 | return outputs 117 | 118 | 119 | def hook_function_cpu_gather_scatter(self, inputs, outputs, acc, name): 120 | """Hardcoded minimal hook function with gather/scatter.""" 121 | rank = mpu.get_tensor_model_parallel_rank() 122 | 123 | _outputs = outputs[0] if isinstance(outputs, tuple) else outputs 124 | 125 | def all_gather(tensor, dim=-1): 126 | output_list = [ 127 | torch.empty_like(tensor) 128 | for _ in range(mpu.get_tensor_model_parallel_world_size()) 129 | ] 130 | torch.distributed.all_gather( 131 | output_list, 132 | tensor, 133 | group=mpu.get_tensor_model_parallel_group(), 134 | ) 135 | return torch.cat(output_list, dim=dim) 136 | 137 | def scatter(tensor, dim=-1): 138 | scatter_list = list( 139 | torch.chunk( 140 | tensor, mpu.get_tensor_model_parallel_world_size(), dim=dim 141 | ) 142 | ) 143 | return scatter_list[rank] 144 | 145 | # Determine correct gather/scatter functions. 146 | def _no_op(x): 147 | return x 148 | 149 | if isinstance(self, ColumnParallelLinear): 150 | gather_fn = all_gather 151 | scatter_fn = scatter 152 | 153 | elif isinstance(self, RowParallelLinear): 154 | gather_fn = _no_op 155 | scatter_fn = _no_op 156 | 157 | elif isinstance(self, nn.Linear): 158 | gather_fn = _no_op 159 | scatter_fn = _no_op 160 | 161 | else: 162 | raise NotImplementedError 163 | 164 | _outputs = gather_fn(_outputs) 165 | 166 | if rank == 0: 167 | acc[name] = _outputs.detach().cpu() 168 | 169 | _outputs = scatter_fn(_outputs) 170 | 171 | outputs = (_outputs, *outputs[1:]) 172 | 173 | return outputs 174 | 175 | 176 | class ExperimentNetworkManager: 177 | """Contains functions for creating the experiment networks. 178 | 179 | Also tries to cache networks to reduce initialization latency. 180 | """ 181 | 182 | def __init__(self): 183 | self.cpu_acc = {} 184 | self.experiment_prefixes = ["multi_gpu_", "single_gpu_"] 185 | self.named_experiments = list( 186 | filter( 187 | lambda attr: any( 188 | attr.startswith(prefix) 189 | for prefix in self.experiment_prefixes 190 | ), 191 | dir(self), 192 | ) 193 | ) 194 | 195 | self.hook_handles = defaultdict(list) 196 | self.network_cache = None 197 | 198 | def get_experiment_handles(self, prefix: str) -> List[Callable]: 199 | experiments = list( 200 | filter( 201 | lambda name: name.startswith(prefix), 202 | self.named_experiments, 203 | ) 204 | ) 205 | experiment_handles = [getattr(self, e) for e in experiments] 206 | return experiment_handles 207 | 208 | def cleanup(self): 209 | self.cpu_acc.clear() 210 | 211 | self.remove_hooks(self.network_cache[-1]) 212 | 213 | for m in self.network_cache[-1].modules(): 214 | assert len(m._forward_hooks) == 0 215 | 216 | def remove_hooks(self, network): 217 | handles = self.hook_handles.get(network, []) 218 | for handle in handles: 219 | handle.remove() 220 | 221 | def _check_network_cache(self, *args, **kwargs): 222 | if self.network_cache is None: 223 | return False 224 | 225 | cached_args, cached_kwargs = ( 226 | self.network_cache[:-2], 227 | self.network_cache[-2], 228 | ) 229 | for arg, c_arg in zip(args, cached_args): 230 | if arg != c_arg: 231 | return False 232 | 233 | if cached_kwargs != kwargs: 234 | return False 235 | 236 | return True 237 | 238 | def make_network(self, *args, **kwargs): 239 | if self._check_network_cache(*args, **kwargs): 240 | network = self.network_cache[-1] 241 | else: 242 | network = TestNetwork(*args, **kwargs) 243 | self.network_cache = [*args, kwargs, network] 244 | return network 245 | 246 | def _hook_every_layer(self, network, hook_fn): 247 | module_names_to_hook = set( 248 | f"layers.{i}" for i in range(len(network.layers)) 249 | ) 250 | for n, m in network.named_modules(): 251 | if n in module_names_to_hook: 252 | hook_fn = functools.partial(hook_fn, name=n) 253 | handle = m.register_forward_hook(hook_fn) 254 | self.hook_handles[network].append(handle) 255 | module_names_to_hook.remove(n) 256 | 257 | assert ( 258 | len(module_names_to_hook) == 0 259 | ), f"Have left over modules to hook: {module_names_to_hook}" 260 | 261 | def single_gpu_no_hooks(self, model_dim, n_layers, config): 262 | network = self.make_network( 263 | model_dim, n_layers, is_distributed=False, config=config 264 | ) 265 | return network 266 | 267 | def single_gpu_unity_hooks(self, model_dim, n_layers, config): 268 | network = self.make_network( 269 | model_dim, n_layers, is_distributed=False, config=config 270 | ) 271 | 272 | hook_fn = hook_function_identity 273 | 274 | self._hook_every_layer(network, hook_fn) 275 | 276 | return network 277 | 278 | def single_gpu_cpu_hooks(self, model_dim, n_layers, config): 279 | network = self.make_network( 280 | model_dim, n_layers, is_distributed=False, config=config 281 | ) 282 | 283 | hook_fn = functools.partial(hook_function_cpu, acc=self.cpu_acc) 284 | 285 | self._hook_every_layer(network, hook_fn) 286 | 287 | return network 288 | 289 | def multi_gpu_no_hooks(self, model_dim, n_layers, config): 290 | network = self.make_network( 291 | model_dim, n_layers, is_distributed=True, config=config 292 | ) 293 | return network 294 | 295 | def multi_gpu_unity_hooks(self, model_dim, n_layers, config): 296 | network = self.make_network( 297 | model_dim, n_layers, is_distributed=True, config=config 298 | ) 299 | 300 | hook_fn = hook_function_identity 301 | 302 | self._hook_every_layer(network, hook_fn) 303 | 304 | return network 305 | 306 | def multi_gpu_cpu_hooks(self, model_dim, n_layers, config): 307 | network = self.make_network( 308 | model_dim, n_layers, is_distributed=True, config=config 309 | ) 310 | 311 | # NOTE: This will not accumulate full tensors since we don't gather. 312 | hook_fn = functools.partial(hook_function_cpu, acc=self.cpu_acc) 313 | 314 | self._hook_every_layer(network, hook_fn) 315 | 316 | return network 317 | 318 | def multi_gpu_gpu_hooks(self, model_dim, n_layers, config): 319 | network = self.make_network( 320 | model_dim, n_layers, is_distributed=True, config=config 321 | ) 322 | 323 | # NOTE: This will not accumulate full tensors since we don't gather. 324 | hook_fn = functools.partial(hook_function_gpu, acc=self.cpu_acc) 325 | 326 | self._hook_every_layer(network, hook_fn) 327 | 328 | return network 329 | 330 | def multi_gpu_cpu_hooks_with_gather_scatter( 331 | self, model_dim, n_layers, config 332 | ): 333 | network = self.make_network( 334 | model_dim, n_layers, is_distributed=True, config=config 335 | ) 336 | hook_fn = functools.partial( 337 | hook_function_cpu_gather_scatter, acc=self.cpu_acc 338 | ) 339 | 340 | self._hook_every_layer(network, hook_fn) 341 | 342 | return network 343 | 344 | def multi_gpu_flex_model(self, model_dim, n_layers, config): 345 | base_network = self.make_network( 346 | model_dim, n_layers, is_distributed=True, config=config 347 | ) 348 | 349 | network = FlexModel( 350 | base_network, 351 | self.cpu_acc, 352 | tensor_parallel_size=mpu.get_tensor_model_parallel_world_size(), 353 | pipeline_parallel_size=mpu.get_pipeline_model_parallel_world_size(), 354 | data_parallel_size=mpu.get_data_parallel_world_size(), 355 | ) 356 | 357 | module_names_to_hook = set( 358 | f"layers.{i}" for i in range(len(base_network.layers)) 359 | ) 360 | for n in module_names_to_hook: 361 | network.register_forward_hook( 362 | HookFunction( 363 | module_name=n, 364 | expected_shape=(None, model_dim), 365 | ) 366 | ) 367 | return network 368 | 369 | 370 | def init_megatron_dist(args): 371 | """Initialize Megatron-LM parallel state.""" 372 | os.environ["NCCL_IB_DISABLE"] = "1" 373 | initialize_torch_distributed() 374 | initialize_megatron_model_parallel(args.tp_size) 375 | 376 | # Taken from: https://github.com/NVIDIA/Megatron-LM/blob/feac76a79148622d8f2a45d46c08a972a24784a3/megatron/initialize.py#L236 377 | seed = 0 378 | seed += 100 * mpu.get_pipeline_model_parallel_rank() 379 | random.seed(seed) 380 | np.random.seed(seed) 381 | torch.manual_seed(seed) 382 | if torch.cuda.device_count() > 0: 383 | model_parallel_cuda_manual_seed(0) 384 | 385 | 386 | def initialize_torch_distributed(): 387 | """Initialize torch distributed state.""" 388 | device = int(os.environ["LOCAL_RANK"]) 389 | torch.cuda.set_device(device) 390 | 391 | torch.distributed.init_process_group(backend="nccl") 392 | 393 | 394 | def spoof_megatron_config(dtype): 395 | """Spoof the megatron config to initialize megatron core layers.""" 396 | config = argparse.Namespace() 397 | config.perform_initialization = True 398 | config.params_dtype = dtype 399 | config.async_tensor_model_parallel_allreduce = False 400 | config.sequence_parallel = False 401 | config.gradient_accumulation_fusion = False 402 | config.expert_model_parallel_size = 1 403 | config.use_cpu_initialization = False 404 | 405 | return config 406 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /requirements/dev_requirements.txt: -------------------------------------------------------------------------------- 1 | coverage 2 | pre-commit 3 | ruff 4 | sphinx-rtd-theme 5 | -------------------------------------------------------------------------------- /requirements/examples_requirements.txt: -------------------------------------------------------------------------------- 1 | fairscale 2 | sentencepiece 3 | tokenizers 4 | transformers 5 | matplotlib 6 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.1 2 | -------------------------------------------------------------------------------- /requirements/test_requirements.txt: -------------------------------------------------------------------------------- 1 | fairscale 2 | pytest 3 | sentencepiece 4 | submitit 5 | tokenizers 6 | transformers 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import chain 3 | import importlib.util 4 | from setuptools import find_packages, setup 5 | 6 | 7 | spec = importlib.util.spec_from_file_location( 8 | "package_info", "flex_model/package_info.py" 9 | ) 10 | package_info = importlib.util.module_from_spec(spec) 11 | spec.loader.exec_module(package_info) 12 | 13 | 14 | __contact_emails__ = package_info.__contact_emails__ 15 | __contact_names__ = package_info.__contact_names__ 16 | __description__ = package_info.__description__ 17 | __homepage__ = package_info.__homepage__ 18 | __license__ = package_info.__license__ 19 | __package_name__ = package_info.__package_name__ 20 | __repository_url__ = package_info.__repository_url__ 21 | __version__ = package_info.__version__ 22 | 23 | 24 | def req_file(filename, folder="requirements"): 25 | with open(os.path.join(folder, filename), encoding="utf-8") as f: 26 | content = f.readlines() 27 | 28 | return [x.strip() for x in content] 29 | 30 | 31 | install_requires = req_file("requirements.txt") 32 | extras_require = { 33 | "core": req_file("requirements.txt"), 34 | "test": req_file("test_requirements.txt"), 35 | "dev": req_file("dev_requirements.txt"), 36 | "examples": req_file("examples_requirements.txt"), 37 | } 38 | 39 | 40 | extras_require["all"] = list(chain(extras_require.values())) 41 | extras_require["test"] = list( 42 | chain([extras_require["core"], extras_require["test"]]) 43 | ) 44 | extras_require["dev"] = list( 45 | chain( 46 | [extras_require["core"], extras_require["test"], extras_require["dev"]] 47 | ) 48 | ) 49 | extras_require["examples"] = list( 50 | chain([extras_require["core"], extras_require["examples"]]) 51 | ) 52 | 53 | setup( 54 | name=__package_name__, 55 | version=__version__, 56 | description=__description__, 57 | url=__repository_url__, 58 | author=__contact_names__, 59 | author_email=__contact_emails__, 60 | maintainer=__contact_names__, 61 | maintainer_email=__contact_emails__, 62 | license=__license__, 63 | packages=find_packages(), 64 | install_requires=install_requires, 65 | extras_require=extras_require, 66 | ) 67 | --------------------------------------------------------------------------------