├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── NOTICE ├── README.md ├── assets └── arch.svg ├── benchmark ├── README.md ├── bench_case.py ├── jax │ ├── bench_jax.py │ ├── bench_jax_alpa.py │ ├── bench_jax_dp.py │ └── model │ │ ├── __init__.py │ │ ├── gat.py │ │ ├── gpt.py │ │ ├── resnet.py │ │ └── wresnet.py └── torch │ ├── bench_torch.py │ ├── bench_torch_tp.py │ ├── model │ ├── __init__.py │ ├── gat.py │ ├── gpt.py │ ├── gpt_tp.py │ └── wresnet.py │ └── pp │ ├── gpt │ └── speed │ │ ├── batch.sh │ │ ├── easydist_pipeline.py │ │ ├── torchgpipe_pipeline.py │ │ └── vanilla_torch.py │ └── resnet101 │ ├── accuracy │ ├── pipeline.py │ └── vanilla_torch.py │ └── speed │ ├── batch.sh │ ├── easydist_pipeline.py │ ├── resnet │ ├── __init__.py │ ├── bottleneck.py │ └── flatten_sequential.py │ ├── torchgpipe_pipeline.py │ └── vanila_torch.py ├── easydist ├── __init__.py ├── autoflow │ ├── __init__.py │ └── solver.py ├── config.py ├── jax │ ├── __init__.py │ ├── api.py │ ├── bridge.py │ ├── device_mesh.py │ ├── sharding_interpreter.py │ └── utils.py ├── metashard │ ├── __init__.py │ ├── annotation.py │ ├── combination.py │ ├── halo.py │ ├── metair.py │ ├── metaop.py │ └── view_propagation.py ├── platform │ ├── __init__.py │ ├── jax.py │ ├── torch.py │ └── tvm.py ├── torch │ ├── __init__.py │ ├── api.py │ ├── bridge.py │ ├── compile.py │ ├── compile_auto.py │ ├── compile_dp.py │ ├── cuda │ │ ├── __init__.py │ │ ├── mem_allocator.py │ │ └── scheduled_graph_drawer.py │ ├── decomp_utils.py │ ├── device_mesh.py │ ├── experimental │ │ ├── __init__.py │ │ └── pp │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── compile_pipeline.py │ │ │ ├── ed_split_module.py │ │ │ ├── microbatch.py │ │ │ ├── runtime.py │ │ │ ├── split_utils.py │ │ │ └── utils.py │ ├── graph_profile_db.py │ ├── init_helper.py │ ├── mem_allocation_info.py │ ├── mem_anaylize.py │ ├── meta_allocator.py │ ├── passes │ │ ├── __init__.py │ │ ├── allocator_profiler.py │ │ ├── comm_optimize.py │ │ ├── edinfo_utils.py │ │ ├── eliminate_detach.py │ │ ├── fix_bias.py │ │ ├── fix_embedding.py │ │ ├── fix_meta_device.py │ │ ├── fix_view.py │ │ ├── pp_passes.py │ │ ├── process_tag.py │ │ ├── rule_override.py │ │ ├── runtime_prof.py │ │ ├── sharding.py │ │ └── tile_comm.py │ ├── preset_propagation.py │ ├── profiler │ │ ├── __init__.py │ │ ├── csrc │ │ │ ├── cupti_callback_api.cpp │ │ │ ├── cupti_callback_api.h │ │ │ ├── effective_cuda_allocator.cpp │ │ │ ├── effective_cuda_allocator.h │ │ │ ├── profiling_allocator.cpp │ │ │ ├── profiling_allocator.h │ │ │ ├── python_tracer_init.cpp │ │ │ ├── stream_tracer.cpp │ │ │ └── stream_tracer.h │ │ └── stream_tracer.py │ ├── reachability.py │ ├── schedule │ │ ├── __init__.py │ │ ├── efficient_memory_scheduler.py │ │ ├── graph_mem_plan.py │ │ ├── ilp_memory_scheduler.py │ │ ├── joint_learning.py │ │ ├── lifetime_info.py │ │ ├── memory_scheduler.py │ │ ├── rcpsp.py │ │ └── schedule_result.py │ ├── scope_auto │ │ ├── build_scope_modules.py │ │ └── scope_marker.py │ ├── sharding_interpreter.py │ ├── split_utils.py │ ├── spmd_prop_rule.py │ ├── symphonia │ │ ├── __init__.py │ │ └── torch_actor.py │ ├── tensorfield │ │ ├── README.md │ │ ├── __init__.py │ │ ├── csrc │ │ │ └── allocator_interface.cpp │ │ ├── helper.py │ │ ├── interface.py │ │ ├── mem_pool.py │ │ └── server.py │ └── utils.py ├── utils │ ├── __init__.py │ ├── testing │ │ ├── __init__.py │ │ ├── mock.py │ │ └── spawn.py │ └── timer.py └── version.py ├── examples ├── README.md ├── jax │ ├── simple_function.py │ └── simple_model.py ├── torch │ ├── bert_train.py │ ├── cifar10.py │ ├── gnn │ │ ├── data.py │ │ ├── gat.py │ │ └── train.py │ ├── gpt_train.py │ ├── resnet18.py │ ├── resnet_train.py │ ├── simple_ddp.py │ ├── simple_function.py │ ├── simple_model.py │ ├── stable_diffusion.py │ ├── tensorfeild │ │ ├── cifar10.py │ │ ├── cifar10_ray.py │ │ ├── matmul.py │ │ ├── param_group_ray.py │ │ └── param_group_test.py │ └── test_dynamo_export.py └── tvm │ └── test_simple.py ├── pytest.ini ├── requirements └── core-requirements.txt ├── setup.py ├── style.cfg └── tests ├── test_combination ├── test_gather.py ├── test_help_func.py ├── test_identity.py ├── test_reduce.py └── test_try_combination_single.py ├── test_scope_auto ├── mark_scope1.py ├── mark_scope2.py └── mult_mesh1.py ├── test_strategy ├── jax │ ├── simple_function1.py │ └── test_simple_function1.sh └── torch │ └── test_simple_model.py ├── test_torch ├── test_hybrid.py ├── test_pp │ ├── test_reslink.py │ ├── test_runtime.py │ └── test_split.py ├── test_simple.py ├── test_spmd.py └── test_utils.py └── test_unfiyshard ├── test_unifyop.py └── test_view_propagation.py /.gitignore: -------------------------------------------------------------------------------- 1 | # log file of MetaDist 2 | *.metair 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # vscode 135 | .vscode/ 136 | 137 | # setup 138 | dist/ 139 | build/ 140 | 141 | # temporary files 142 | md_compiled 143 | tmp 144 | 145 | # graphviz output files 146 | *.dot 147 | 148 | # cached datasets 149 | data 150 | 151 | # pt files 152 | *.pt 153 | 154 | # txt files 155 | *.txt 156 | 157 | # svg files 158 | *.svg 159 | 160 | # log folder 161 | log -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | 3 | - repo: https://github.com/PyCQA/autoflake 4 | rev: v2.3.1 5 | hooks: 6 | - id: autoflake 7 | name: remove unused variables and imports 8 | args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports'] 9 | 10 | - repo: https://github.com/pre-commit/mirrors-yapf 11 | rev: v0.32.0 12 | hooks: 13 | - id: yapf 14 | name: format python code 15 | entry: yapf 16 | language: python 17 | files: \.py$ 18 | types: [python] 19 | args: ["--style=style.cfg"] 20 | 21 | - repo: https://github.com/pycqa/isort 22 | rev: 5.13.2 23 | hooks: 24 | - id: isort 25 | name: sort python imports 26 | args: ["--profile", "black"] 27 | 28 | - repo: https://github.com/pre-commit/pre-commit-hooks 29 | rev: v4.6.0 30 | hooks: 31 | - id: check-yaml 32 | - id: check-merge-conflict 33 | - id: check-case-conflict 34 | - id: trailing-whitespace 35 | - id: end-of-file-fixer 36 | - id: mixed-line-ending 37 | args: ['--fix=lf'] 38 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Alibaba has adopted a Code of Conduct that we expect project participants to adhere to. 4 | 5 | Please refer to [Alibaba Open Source Code of Conduct](https://github.com/AlibabaDR/community/blob/master/CODE_OF_CONDUCT.md). 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing guidelines 2 | 3 | Contributors are welcome to submit their code and ideas. In a long run, we hope this project can be managed by developers from both inside and outside Alibaba. 4 | 5 | ### Contributor License Agreements 6 | 7 | * Sign CLA of EasyDist: 8 | Please download EasyDist [CLA](https://gist.github.com/alibaba-oss/151a13b0a72e44ba471119c7eb737d74). Follow the instructions to sign it. 9 | 10 | ### Pull Request Checklist 11 | 12 | Here is a checklist to prepare and submit your PR (pull request). 13 | 14 | * Create your own Github branch by forking EasyDist. 15 | * Read the [README](README.md). 16 | * Read the [contributing guidelines](CONTRIBUTING.md). 17 | * Read the [Code of Conduct](CODE_OF_CONDUCT.md). 18 | * Ensure you have signed the 19 | [Contributor License Agreement (CLA)](https://gist.github.com/alibaba-oss/151a13b0a72e44ba471119c7eb737d74). 20 | * Push changes to your personal fork. 21 | * Create a PR with a detail description, if commit messages do not express themselves. 22 | * Submit PR for review and address all feedbacks. 23 | * Wait for merging (done by committers). 24 | 25 | Let's use an example to walk through the list. 26 | 27 | ## An Example of Submitting Code Change to EasyDist 28 | 29 | ### Fork Your Own Branch 30 | 31 | On Github page of [EasyDist](https://github.com/alibaba/easydist), Click **fork** button to create your own easydist repository. 32 | 33 | ### Create Local Repository 34 | ```bash 35 | git clone --recursive https://github.com/your_github/easydist.git 36 | ``` 37 | ### Create a dev Branch (named as your_github_id_feature_name) 38 | ```bash 39 | git branch your_github_id_feature_name 40 | ``` 41 | ### Make Changes and Commit Locally 42 | ```bash 43 | git status 44 | git add files-to-change 45 | git commit -m "messages for your modifications" 46 | ``` 47 | 48 | ### Rebase and Commit to Remote Repository 49 | ```bash 50 | git checkout main 51 | git pull 52 | git checkout your_github_id_feature_name 53 | git rebase main 54 | -- resolve conflict, run test -- 55 | git push --recurse-submodules=on-demand origin your_github_id_feature_name 56 | ``` 57 | 58 | ### Create a PR 59 | Click **New pull request** or **Compare & pull request** button, choose to compare branches easydist/main and your_github/your_github_id_feature_name, and write PR description. 60 | 61 | ### Address Reviewers' Comments 62 | Resolve all problems raised by reviewers and update PR. 63 | 64 | ### Merge 65 | It is done by EasyDist committers. 66 | ___ 67 | 68 | Copyright © Alibaba Group, Inc. 69 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | recursive-include easydist *.cpp *.h 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EasyDist 2 | 3 | EasyDist is an automated parallelization system and infrastructure designed for multiple ecosystems, offering the following key features: 4 | 5 | - **Usability**. With EasyDist, parallelizing your training or inference code to a larger scale becomes effortless with just a single line of change. 6 | 7 | - **Ecological Compatibility**. EasyDist serves as a centralized source of truth for SPMD rules at the operator-level for various machine learning frameworks. Currently, EasyDist currently supports PyTorch, Jax natively, and the TVM Tensor Expression operator for SPMD rules. 8 | 9 | - **Infrastructure**. EasyDist decouples auto-parallel algorithms from specific machine learning frameworks and IRs. This design choice allows for the development and benchmarking of different auto-parallel algorithms in a more flexible manner, leveraging the capabilities and abstractions provided by EasyDist. 10 | 11 | ## One Line of Code for Parallelism 12 | 13 | To parallelize your training loop using EasyDist, you can use the `easydist_compile` decorator. Here's an example of how it can be used with PyTorch: 14 | 15 | ```python 16 | @easydist_compile() 17 | def train_step(net, optimizer, inputs, labels): 18 | 19 | outputs = net(inputs) 20 | loss = nn.CrossEntropyLoss()(outputs, labels) 21 | loss.backward() 22 | 23 | optimizer.step() 24 | optimizer.zero_grad() 25 | 26 | return loss 27 | ``` 28 | 29 | This one-line decorator parallelizes the training step. You can find more examples in the [`./examples/`](./examples/) directory. 30 | 31 | ## Overview 32 | 33 | EasyDist introduces the concept of MetaOp and MetaIR to decouple automatic parallelization methods from specific intermediate representations (IR) and frameworks. Additionally, it presents the ShardCombine Algorithm, which defines operator Single-Program, Multiple-Data (SPMD) sharding rules without requiring manual annotations. The architecture of EasyDist is as follows: 34 | 35 |
36 | 37 |

38 | 39 | ## Installation 40 | 41 | To install EasyDist, you can use pip and install from PyPI: 42 | 43 | ```shell 44 | # For PyTorch users 45 | pip install pai-easydist[torch] 46 | 47 | # For Jax users 48 | pip install pai-easydist[jax] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 49 | ``` 50 | 51 | If you prefer to install EasyDist from source, you can clone the GitHub repository and then install it with the appropriate extras: 52 | 53 | ```shell 54 | git clone https://github.com/alibaba/easydist.git && cd easydist 55 | 56 | # EasyDist with PyTorch installation 57 | pip install -e '.[torch]' 58 | 59 | # EasyDist with Jax installation 60 | pip install -e '.[jax]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 61 | ``` 62 | 63 | ## Contributing 64 | 65 | See CONTRIBUTING.md for details. 66 | 67 | ## Contributors 68 | 69 | EasyDist is developed by Alibaba Group and NUS HPC-AI Lab. This work is supported by [Alibaba Innovative Research(AIR)](https://damo.alibaba.com/air/). 70 | 71 | ## License 72 | 73 | EasyDist is licensed under the Apache License (Version 2.0). See LICENSE file. 74 | This product contains some third-party testcases under other open source licenses. 75 | See the NOTICE file for more information. 76 | 77 | -------------------------------------------------------------------------------- /benchmark/README.md: -------------------------------------------------------------------------------- 1 | ## Benchmark 2 | 3 | ```shell 4 | torchrun --nproc_per_node 2 --master_port 26543 ./benchmark/torch/bench_torch.py 5 | torchrun --nproc_per_node 2 --master_port 26543 ./benchmark/torch/bench_torch_tp.py 6 | ``` 7 | -------------------------------------------------------------------------------- /benchmark/bench_case.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class GPTCase: 6 | batch_size: int = 4 7 | seq_size: int = 1024 8 | num_layers: int = 1 9 | hidden_dim: int = 12288 10 | num_heads: int = 48 11 | dropout_rate: float = 0.0 12 | use_bias: bool = True 13 | dtype = "float32" 14 | 15 | 16 | @dataclass 17 | class ResNetCase: 18 | batch_size: int = 128 19 | 20 | 21 | @dataclass 22 | class GATCase: 23 | num_node: int = 4096 24 | in_feature: int = 12288 25 | out_feature: int = 12288 26 | -------------------------------------------------------------------------------- /benchmark/jax/bench_jax_alpa.py: -------------------------------------------------------------------------------- 1 | # python ./benchmark/bench_jax_alpa.py 2 | 3 | import logging 4 | import os 5 | import sys 6 | 7 | import alpa 8 | import jax 9 | 10 | os.environ["EASYDIST_DEVICE"] = "cuda" 11 | os.environ["EASYDIST_BACKEND"] = "jax" 12 | 13 | from easydist.utils.timer import EDTimer 14 | 15 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 16 | from benchmark.jax.model.gpt import GPTSimple 17 | from benchmark.jax.model.wresnet import resnet18 18 | from benchmark.jax.model.gat import GATLayer 19 | from benchmark.bench_case import GPTCase, ResNetCase, GATCase 20 | 21 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', 22 | datefmt='%m/%d %H:%M:%S', 23 | level=logging.DEBUG) 24 | 25 | 26 | def get_gpt_case(): 27 | case = GPTCase() 28 | model = GPTSimple(case) 29 | 30 | root_key = jax.random.PRNGKey(seed=0) 31 | main_key, params_key = jax.random.split(key=root_key) 32 | input_ = jax.random.normal( 33 | main_key, (case.batch_size, case.seq_size, case.hidden_dim)) # Dummy input data 34 | variables = model.init(params_key, input_, deterministic=True) 35 | params = variables['params'] 36 | 37 | # DataParallel() Zero3Parallel() Zero2Parallel() 38 | @alpa.parallelize(method=alpa.ShardParallel()) 39 | def train_step(params, input_): 40 | lr = 0.0001 41 | 42 | def loss_fn(params): 43 | dropout_key = jax.random.PRNGKey(seed=0) 44 | return model.apply({ 45 | 'params': params 46 | }, 47 | input_, 48 | deterministic=False, 49 | rngs={ 50 | 'dropout': dropout_key 51 | }).mean() 52 | 53 | grad_fn = jax.grad(loss_fn) 54 | grads = grad_fn(params) 55 | params = jax.tree_map(lambda x, y: x - lr * y, params, grads) 56 | return params 57 | 58 | return train_step, [params, input_] 59 | 60 | 61 | def get_resnet_case(): 62 | case = ResNetCase() 63 | model = resnet18() 64 | 65 | key1, key2 = jax.random.split(jax.random.PRNGKey(0), num=2) 66 | input_ = jax.random.normal(key1, (case.batch_size, 224, 224, 3)) # Dummy input data 67 | variables = model.init(key2, input_) # Initialization call 68 | params, batch_stats = variables['params'], variables['batch_stats'] 69 | 70 | @alpa.parallelize(method=alpa.ShardParallel()) 71 | def train_step(params, batch_stats, input_): 72 | lr = 0.0001 73 | 74 | def loss_fn(params, batch_stats): 75 | out_, batch_stats = model.apply({ 76 | 'params': params, 77 | 'batch_stats': batch_stats 78 | }, 79 | input_, 80 | mutable=['batch_stats']) 81 | return out_.mean() 82 | 83 | grad_fn = jax.grad(loss_fn) 84 | grads = grad_fn(params, batch_stats) 85 | params = jax.tree_map(lambda x, y: x - lr * y, params, grads) 86 | return params 87 | 88 | return train_step, [params, batch_stats, input_] 89 | 90 | 91 | def get_gat_case(): 92 | 93 | case = GATCase() 94 | model = GATLayer(case.in_feature, case.out_feature) 95 | 96 | key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), num=3) 97 | h = jax.random.normal(key1, (case.num_node, case.in_feature)) # Dummy input data 98 | adj = jax.random.normal(key2, (case.num_node, case.num_node)) # Dummy input data 99 | variables = model.init(key3, h, adj) # Initialization call 100 | params = variables['params'] 101 | 102 | @alpa.parallelize(method=alpa.ShardParallel()) 103 | def train_step(params, h, adj): 104 | lr = 0.0001 105 | 106 | def loss_fn(params): 107 | return model.apply({'params': params}, h, adj).mean() 108 | 109 | grad_fn = jax.grad(loss_fn) 110 | grads = grad_fn(params) 111 | params = jax.tree_map(lambda x, y: x - lr * y, params, grads) 112 | return params 113 | 114 | return train_step, [params, h, adj] 115 | 116 | 117 | def bench_alpa(func, args): 118 | 119 | def train_step(): 120 | func(*args) 121 | 122 | timer = EDTimer(train_step, in_ms=False) 123 | 124 | elaps_time = timer.time() 125 | 126 | print(f"Time: {elaps_time}") 127 | 128 | 129 | def main(): 130 | os.environ["NVIDIA_TF32_OVERRIDE"] = "0" 131 | print(jax.devices()) 132 | 133 | func, args = get_gat_case() 134 | 135 | bench_alpa(func, args) 136 | 137 | 138 | if __name__ == '__main__': 139 | main() 140 | -------------------------------------------------------------------------------- /benchmark/jax/bench_jax_dp.py: -------------------------------------------------------------------------------- 1 | # python ./benchmark/bench_jax_dp.py 2 | 3 | import os 4 | import sys 5 | import logging 6 | from functools import partial 7 | 8 | import jax 9 | 10 | from easydist import easydist_setup 11 | from easydist.utils.timer import EDTimer 12 | 13 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 14 | from benchmark.bench_case import GPTCase, ResNetCase 15 | from benchmark.jax.model.gpt import GPTSimple 16 | from benchmark.jax.model.wresnet import resnet18 17 | 18 | 19 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', 20 | datefmt='%m/%d %H:%M:%S', 21 | level=logging.INFO) 22 | 23 | 24 | def get_gpt_case(): 25 | case = GPTCase() 26 | model = GPTSimple(case) 27 | 28 | root_key = jax.random.PRNGKey(seed=0) 29 | main_key, params_key = jax.random.split(key=root_key) 30 | input_ = jax.random.normal( 31 | main_key, (case.batch_size, case.seq_size, case.hidden_dim)) # Dummy input data 32 | variables = model.init(params_key, input_, deterministic=True) 33 | params = variables['params'] 34 | 35 | @partial(jax.pmap, axis_name="batch") 36 | def train_step(params, input_): 37 | lr = 0.0001 38 | 39 | def loss_fn(params): 40 | dropout_key = jax.random.PRNGKey(seed=0) 41 | return model.apply({ 42 | 'params': params 43 | }, 44 | input_, 45 | deterministic=False, 46 | rngs={ 47 | 'dropout': dropout_key 48 | }).mean() 49 | 50 | grad_fn = jax.grad(loss_fn) 51 | grads = grad_fn(params) 52 | grads = jax.lax.pmean(grads, axis_name="batch") 53 | params = jax.tree_map(lambda x, y: x - lr * y, params, grads) 54 | return params 55 | 56 | devices = jax.local_devices() 57 | params = jax.device_put_replicated(params, devices) 58 | 59 | def shard_batch(x): 60 | x = x.reshape((len(devices), -1) + x.shape[1:]) 61 | return jax.device_put_sharded(list(x), devices) 62 | 63 | input_ = jax.tree_map(shard_batch, input_) 64 | 65 | return train_step, [params, input_] 66 | 67 | 68 | def get_resnet_case(): 69 | case = ResNetCase() 70 | model = resnet18() 71 | 72 | key1, key2 = jax.random.split(jax.random.PRNGKey(0), num=2) 73 | input_ = jax.random.normal(key1, (case.batch_size, 224, 224, 3)) # Dummy input data 74 | variables = model.init(key2, input_) # Initialization call 75 | params, batch_stats = variables['params'], variables['batch_stats'] 76 | 77 | @partial(jax.pmap, axis_name="batch") 78 | def train_step(params, batch_stats, input_): 79 | lr = 0.0001 80 | 81 | def loss_fn(params, batch_stats): 82 | out_, batch_stats = model.apply({ 83 | 'params': params, 84 | 'batch_stats': batch_stats 85 | }, 86 | input_, 87 | mutable=['batch_stats']) 88 | return out_.mean() 89 | 90 | grad_fn = jax.grad(loss_fn) 91 | grads = grad_fn(params, batch_stats) 92 | grads = jax.lax.pmean(grads, axis_name="batch") 93 | params = jax.tree_map(lambda x, y: x - lr * y, params, grads) 94 | return params 95 | 96 | devices = jax.local_devices() 97 | params = jax.device_put_replicated(params, devices) 98 | batch_stats = jax.device_put_replicated(batch_stats, devices) 99 | 100 | def shard_batch(x): 101 | x = x.reshape((len(devices), -1) + x.shape[1:]) 102 | return jax.device_put_sharded(list(x), devices) 103 | 104 | input_ = jax.tree_map(shard_batch, input_) 105 | 106 | return train_step, [params, batch_stats, input_] 107 | 108 | 109 | def bench_pmap_dp(func, args): 110 | 111 | def train_step(): 112 | func(*args) 113 | 114 | timer = EDTimer(train_step, in_ms=False) 115 | 116 | elaps_time = timer.time() 117 | 118 | print(f"Time: {elaps_time}") 119 | 120 | 121 | def main(): 122 | # setup easydist 123 | easydist_setup(backend="jax", device="cuda") 124 | 125 | print(jax.devices()) 126 | 127 | func, args = get_gpt_case() 128 | 129 | bench_pmap_dp(func, args) 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /benchmark/jax/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt import GPT, GPTSimple, GPTBlock, GPTConfig 2 | from .resnet import ResNet, ResNet18 3 | from .wresnet import resnet18, resnet34, resnet50, wresnet50, wresnet101 4 | from .gat import GATLayer 5 | 6 | __all__ = [ 7 | "GPT", "GPTSimple", "GPTBlock", "GPTConfig", "ResNet", "ResNet18", "resnet18", "resnet34", 8 | "resnet50", "wresnet50", "wresnet101", "GATLayer" 9 | ] 10 | -------------------------------------------------------------------------------- /benchmark/jax/model/gat.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from flax import linen as nn 3 | 4 | 5 | class GATLayer(nn.Module): 6 | in_features: int 7 | out_features: int 8 | 9 | @nn.compact 10 | def __call__(self, h, adj): 11 | wh = nn.Dense(features=self.out_features)(h) 12 | wh1 = nn.Dense(features=1)(wh) 13 | wh2 = nn.Dense(features=1)(wh) 14 | e = nn.leaky_relu(wh1 + wh2.T) 15 | 16 | zero_vec = -10e10 * jax.numpy.ones_like(e) 17 | attention = jax.numpy.where(adj > 0, e, zero_vec) 18 | attention = nn.softmax(attention) 19 | 20 | h_new = jax.numpy.matmul(attention, wh) 21 | 22 | return nn.elu(h_new) 23 | -------------------------------------------------------------------------------- /benchmark/jax/model/resnet.py: -------------------------------------------------------------------------------- 1 | """Flax implementation of ResNet V1.""" 2 | 3 | # Copyright 2023 The Flax Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # See issue #620. 18 | # pytype: disable=wrong-arg-count 19 | 20 | from functools import partial 21 | from typing import Any, Callable, Sequence, Tuple 22 | 23 | from jax import numpy as jnp 24 | from flax import linen as nn 25 | 26 | ModuleDef = Any 27 | 28 | 29 | class ResNetBlock(nn.Module): 30 | """ResNet block.""" 31 | filters: int 32 | conv: ModuleDef 33 | norm: ModuleDef 34 | act: Callable 35 | strides: Tuple[int, int] = (1, 1) 36 | 37 | @nn.compact 38 | def __call__( 39 | self, 40 | x, 41 | ): 42 | residual = x 43 | y = self.conv(self.filters, (3, 3), self.strides)(x) 44 | y = self.norm()(y) 45 | y = self.act(y) 46 | y = self.conv(self.filters, (3, 3))(y) 47 | y = self.norm(scale_init=nn.initializers.zeros_init())(y) 48 | 49 | if residual.shape != y.shape: 50 | residual = self.conv(self.filters, (1, 1), self.strides, name='conv_proj')(residual) 51 | residual = self.norm(name='norm_proj')(residual) 52 | 53 | return self.act(residual + y) 54 | 55 | 56 | class BottleneckResNetBlock(nn.Module): 57 | """Bottleneck ResNet block.""" 58 | filters: int 59 | conv: ModuleDef 60 | norm: ModuleDef 61 | act: Callable 62 | strides: Tuple[int, int] = (1, 1) 63 | 64 | @nn.compact 65 | def __call__(self, x): 66 | residual = x 67 | y = self.conv(self.filters, (1, 1))(x) 68 | y = self.norm()(y) 69 | y = self.act(y) 70 | y = self.conv(self.filters, (3, 3), self.strides)(y) 71 | y = self.norm()(y) 72 | y = self.act(y) 73 | y = self.conv(self.filters * 4, (1, 1))(y) 74 | y = self.norm(scale_init=nn.initializers.zeros_init())(y) 75 | 76 | if residual.shape != y.shape: 77 | residual = self.conv(self.filters * 4, (1, 1), self.strides, 78 | name='conv_proj')(residual) 79 | residual = self.norm(name='norm_proj')(residual) 80 | 81 | return self.act(residual + y) 82 | 83 | 84 | class ResNet(nn.Module): 85 | """ResNetV1.""" 86 | stage_sizes: Sequence[int] 87 | block_cls: ModuleDef 88 | num_classes: int = 1000 89 | num_filters: int = 64 90 | dtype: Any = jnp.float32 91 | act: Callable = nn.relu 92 | conv: ModuleDef = nn.Conv 93 | 94 | @nn.compact 95 | def __call__(self, x, train: bool = True): 96 | conv = partial(self.conv, use_bias=False, dtype=self.dtype) 97 | norm = partial(nn.BatchNorm, 98 | use_running_average=not train, 99 | momentum=0.9, 100 | epsilon=1e-5, 101 | dtype=self.dtype) 102 | 103 | x = conv(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init')(x) 104 | x = norm(name='bn_init')(x) 105 | x = nn.relu(x) 106 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') 107 | for i, block_size in enumerate(self.stage_sizes): 108 | for j in range(block_size): 109 | strides = (2, 2) if i > 0 and j == 0 else (1, 1) 110 | x = self.block_cls(self.num_filters * 2**i, 111 | strides=strides, 112 | conv=conv, 113 | norm=norm, 114 | act=self.act)(x) 115 | x = jnp.mean(x, axis=(1, 2)) 116 | x = nn.Dense(self.num_classes, dtype=self.dtype)(x) 117 | x = jnp.asarray(x, self.dtype) 118 | return x 119 | 120 | 121 | ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock) 122 | -------------------------------------------------------------------------------- /benchmark/torch/bench_torch_tp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torch.optim as optim 6 | from torch.nn.parallel import DistributedDataParallel as DDP 7 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel 8 | from fairscale.nn.model_parallel import get_data_parallel_group 9 | 10 | from easydist.utils.timer import EDTimer 11 | from easydist import easydist_setup 12 | 13 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 14 | from benchmark.bench_case import GPTCase 15 | from benchmark.torch.model.gpt_tp import GPT 16 | 17 | 18 | def get_gpt_case(cuda=True): 19 | 20 | case = GPTCase() 21 | model = GPT(depth=case.num_layers, dim=case.hidden_dim, num_heads=case.num_heads) 22 | data_in = torch.ones(case.batch_size, case.seq_size, case.hidden_dim) 23 | 24 | if cuda: 25 | return model.cuda(), data_in.cuda() 26 | 27 | return model, data_in 28 | 29 | 30 | def bench_tp(model, data_in): 31 | 32 | ddp_model = DDP(model, process_group=get_data_parallel_group()) 33 | optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) 34 | 35 | def train_step(): 36 | optimizer.zero_grad() 37 | out = ddp_model(data_in) 38 | out_grad = torch.ones_like(out) 39 | out.backward(out_grad) 40 | optimizer.step() 41 | 42 | torch.cuda.reset_peak_memory_stats() 43 | 44 | timer = EDTimer(train_step, in_ms=False) 45 | 46 | elaps_time = timer.time() 47 | peak_memory = torch.cuda.max_memory_allocated() 48 | 49 | print(f"Memory: {peak_memory / 1024 / 1024 / 1024} GB") 50 | print(f"Time: {elaps_time}") 51 | 52 | 53 | def main(): 54 | easydist_setup(backend="torch", device="cuda") 55 | # setup distributed 56 | torch.distributed.init_process_group(backend="nccl") 57 | local_rank = int(os.environ["LOCAL_RANK"]) 58 | world_size = torch.distributed.get_world_size() 59 | initialize_model_parallel(world_size) 60 | torch.cuda.set_device(local_rank) 61 | 62 | model, data_in = get_gpt_case(cuda=True) 63 | 64 | bench_tp(model, data_in) 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /benchmark/torch/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .gat import GATLayer 2 | from .gpt import GPT, GPTLayer, FeedForward, SelfAttention 3 | from .wresnet import resnet18, resnet34, resnet50, wresnet50, wresnet101 4 | 5 | __all__ = [ 6 | 'GATLayer', 'GPT', 'GPTLayer', 'FeedForward', 'SelfAttention', 'wresnet50', 'wresnet101', 7 | 'resnet18', 'resnet34', 'resnet50', 'LLAMA', 'LLAMAConfig' 8 | ] 9 | -------------------------------------------------------------------------------- /benchmark/torch/model/gat.py: -------------------------------------------------------------------------------- 1 | # code modified from https://github.com/Diego999/pyGAT 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class GATLayer(nn.Module): 9 | 10 | def __init__(self, in_features, out_features): 11 | super(GATLayer, self).__init__() 12 | self.in_features = in_features 13 | self.out_features = out_features 14 | 15 | self.linear_w = nn.Linear(in_features=in_features, out_features=out_features, bias=None) 16 | 17 | self.linear_a_1 = nn.Linear(in_features=out_features, out_features=1, bias=None) 18 | self.linear_a_2 = nn.Linear(in_features=out_features, out_features=1, bias=None) 19 | 20 | self.leakyrelu = nn.LeakyReLU() 21 | 22 | self.softmax = nn.Softmax(dim=-1) 23 | 24 | def forward(self, h, adj): 25 | wh = self.linear_w(h) 26 | wh1 = self.linear_a_1(wh) 27 | wh2 = self.linear_a_2(wh) 28 | 29 | e = self.leakyrelu(wh1 + wh2.T) 30 | zero_vec = -10e10 * torch.ones(*e.shape, device=e.device, dtype=e.dtype) 31 | attention = torch.where(adj > 0, e, zero_vec) 32 | attention = self.softmax(attention) 33 | 34 | h_new = torch.matmul(attention, wh) 35 | 36 | return F.elu(h_new) 37 | -------------------------------------------------------------------------------- /benchmark/torch/pp/gpt/speed/batch.sh: -------------------------------------------------------------------------------- 1 | # threshold microbatch size 2 | for ((i=1 ; i <= 16 ; i += 1)); do 3 | torchrun --nproc_per_node 4 benchmark/torch/pp/gpt/speed/easydist_pipeline.py --micro-batch-size $i --num-chunks 1 --schedule gpipe 4 | done 5 | 6 | for ((i=1 ; i <= 16 ; i += 1)); do 7 | torchrun --nproc_per_node 4 benchmark/torch/pp/gpt/speed/easydist_pipeline.py --micro-batch-size $i --num-chunks 1 --schedule dapple 8 | done 9 | 10 | for ((i=1 ; i <= 16 ; i += 1)); do 11 | python benchmark/torch/pp/gpt/speed/torchgpipe_pipeline.py --micro-batch-size $i --num-chunks 1 12 | done 13 | 14 | for ((i=1 ; i <= 16 ; i += 1)); do 15 | python benchmark/torch/pp/gpt/speed/vanilla_torch.py --micro-batch-size $i --num-chunks 1 16 | done 17 | 18 | 19 | # num chunks 20 | for ((i=1; i <= 32 ; i *= 2)); do 21 | torchrun --nproc_per_node 4 benchmark/torch/pp/gpt/speed/easydist_pipeline.py --dataset-size 5000 --micro-batch-size 16 --num-chunks $i --schedule gpipe 22 | done 23 | torchrun --nproc_per_node 4 benchmark/torch/pp/gpt/speed/easydist_pipeline.py --dataset-size 5000 --micro-batch-size 16 --num-chunks 34 --schedule gpipe 24 | 25 | for ((i=1; i <= 64 ; i *= 2)); do 26 | torchrun --nproc_per_node 4 benchmark/torch/pp/gpt/speed/easydist_pipeline.py --dataset-size 5000 --micro-batch-size 16 --num-chunks $i --schedule dapple 27 | done 28 | torchrun --nproc_per_node 4 benchmark/torch/pp/gpt/speed/easydist_pipeline.py --dataset-size 5000 --micro-batch-size 16 --num-chunks 98 --schedule dapple 29 | 30 | for ((i=1; i <= 32 ; i *= 2)); do 31 | python benchmark/torch/pp/gpt/speed/torchgpipe_pipeline.py --dataset-size 5000 --micro-batch-size 16 --num-chunks $i 32 | done 33 | python benchmark/torch/pp/gpt/speed/torchgpipe_pipeline.py --dataset-size 5000 --micro-batch-size 16 --num-chunks 34 34 | 35 | for ((i=1; i <= 256 ; i *= 2)); do 36 | python benchmark/torch/pp/gpt/speed/vanilla_torch.py --dataset-size 5000 --micro-batch-size 16 --num-chunks $i 37 | done 38 | 39 | 40 | -------------------------------------------------------------------------------- /benchmark/torch/pp/resnet101/accuracy/vanilla_torch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | # python benchmark/torch/pp/resnet101/accuracy/vanilla_torch.py 16 | import os 17 | import random 18 | import time 19 | 20 | import numpy as np 21 | 22 | import torch 23 | 24 | from torchvision import datasets, transforms 25 | from torchvision.models import resnet18 26 | from torch.profiler import profile, record_function, ProfilerActivity 27 | 28 | from tqdm import tqdm 29 | 30 | 31 | def seed(seed=42): 32 | # Set seed for PyTorch 33 | torch.manual_seed(seed) 34 | # torch.use_deterministic_algorithms(True) 35 | if torch.cuda.is_available(): 36 | torch.cuda.manual_seed(seed) 37 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 38 | # Set seed for numpy 39 | np.random.seed(seed) 40 | # Set seed for built-in Python 41 | random.seed(seed) 42 | # Set(seed) for each of the random number generators in python: 43 | torch.backends.cudnn.deterministic = True 44 | torch.backends.cudnn.benchmark = False 45 | 46 | 47 | criterion = torch.nn.CrossEntropyLoss() 48 | 49 | 50 | def train_step(input, label, model, opt): 51 | opt.zero_grad() 52 | out = model(input) 53 | loss = criterion(out, label) 54 | loss.backward() 55 | opt.step() 56 | return out, loss 57 | 58 | 59 | def test_main(): 60 | seed(1) 61 | 62 | device = torch.device('cuda') 63 | 64 | module = resnet18().train().to(device) 65 | module.fc = torch.nn.Linear(module.fc.in_features, 10).to(device) 66 | batch_size = 1024 67 | 68 | opt = torch.optim.Adam(module.parameters(), foreach=True, capturable=True) 69 | 70 | transform = transforms.Compose([ 71 | transforms.ToTensor(), 72 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 73 | ]) 74 | train_data = datasets.CIFAR10('./data', train=True, download=True, transform=transform) 75 | train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size) 76 | x_batch, y_batch = next(iter(train_dataloader)) 77 | train_step(x_batch.to(device), y_batch.to(device), module, opt) 78 | epochs = 5 79 | for epoch in range(epochs): 80 | all_cnt, correct_cnt, loss_sum = 0, 0, 0 81 | time_start = time.time() 82 | for x_batch, y_batch in tqdm(train_dataloader, dynamic_ncols=True): 83 | x_batch = x_batch.to(device) 84 | y_batch = y_batch.to(device) 85 | if x_batch.size(0) != batch_size: # TODO need to solve this 86 | continue 87 | out, loss = train_step(x_batch, y_batch, module, opt) 88 | all_cnt += len(out) 89 | preds = out.argmax(-1) 90 | correct_cnt += (preds == y_batch).sum() 91 | loss_sum += loss.mean().item() 92 | print( 93 | f'epoch {epoch} train accuracy: {correct_cnt / all_cnt}, loss sum {loss_sum}, avg loss: {loss_sum / all_cnt} ' 94 | f'time: {time.time() - time_start}') 95 | 96 | 97 | if __name__ == '__main__': 98 | test_main() 99 | -------------------------------------------------------------------------------- /benchmark/torch/pp/resnet101/speed/batch.sh: -------------------------------------------------------------------------------- 1 | for ((i=1; i <= 256 ; i *= 2)); do 2 | torchrun --nproc_per_node 4 benchmark/torch/pp/resnet101/speed/easydist_pipeline.py --micro-batch-size 128 --num-chunks $i --schedule gpipe 3 | done 4 | 5 | for ((i=1; i <= 256 ; i *= 2)); do 6 | torchrun --nproc_per_node 4 benchmark/torch/pp/resnet101/speed/easydist_pipeline.py --micro-batch-size 128 --num-chunks $i --schedule dapple 7 | done 8 | 9 | 10 | -------------------------------------------------------------------------------- /benchmark/torch/pp/resnet101/speed/resnet/__init__.py: -------------------------------------------------------------------------------- 1 | """A ResNet implementation but using :class:`nn.Sequential`. :func:`resnet101` 2 | returns a :class:`nn.Sequential` instead of ``ResNet``. 3 | 4 | This code is transformed :mod:`torchvision.models.resnet`. 5 | 6 | """ 7 | from collections import OrderedDict 8 | from typing import Any, List 9 | 10 | from torch import nn 11 | 12 | from resnet.bottleneck import bottleneck 13 | from resnet.flatten_sequential import flatten_sequential 14 | 15 | __all__ = ['resnet101'] 16 | 17 | 18 | def build_resnet(layers: List[int], 19 | num_classes: int = 1000, 20 | inplace: bool = False 21 | ) -> nn.Sequential: 22 | """Builds a ResNet as a simple sequential model. 23 | 24 | Note: 25 | The implementation is copied from :mod:`torchvision.models.resnet`. 26 | 27 | """ 28 | inplanes = 64 29 | 30 | def make_layer(planes: int, 31 | blocks: int, 32 | stride: int = 1, 33 | inplace: bool = False, 34 | ) -> nn.Sequential: 35 | nonlocal inplanes 36 | 37 | downsample = None 38 | if stride != 1 or inplanes != planes * 4: 39 | downsample = nn.Sequential( 40 | nn.Conv2d(inplanes, planes * 4, 41 | kernel_size=1, stride=stride, bias=False), 42 | nn.BatchNorm2d(planes * 4), 43 | ) 44 | 45 | layers = [] 46 | layers.append(bottleneck(inplanes, planes, stride, downsample, inplace)) 47 | inplanes = planes * 4 48 | for _ in range(1, blocks): 49 | layers.append(bottleneck(inplanes, planes, inplace=inplace)) 50 | 51 | return nn.Sequential(*layers) 52 | 53 | # Build ResNet as a sequential model. 54 | model = nn.Sequential(OrderedDict([ 55 | ('conv1', nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)), 56 | ('bn1', nn.BatchNorm2d(64)), 57 | ('relu', nn.ReLU()), 58 | ('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 59 | 60 | ('layer1', make_layer(64, layers[0], inplace=inplace)), 61 | ('layer2', make_layer(128, layers[1], stride=2, inplace=inplace)), 62 | ('layer3', make_layer(256, layers[2], stride=2, inplace=inplace)), 63 | ('layer4', make_layer(512, layers[3], stride=2, inplace=inplace)), 64 | 65 | ('avgpool', nn.AdaptiveAvgPool2d((1, 1))), 66 | ('flat', nn.Flatten()), 67 | ('fc', nn.Linear(512 * 4, num_classes)), 68 | ])) 69 | 70 | # Flatten nested sequentials. 71 | model = flatten_sequential(model) 72 | 73 | # Initialize weights for Conv2d and BatchNorm2d layers. 74 | # Stolen from torchvision-0.4.0. 75 | def init_weight(m: nn.Module) -> None: 76 | if isinstance(m, nn.Conv2d): 77 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 78 | return 79 | 80 | if isinstance(m, nn.BatchNorm2d): 81 | nn.init.constant_(m.weight, 1) 82 | nn.init.constant_(m.bias, 0) 83 | return 84 | 85 | model.apply(init_weight) 86 | 87 | return model 88 | 89 | 90 | def resnet101(**kwargs: Any) -> nn.Sequential: 91 | """Constructs a ResNet-101 model.""" 92 | return build_resnet([3, 4, 23, 3], **kwargs) 93 | -------------------------------------------------------------------------------- /benchmark/torch/pp/resnet101/speed/resnet/bottleneck.py: -------------------------------------------------------------------------------- 1 | """A ResNet bottleneck implementation but using :class:`nn.Sequential`.""" 2 | from collections import OrderedDict 3 | from typing import TYPE_CHECKING, Optional, Tuple, Union 4 | 5 | from torch import Tensor, nn 6 | 7 | from torchgpipe.skip import Namespace, pop, skippable, stash 8 | 9 | __all__ = ['bottleneck'] 10 | 11 | Tensors = Tuple[Tensor, ...] 12 | TensorOrTensors = Union[Tensor, Tensors] 13 | 14 | if TYPE_CHECKING: 15 | NamedModules = OrderedDict[str, nn.Module] 16 | else: 17 | NamedModules = OrderedDict 18 | 19 | 20 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 27 | """1x1 convolution""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 29 | 30 | 31 | @skippable(stash=['identity']) 32 | class Identity(nn.Module): 33 | def forward(self, tensor: Tensor) -> Tensor: # type: ignore 34 | yield stash('identity', tensor) 35 | return tensor 36 | 37 | 38 | @skippable(pop=['identity']) 39 | class Residual(nn.Module): 40 | """A residual block for ResNet.""" 41 | 42 | def __init__(self, downsample: Optional[nn.Module] = None): 43 | super().__init__() 44 | self.downsample = downsample 45 | 46 | def forward(self, input: Tensor) -> Tensor: # type: ignore 47 | identity = yield pop('identity') 48 | if self.downsample is not None: 49 | identity = self.downsample(identity) 50 | return input + identity 51 | 52 | 53 | def bottleneck(inplanes: int, 54 | planes: int, 55 | stride: int = 1, 56 | downsample: Optional[nn.Module] = None, 57 | inplace: bool = False, 58 | ) -> nn.Sequential: 59 | """Creates a bottleneck block in ResNet as a :class:`nn.Sequential`.""" 60 | 61 | layers: NamedModules = OrderedDict() 62 | 63 | ns = Namespace() 64 | layers['identity'] = Identity().isolate(ns) # type: ignore 65 | 66 | layers['conv1'] = conv1x1(inplanes, planes) 67 | layers['bn1'] = nn.BatchNorm2d(planes) 68 | layers['relu1'] = nn.ReLU(inplace=inplace) 69 | 70 | layers['conv2'] = conv3x3(planes, planes, stride) 71 | layers['bn2'] = nn.BatchNorm2d(planes) 72 | layers['relu2'] = nn.ReLU(inplace=inplace) 73 | 74 | layers['conv3'] = conv1x1(planes, planes * 4) 75 | layers['bn3'] = nn.BatchNorm2d(planes * 4) 76 | layers['residual'] = Residual(downsample).isolate(ns) # type: ignore 77 | layers['relu3'] = nn.ReLU(inplace=inplace) 78 | 79 | return nn.Sequential(layers) 80 | -------------------------------------------------------------------------------- /benchmark/torch/pp/resnet101/speed/resnet/flatten_sequential.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Iterator, Tuple 3 | 4 | from torch import nn 5 | 6 | 7 | def flatten_sequential(module: nn.Sequential) -> nn.Sequential: 8 | """flatten_sequentials a nested sequential module.""" 9 | if not isinstance(module, nn.Sequential): 10 | raise TypeError('not sequential') 11 | 12 | return nn.Sequential(OrderedDict(_flatten_sequential(module))) 13 | 14 | 15 | def _flatten_sequential(module: nn.Sequential) -> Iterator[Tuple[str, nn.Module]]: 16 | for name, child in module.named_children(): 17 | # flatten_sequential child sequential layers only. 18 | if isinstance(child, nn.Sequential): 19 | for sub_name, sub_child in _flatten_sequential(child): 20 | yield (f'{name}_{sub_name}', sub_child) 21 | else: 22 | yield (name, child) 23 | -------------------------------------------------------------------------------- /benchmark/torch/pp/resnet101/speed/vanila_torch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import os 16 | import random 17 | import time 18 | 19 | import numpy as np 20 | 21 | import torch 22 | 23 | from torchvision import datasets, transforms 24 | from torchvision.models import resnet101 25 | from torch.profiler import profile, record_function, ProfilerActivity 26 | 27 | from tqdm import tqdm 28 | 29 | 30 | def seed(seed=42): 31 | # Set seed for PyTorch 32 | torch.manual_seed(seed) 33 | # torch.use_deterministic_algorithms(True) 34 | if torch.cuda.is_available(): 35 | torch.cuda.manual_seed(seed) 36 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 37 | # Set seed for numpy 38 | np.random.seed(seed) 39 | # Set seed for built-in Python 40 | random.seed(seed) 41 | # Set(seed) for each of the random number generators in python: 42 | torch.backends.cudnn.deterministic = True 43 | torch.backends.cudnn.benchmark = False 44 | 45 | 46 | criterion = torch.nn.CrossEntropyLoss() 47 | 48 | 49 | def train_step(input, label, model, opt): 50 | out = model(input) 51 | loss = criterion(out, label) 52 | loss.backward() 53 | opt.step() 54 | opt.zero_grad() 55 | return out, loss 56 | 57 | 58 | def test_main(): 59 | seed(42) 60 | 61 | device = torch.device('cuda') 62 | 63 | module = resnet101().train().to(device) 64 | module.fc = torch.nn.Linear(2048, 10).to(device) 65 | 66 | opt = torch.optim.Adam(module.parameters(), foreach=True, capturable=True) 67 | 68 | dataset_size = 10000 69 | batch_size = 128 70 | train_dataloader = [(torch.randn(batch_size, 3, 224, 224), torch.randint( 71 | 0, 10, (batch_size, )))] * (dataset_size // batch_size) 72 | 73 | x_batch, y_batch = next(iter(train_dataloader)) 74 | train_step(x_batch.to(device), y_batch.to(device), module, opt) 75 | epochs = 1 76 | for epoch in range(epochs): 77 | all_cnt, correct_cnt, loss_sum = 0, 0, 0 78 | time_start = time.time() 79 | for x_batch, y_batch in tqdm(train_dataloader, dynamic_ncols=True): 80 | x_batch = x_batch.to(device) 81 | y_batch = y_batch.to(device) 82 | if x_batch.size(0) != batch_size: # TODO need to solve this 83 | continue 84 | out, loss = train_step(x_batch, y_batch, module, opt) 85 | all_cnt += len(out) 86 | preds = out.argmax(-1) 87 | correct_cnt += (preds == y_batch).sum() 88 | loss_sum += loss.mean().item() 89 | print( 90 | f'epoch {epoch} train accuracy: {correct_cnt / all_cnt}, loss sum {loss_sum}, avg loss: {loss_sum / all_cnt} ' 91 | f'time: {time.time() - time_start}' 92 | f'max memory: {torch.cuda.max_memory_allocated() / 1024 / 1024}mb') 93 | 94 | 95 | if __name__ == '__main__': 96 | test_main() 97 | -------------------------------------------------------------------------------- /easydist/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import logging 16 | 17 | from . import platform 18 | import easydist.config as mdconfig 19 | 20 | 21 | def easydist_setup(backend, device="cpu", allow_tf32=True): 22 | mdconfig.easydist_device = device 23 | 24 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 25 | datefmt='%m/%d %H:%M:%S', 26 | level=mdconfig.log_level) 27 | 28 | if backend == "jax": 29 | from .jax import easydist_setup_jax 30 | easydist_setup_jax(device, allow_tf32) 31 | 32 | logging.getLogger("jax._src").setLevel(logging.INFO) 33 | elif backend == "torch": 34 | from .torch import easydist_setup_torch 35 | easydist_setup_torch(device, allow_tf32) 36 | 37 | logging.getLogger("torch._subclasses.fake_tensor").setLevel(logging.INFO) 38 | 39 | platform.init_backend(backend) 40 | -------------------------------------------------------------------------------- /easydist/autoflow/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | from .solver import AutoFlowSolver1D 16 | 17 | __all__ = ["AutoFlowSolver1D"] 18 | -------------------------------------------------------------------------------- /easydist/jax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import os 16 | import logging 17 | 18 | import jax 19 | from mpi4py import MPI 20 | 21 | from .api import easydist_shard, get_opt_strategy, set_device_mesh 22 | from .sharding_interpreter import EDJaxShardingAnn 23 | from .bridge import jax2md_bridge 24 | 25 | __all__ = [ 26 | "easydist_shard", "get_opt_strategy", "set_device_mesh", "EDJaxShardingAnn", "jax2md_bridge" 27 | ] 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | def is_jax_distributed_initialized(): 33 | return jax._src.distributed.global_state.client is not None 34 | 35 | 36 | def easydist_setup_jax(device, allow_tf32): 37 | os.environ["NVIDIA_TF32_OVERRIDE"] = "1" if allow_tf32 else "0" 38 | jax.config.update('jax_platforms', device) 39 | 40 | # setup distributed 41 | comm = MPI.COMM_WORLD 42 | size, rank = comm.Get_size(), comm.Get_rank() 43 | 44 | if not is_jax_distributed_initialized(): 45 | 46 | jax.distributed.initialize(coordinator_address="localhost:19705", 47 | num_processes=size, 48 | process_id=rank, 49 | local_device_ids=rank) 50 | 51 | logging.info( 52 | f"[Rank {rank}], Global Devices: {jax.device_count()}, Local Devices: {jax.local_device_count()}" 53 | ) 54 | -------------------------------------------------------------------------------- /easydist/jax/device_mesh.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import functools 16 | import operator 17 | 18 | from easydist.metashard import metair 19 | 20 | JAX_DEVICE_MESH = None 21 | 22 | def set_device_mesh(device_mesh): 23 | global JAX_DEVICE_MESH 24 | JAX_DEVICE_MESH = device_mesh 25 | 26 | mesh_shape = device_mesh.device_ids.shape 27 | 28 | if len(mesh_shape) > 1: 29 | raise ValueError("Only support 1D mesh now") 30 | 31 | 32 | def get_device_mesh(): 33 | global JAX_DEVICE_MESH 34 | return JAX_DEVICE_MESH 35 | 36 | 37 | def device_mesh_world_size(device_mesh=None): 38 | if device_mesh is None: 39 | device_mesh = get_device_mesh() 40 | 41 | if device_mesh is None: 42 | return None 43 | 44 | device_mesh_shape = device_mesh.device_ids.shape 45 | 46 | return functools.reduce(operator.mul, device_mesh_shape) 47 | -------------------------------------------------------------------------------- /easydist/jax/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import os 16 | from contextlib import contextmanager 17 | 18 | 19 | @contextmanager 20 | def _sharding_ann_env(): 21 | 22 | ori_tf32_override = os.environ.get("NVIDIA_TF32_OVERRIDE", None) 23 | 24 | os.environ["NVIDIA_TF32_OVERRIDE"] = "1" 25 | 26 | try: 27 | yield 28 | finally: 29 | if ori_tf32_override is None: 30 | os.environ.pop("NVIDIA_TF32_OVERRIDE") 31 | os.environ["NVIDIA_TF32_OVERRIDE"] = ori_tf32_override 32 | -------------------------------------------------------------------------------- /easydist/metashard/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | from .metaop import MetaOp 16 | from .view_propagation import view_propagation 17 | from .annotation import ShardDim, ShardAnnotation 18 | 19 | __all__ = ["MetaOp", "view_propagation", "ShardDim", "ShardAnnotation"] 20 | -------------------------------------------------------------------------------- /easydist/metashard/halo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | from typing import List 16 | 17 | from easydist import platform 18 | 19 | 20 | class HaloInfo: 21 | 22 | def __init__(self, halowidth: int, dim: int) -> None: 23 | self.halowidth = halowidth 24 | self.dim = dim 25 | 26 | def __str__(self) -> str: 27 | return self.halowidth.__str__() 28 | 29 | def __repr__(self) -> str: 30 | return self.__str__() 31 | 32 | 33 | def halo_padding(tensor_list_: List[platform.Tensor], haloinfo: HaloInfo) -> List[platform.Tensor]: 34 | """add halo padding to tensor_list_""" 35 | 36 | if haloinfo is None or len(tensor_list_) < 2: 37 | return tensor_list_ 38 | 39 | halo = haloinfo.halowidth 40 | dim = haloinfo.dim 41 | 42 | padded_tensor_list = [] 43 | for idx in range(len(tensor_list_)): 44 | to_concatenate = [tensor_list_[idx]] 45 | if idx >= 1: 46 | dim_size = tensor_list_[idx - 1].shape[dim] 47 | if dim_size < halo: 48 | raise RuntimeError("Cannot halo padding for this sharded_tensor") 49 | to_concatenate.insert( 50 | 0, platform.narrow(tensor_list_[idx - 1], dim, dim_size - halo, halo)) 51 | if idx <= len(tensor_list_) - 2: 52 | to_concatenate.append(platform.narrow(tensor_list_[idx + 1], dim, 0, halo)) 53 | padded_tensor_list.append(platform.concatenate(to_concatenate, dim=dim)) 54 | 55 | return padded_tensor_list 56 | -------------------------------------------------------------------------------- /easydist/platform/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import os 16 | import logging 17 | import importlib 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | EASYDIST_BACKEND = None 22 | 23 | __all__ = [ 24 | "add", "equal", "zeros_like", "min", "max", "allclose", "concatenate", "chunk", "narrow", 25 | "Tensor", "tree_flatten", "tree_unflatten", "clone", "from_numpy" 26 | ] 27 | 28 | 29 | def backend_valid(_backend): 30 | return _backend in {"torch", "jax", "tvm"} 31 | 32 | 33 | def init_backend(backend="torch"): 34 | assert backend_valid(backend) 35 | global EASYDIST_BACKEND 36 | EASYDIST_BACKEND = backend 37 | modules = importlib.import_module("." + backend, __name__) 38 | for val in __all__: 39 | exec("globals()['%s'] = modules.%s" % (val, val)) 40 | logger.info(f"========= EasyDist init with backend {backend}. =========") 41 | 42 | 43 | def get_backend(): 44 | global EASYDIST_BACKEND 45 | return EASYDIST_BACKEND 46 | 47 | 48 | for val in __all__: 49 | exec("globals()['%s'] = None" % val) 50 | -------------------------------------------------------------------------------- /easydist/platform/jax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | from functools import partial 16 | 17 | import jax 18 | 19 | add = jax.numpy.add 20 | equal = jax.numpy.array_equal 21 | zeros_like = jax.numpy.zeros_like 22 | min = jax.numpy.minimum 23 | max = jax.numpy.maximum 24 | allclose = partial(jax.numpy.allclose, rtol=5e-3, atol=5e-03) 25 | 26 | 27 | def concatenate(tensors, dim=0): 28 | return jax.numpy.concatenate(tensors, axis=dim) 29 | 30 | 31 | def chunk(input, chunks, dim=0): 32 | return jax.numpy.array_split(input, chunks, axis=dim) 33 | 34 | 35 | def narrow(input, dim, start, length): 36 | indices = jax.numpy.asarray(range(start, start + length)) 37 | return jax.numpy.take(input, indices, axis=dim) 38 | 39 | 40 | Tensor = jax.Array 41 | 42 | tree_flatten = jax.tree_util.tree_flatten 43 | 44 | 45 | def tree_unflatten(values, spec): 46 | return jax.tree_util.tree_unflatten(spec, values) 47 | 48 | 49 | clone = jax.numpy.copy 50 | 51 | from_numpy = jax.numpy.array 52 | -------------------------------------------------------------------------------- /easydist/platform/torch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | from functools import partial 16 | 17 | import torch 18 | import torch.utils._pytree as pytree 19 | 20 | add = torch.Tensor.add 21 | equal = torch.equal 22 | zeros_like = torch.zeros_like 23 | min = torch.min 24 | max = torch.max 25 | allclose = partial(torch.allclose, rtol=1e-3, atol=1e-07) 26 | concatenate = torch.concatenate 27 | chunk = torch.chunk 28 | narrow = torch.narrow 29 | 30 | Tensor = torch.Tensor 31 | 32 | tree_flatten = pytree.tree_flatten 33 | tree_unflatten = pytree.tree_unflatten 34 | 35 | clone = torch.clone 36 | from_numpy = torch.from_numpy 37 | -------------------------------------------------------------------------------- /easydist/platform/tvm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import tvm 16 | import numpy 17 | import torch.utils._pytree as pytree 18 | 19 | 20 | def add(a, b): 21 | return tvm.nd.array(numpy.add(a.numpy(), b.numpy())) 22 | 23 | 24 | def equal(a, b): 25 | return numpy.array_equal(a.numpy(), b.numpy()) 26 | 27 | 28 | def zeros_like(input_): 29 | return tvm.nd.array(numpy.zeros_like(input_.numpy())) 30 | 31 | 32 | def min(a, b): 33 | return tvm.nd.array(numpy.minimum(a.numpy(), b.numpy())) 34 | 35 | 36 | def max(a, b): 37 | return tvm.nd.array(numpy.maximum(a.numpy(), b.numpy())) 38 | 39 | 40 | def allclose(a, b): 41 | return numpy.allclose(a.numpy(), b.numpy(), rtol=5e-3, atol=5e-03) 42 | 43 | 44 | def concatenate(tensors, dim=0): 45 | return tvm.nd.array(numpy.concatenate([t.numpy() for t in tensors], axis=dim)) 46 | 47 | 48 | def chunk(input, chunks, dim=0): 49 | return [tvm.nd.array(i) for i in numpy.array_split(input.numpy(), chunks, axis=dim)] 50 | 51 | 52 | def narrow(input, dim, start, length): 53 | indices = numpy.asarray(range(start, start + length)) 54 | return tvm.nd.array(numpy.take(input.numpy(), indices, axis=dim)) 55 | 56 | 57 | Tensor = tvm.nd.NDArray 58 | 59 | tree_flatten = pytree.tree_flatten 60 | tree_unflatten = pytree.tree_unflatten 61 | 62 | 63 | def clone(input_): 64 | return tvm.nd.array(numpy.copy(input_.numpy())) 65 | 66 | 67 | from_numpy = tvm.nd.array 68 | -------------------------------------------------------------------------------- /easydist/torch/compile.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from easydist.torch.decomp_utils import EASYDIST_DECOMP_TABLE 3 | from easydist.torch.experimental.pp.split_utils import SplitPatcher 4 | import torch.utils._pytree as pytree 5 | from torch._subclasses.fake_tensor import FakeTensor 6 | from torch.fx.experimental.proxy_tensor import make_fx 7 | from functools import partial 8 | from easydist.torch.experimental.pp.split_utils import clear_pp_compile_states, get_updated_params_states 9 | from easydist.torch.experimental.pp.utils import save_graphviz_dot 10 | from easydist.torch.init_helper import SetParaInitHelper 11 | from easydist.torch.utils import _enable_compile, _rematerialize_optimizer 12 | from easydist.torch.scope_auto.build_scope_modules import build_scope_modules 13 | 14 | 15 | import torch 16 | from torch.nn.utils import stateless 17 | 18 | 19 | from contextlib import nullcontext 20 | from typing import cast 21 | 22 | from easydist.utils import rgetattr, rsetattr 23 | 24 | 25 | def stateless_func(func, module, opt, params, buffers, named_states, args, kwargs): 26 | clear_pp_compile_states() 27 | with stateless._reparametrize_module( 28 | cast(torch.nn.Module, module), { 29 | **params, 30 | **buffers 31 | }, tie_weights=True) if module else nullcontext(), _rematerialize_optimizer( 32 | opt, named_states, params) if opt else nullcontext(): 33 | ret = func(*args, **kwargs) 34 | if (tup := get_updated_params_states()) != (None, None): 35 | params, named_states = tup 36 | grads = {k: v.grad for k, v in params.items()} 37 | return params, buffers, named_states, grads, ret 38 | 39 | 40 | def ed_compile_func(func, tracing_mode, init_helper, args, kwargs, schedule_cls, module, opt): 41 | params, buffers = {}, {} 42 | if module is not None: 43 | params = dict(module.named_parameters()) 44 | buffers = dict(module.named_buffers()) 45 | 46 | if isinstance(init_helper, SetParaInitHelper): 47 | init_helper.module = module 48 | 49 | named_states = {} 50 | 51 | if opt is not None: 52 | # assign grad and warm up optimizer 53 | for name in dict(module.named_parameters()): 54 | with torch.no_grad(): 55 | rsetattr(module, name + ".grad", torch.zeros_like(rgetattr(module, name).data)) 56 | if isinstance(rgetattr(module, name).data, FakeTensor): 57 | mode = rgetattr(module, name).data.fake_mode 58 | 59 | opt.step() 60 | opt.zero_grad(True) 61 | 62 | for n, p in params.items(): 63 | if p in opt.state: 64 | named_states[n] = opt.state[p] # type: ignore[index] 65 | # if step in state, reduce one for warmup step. 66 | if 'step' in named_states[n]: 67 | named_states[n]['step'] -= 1 68 | 69 | flat_named_states, _ = pytree.tree_flatten(named_states) 70 | 71 | # fix for sgd withtout momentum 72 | if all(state is None for state in flat_named_states): 73 | named_states = {} 74 | flat_named_states, _ = pytree.tree_flatten(named_states) 75 | 76 | state_tensor_num = len(params) + len(buffers) + len(flat_named_states) 77 | 78 | with _enable_compile(), SplitPatcher(module, opt) if schedule_cls else nullcontext(): 79 | traced_graph = make_fx(partial(stateless_func, func, module, opt), 80 | tracing_mode=tracing_mode, 81 | decomposition_table=EASYDIST_DECOMP_TABLE, 82 | _allow_non_fake_inputs=False)(params, buffers, named_states, args, 83 | kwargs) 84 | 85 | if len(list(traced_graph.named_buffers())) != 0: 86 | warnings.warn(f"No buffer should be found in the traced graph, please check if the model is correctly traced, found {dict(traced_graph.named_buffers())}") 87 | 88 | traced_graph = build_scope_modules(traced_graph) 89 | traced_graph.graph.eliminate_dead_code() 90 | traced_graph.recompile() 91 | 92 | save_graphviz_dot(traced_graph, 'traced_graph') 93 | 94 | return params,buffers,named_states,state_tensor_num,traced_graph 95 | -------------------------------------------------------------------------------- /easydist/torch/cuda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/easydist/4ced4cd7ca7a39722670beae16332426d5430238/easydist/torch/cuda/__init__.py -------------------------------------------------------------------------------- /easydist/torch/cuda/mem_allocator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import ctypes 16 | 17 | import easydist 18 | from easydist.torch.meta_allocator import profiling_allocator 19 | 20 | 21 | class _CUDAAllocator: 22 | def __init__(self, allocator): 23 | self._allocator = allocator 24 | 25 | def allocator(self): 26 | return self._allocator 27 | 28 | 29 | class EffectiveCUDAAllocator(_CUDAAllocator): 30 | def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str): 31 | allocator = ctypes.CDLL(path_to_so_file) 32 | alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value 33 | free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value 34 | assert alloc_fn is not None 35 | assert free_fn is not None 36 | self._allocator = profiling_allocator._cuda_customEffectiveAllocator(alloc_fn, free_fn) 37 | 38 | 39 | def change_current_allocator(allocator: _CUDAAllocator) -> None: 40 | profiling_allocator._cuda_changeCurrentAllocator(allocator.allocator()) 41 | 42 | def init_meta_allocator(): 43 | if not easydist.config.enable_memory_opt: 44 | return 45 | swap_to_profiling_allocator() 46 | 47 | def swap_to_profiling_allocator(): 48 | # swap from caching allocator to profiling allocator 49 | 50 | profiling_allocator._compile_if_needed() 51 | 52 | path_to_profiling_allocator = profiling_allocator.module.__file__ 53 | raw_allocator = ctypes.CDLL(path_to_profiling_allocator) 54 | init_fn = ctypes.cast(getattr(raw_allocator, 'init_fn'), ctypes.c_void_p).value 55 | new_alloc = EffectiveCUDAAllocator( 56 | path_to_profiling_allocator, 'meta_malloc', 'meta_free') 57 | profiling_allocator._save_back_allocator() 58 | change_current_allocator(new_alloc) 59 | new_alloc.allocator().set_init_fn(init_fn) 60 | 61 | -------------------------------------------------------------------------------- /easydist/torch/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/easydist/4ced4cd7ca7a39722670beae16332426d5430238/easydist/torch/experimental/__init__.py -------------------------------------------------------------------------------- /easydist/torch/experimental/pp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | # TODO @botbw: dependency and circular import issues -------------------------------------------------------------------------------- /easydist/torch/experimental/pp/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import logging 16 | import operator 17 | from copy import deepcopy 18 | from typing import Any, Dict 19 | 20 | import torch.fx as fx 21 | from torch.fx.passes.graph_drawer import FxGraphDrawer 22 | 23 | import easydist.config as mdconfig 24 | 25 | 26 | def ordered_gi_users(node: fx.node): 27 | assert all(user.op == 'call_function' and user.target == operator.getitem 28 | for user in node.users), "All users of the node must be getitem" 29 | ret = [None for _ in range(len(node.users))] 30 | for user in node.users: 31 | ret[user.args[1]] = user 32 | return ret 33 | 34 | 35 | def save_graphviz_dot(gm, name): 36 | with open(f"./log/{name}.dot", "w") as f: 37 | f.write(str(FxGraphDrawer(gm, name).get_dot_graph())) 38 | 39 | 40 | def _to_tuple(x): 41 | if isinstance(x, tuple): 42 | return x 43 | return (x, ) 44 | 45 | 46 | class OneToOneMap: 47 | def __init__(self): 48 | self._map: Dict[Any, Any]= {} 49 | self._inv: Dict[Any, Any] = {} 50 | 51 | def get(self, key: Any) -> Any: 52 | return self._map[key] 53 | 54 | def inv_get(self, key: Any) -> Any: 55 | return self._inv[key] 56 | 57 | def add(self, key: Any, value: Any) -> Any: 58 | if key in self._map or value in self._map: 59 | raise RuntimeError(f"{key}: {value} is not one to one mapping, found {key in self._map} {value in self._inv}") 60 | self._map[key] = value 61 | self._inv[value] = key 62 | 63 | def items(self): 64 | return self._map.items() 65 | 66 | def keys(self): 67 | return self._map.keys() 68 | 69 | def inv_items(self): 70 | return self._inv.items() 71 | 72 | def inv_keys(self): 73 | return self._inv.keys() 74 | 75 | def apply(self, other: "OneToOneMap") -> "OneToOneMap": 76 | mapping = OneToOneMap() 77 | for k, v in self.items(): 78 | mapping.add(k, other.get(v)) 79 | return mapping 80 | 81 | def map_dict_key(self, dictt: Dict) -> Dict: 82 | ret = {} 83 | for k, v in dictt.items(): 84 | ret[self._map[k]] = v 85 | return ret 86 | 87 | def inverse(self) -> "OneToOneMap": 88 | inversed = deepcopy(self) 89 | inversed._map, inversed._inv = inversed._inv, inversed._map 90 | return inversed 91 | 92 | def __repr__(self): 93 | return f"{self._map=}\n{self._inv=}" 94 | 95 | @staticmethod 96 | def from_dict(dict: Dict) -> "OneToOneMap": 97 | mapping = OneToOneMap() 98 | for k, v in dict.items(): 99 | mapping.add(k, v) 100 | return mapping 101 | -------------------------------------------------------------------------------- /easydist/torch/graph_profile_db.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import os 16 | import pickle 17 | import logging 18 | 19 | import easydist.config as mdconfig 20 | 21 | _logger = logging.getLogger(__name__) 22 | 23 | 24 | class PerfDB: 25 | 26 | def __init__(self) -> None: 27 | self.db = dict() 28 | if os.path.exists(mdconfig.prof_db_path): 29 | self.db = pickle.load(open(mdconfig.prof_db_path, 'rb')) 30 | _logger.info(f"Load Perf DB from {mdconfig.prof_db_path}") 31 | 32 | def get_op_perf(self, key_l1, key_l2): 33 | if key_l1 in self.db: 34 | return self.db[key_l1].get(key_l2, None) 35 | return None 36 | 37 | def record_op_perf(self, key_l1, key_l2, value): 38 | if key_l1 not in self.db: 39 | self.db[key_l1] = dict() 40 | self.db[key_l1][key_l2] = value 41 | 42 | def persistent(self): 43 | _logger.info(f"Persistent Perf DB to {mdconfig.prof_db_path}") 44 | 45 | if not os.path.exists(mdconfig.easydist_dir): 46 | os.makedirs(mdconfig.easydist_dir, exist_ok=True) 47 | 48 | pickle.dump(self.db, open(mdconfig.prof_db_path, 'wb')) 49 | -------------------------------------------------------------------------------- /easydist/torch/meta_allocator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import os 16 | import logging 17 | 18 | import pynvml 19 | 20 | from torch.utils.cpp_extension import load, _join_cuda_home 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | def _compile(): 25 | 26 | print("torch load") 27 | 28 | device_index = os.environ.get('CUDA_VISIBLE_DEVICES', "0").split(',')[0] 29 | 30 | # (NOTE) workaround of torch.cuda.get_cuda_arch_list() which will init allocator 31 | pynvml.nvmlInit() 32 | handle = pynvml.nvmlDeviceGetHandleByIndex(int(device_index)) 33 | capability = pynvml.nvmlDeviceGetCudaComputeCapability(handle) 34 | 35 | os.environ['TORCH_CUDA_ARCH_LIST'] = f'{capability[0]}.{capability[1]}+PTX' 36 | 37 | sources_files = [ 38 | 'csrc/profiling_allocator.cpp', 39 | 'csrc/stream_tracer.cpp', 40 | 'csrc/cupti_callback_api.cpp', 41 | 'csrc/python_tracer_init.cpp', 42 | 'csrc/effective_cuda_allocator.cpp' 43 | ] 44 | 45 | profiling_allocator_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "profiler") 46 | _comiled_module = load(name="profiling_allocator", 47 | extra_include_paths=[_join_cuda_home('extras', 'CUPTI', 'include')], 48 | sources=[os.path.join(profiling_allocator_dir, f) for f in sources_files], 49 | extra_cflags=['-DUSE_CUDA=1', '-D_GLIBCXX_USE_CXX11_ABI=0'], 50 | with_cuda=True, verbose=True) 51 | 52 | logger.info(f"[profiling_allocator] compiled in {_comiled_module.__file__}") 53 | 54 | return _comiled_module 55 | 56 | 57 | class LazyModule(): 58 | def __init__(self): 59 | self.module = None 60 | 61 | def _compile_if_needed(self): 62 | if self.module is None: 63 | self.module = _compile() 64 | 65 | def __getattr__(self, name: str): 66 | self._compile_if_needed() 67 | return self.module.__getattribute__(name) 68 | 69 | profiling_allocator = LazyModule() 70 | -------------------------------------------------------------------------------- /easydist/torch/passes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | from .fix_embedding import fix_embedding 16 | from .fix_bias import fix_addmm_bias, fix_convoluation_bias 17 | from .fix_meta_device import fix_meta_device 18 | from .eliminate_detach import eliminate_detach 19 | from .sharding import sharding_transform, sharding_transform_dtensor 20 | from .tile_comm import tile_comm 21 | from .comm_optimize import comm_optimize 22 | from .rule_override import rule_override_by_graph 23 | from .runtime_prof import runtime_prof 24 | from .edinfo_utils import create_edinfo, annotation_edinfo 25 | from .process_tag import process_tag 26 | from .allocator_profiler import AllocatorProfiler, ModuleProfilingInfo 27 | from .fix_view import decouple_view 28 | from .pp_passes import get_partition 29 | __all__ = [ 30 | "fix_embedding", "fix_addmm_bias", "fix_convoluation_bias", "eliminate_detach", 31 | "sharding_transform", "sharding_transform_dtensor", "fix_meta_device", "tile_comm", 32 | "comm_optimize", "rule_override_by_graph", "runtime_prof", "create_edinfo", 33 | "AllocatorProfiler", "ModuleProfilingInfo", "annotation_edinfo", "process_tag", 34 | "decouple_view", "get_partition" 35 | ] 36 | -------------------------------------------------------------------------------- /easydist/torch/passes/edinfo_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import operator 16 | 17 | import torch 18 | from torch.fx.node import Node, _get_qualified_name 19 | import torch.utils._pytree as pytree 20 | 21 | from easydist.torch.utils import EDInfo, EDNodeType, create_meta_from_node 22 | from easydist.torch.passes.sharding import CREATE_ATEN_OP, COMM_FUNCS 23 | 24 | 25 | def create_edinfo(fx_module: torch.fx.GraphModule, sharding_info, shape_info) -> torch.fx.GraphModule: 26 | 27 | for node in fx_module.graph.nodes: 28 | 29 | if node.op == 'call_function' and 'val' not in node.meta: 30 | node.meta = create_meta_from_node(node) 31 | 32 | 33 | if not hasattr(node, "ed_info"): 34 | node.ed_info = EDInfo(ori_meta=node.meta) 35 | 36 | if node.op == "call_function": 37 | op_name = _get_qualified_name(node.target) 38 | 39 | node_sharding_info = None 40 | if op_name in sharding_info: 41 | 42 | def _gen_meta(arg: Node): 43 | return torch.empty(shape_info[arg.name]["shape"], 44 | dtype=shape_info[arg.name]["dtype"], 45 | device="meta") 46 | 47 | args_meta = pytree.tree_map_only(Node, _gen_meta, node.args) 48 | args_meta = str(tuple(args_meta)) + ' | ' + str(node.kwargs) 49 | if args_meta in sharding_info[op_name]: 50 | node_sharding_info = sharding_info[op_name][args_meta] 51 | 52 | node.ed_info.spmd_annotation = node_sharding_info 53 | 54 | elif node.op in ["placeholder", "get_attr"]: 55 | if hasattr(node, "meta") and 'val' in node.meta: 56 | node_sharding_info = None 57 | if node.op in sharding_info: 58 | arg_meta_tensor = torch.empty(shape_info[node.name]["shape"], 59 | dtype=shape_info[node.name]["dtype"], 60 | device="meta") 61 | args_meta = str(arg_meta_tensor) 62 | if args_meta in sharding_info[node.op]: 63 | node_sharding_info = sharding_info[node.op][args_meta] 64 | 65 | node.ed_info.spmd_annotation = node_sharding_info 66 | 67 | return fx_module 68 | 69 | 70 | def annotation_edinfo(traced_graph: torch.fx.GraphModule) -> torch.fx.GraphModule: 71 | for node in traced_graph.graph.nodes: 72 | if not hasattr(node, "ed_info"): 73 | node.ed_info = EDInfo(ori_meta=node.meta) 74 | 75 | if node.op == 'placeholder': 76 | node.ed_info.node_type = EDNodeType.AUXILIARY 77 | elif node.op == 'call_function': 78 | # create meta for custom function 79 | if node.target not in CREATE_ATEN_OP: 80 | node.meta = create_meta_from_node(node) 81 | # annotate node type 82 | if node.target in COMM_FUNCS: 83 | node.ed_info.node_type = EDNodeType.COMMUNICATION 84 | # (TODO) hard code here to avoid to runtime profile torch.ops.aten._fused_adam.default 85 | elif node.target in [operator.getitem, torch.ops.aten._fused_adam.default]: 86 | node.ed_info.node_type = EDNodeType.AUXILIARY 87 | else: 88 | node.ed_info.node_type = EDNodeType.COMPUTATION 89 | elif node.op == 'output': 90 | node.ed_info.node_type = EDNodeType.AUXILIARY 91 | 92 | return traced_graph -------------------------------------------------------------------------------- /easydist/torch/passes/eliminate_detach.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import torch.fx as fx 16 | from torch.fx.node import _get_qualified_name 17 | 18 | 19 | def eliminate_detach(fx_graph: fx.GraphModule): 20 | for node in fx_graph.graph.nodes: 21 | if node.op == 'call_function': 22 | if _get_qualified_name(node.target) == 'torch.ops.aten.detach.default': 23 | node.replace_all_uses_with(node.args[0]) 24 | 25 | fx_graph.graph.eliminate_dead_code() 26 | 27 | return fx_graph 28 | -------------------------------------------------------------------------------- /easydist/torch/passes/fix_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import copy 16 | 17 | import torch 18 | from torch.fx.node import _get_qualified_name 19 | 20 | from easydist.torch.utils import create_meta_from_node 21 | 22 | def fix_addmm_bias(fx_module: torch.fx.GraphModule): 23 | 24 | for node in fx_module.graph.nodes: 25 | if node.op == 'call_function': 26 | if "torch.ops.aten.addmm.default" in _get_qualified_name(node.target): 27 | node.target = torch.ops.aten.mm.default 28 | bias = node.args[0] 29 | node.args = (node.args[1], node.args[2]) 30 | 31 | with fx_module.graph.inserting_after(node): 32 | add_bias_node = fx_module.graph.call_function(torch.ops.aten.add.Tensor, 33 | args=(node, bias)) 34 | 35 | node.replace_all_uses_with(add_bias_node) 36 | 37 | add_bias_node.update_arg(0, node) 38 | 39 | add_bias_node.meta = create_meta_from_node(add_bias_node) 40 | 41 | fx_module.recompile() 42 | 43 | return fx_module 44 | 45 | 46 | def fix_convoluation_bias(fx_module: torch.fx.GraphModule): 47 | 48 | for node in fx_module.graph.nodes: 49 | if node.op == 'call_function': 50 | if "torch.ops.aten.convolution.default" in _get_qualified_name(node.target): 51 | if node.args[2] is not None: 52 | node.target = torch.ops.aten.convolution.default 53 | bias = node.args[2] 54 | node.args = (node.args[0], node.args[1], None, *node.args[3:]) 55 | 56 | with fx_module.graph.inserting_after(node): 57 | bias_new = fx_module.graph.call_function(torch.ops.aten.view.default, 58 | args=(bias, [1, -1, 1, 1])) 59 | 60 | with fx_module.graph.inserting_after(bias_new): 61 | add_bias_node = fx_module.graph.call_function(torch.ops.aten.add.Tensor, 62 | args=(node, bias_new)) 63 | 64 | node.replace_all_uses_with(add_bias_node) 65 | 66 | add_bias_node.update_arg(0, node) 67 | 68 | bias_new.meta = create_meta_from_node(bias_new) 69 | add_bias_node.meta = create_meta_from_node(add_bias_node) 70 | 71 | fx_module.recompile() 72 | 73 | return fx_module 74 | -------------------------------------------------------------------------------- /easydist/torch/passes/fix_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import torch 16 | from torch.fx.node import _get_qualified_name 17 | 18 | 19 | def md_embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): 20 | if int(torch.max(indices).item()) >= weight.shape[0]: 21 | raise RuntimeError("embedding indice overflow") 22 | return torch.ops.aten.embedding.default(weight, indices, padding_idx, scale_grad_by_freq, 23 | sparse) 24 | 25 | 26 | def fix_embedding(fx_module: torch.fx.GraphModule, recover=False): 27 | 28 | for node in fx_module.graph.nodes: 29 | if node.op == 'call_function': 30 | if "torch.ops.aten.embedding.default" in _get_qualified_name(node.target): 31 | node.target = md_embedding 32 | 33 | if recover and "md_embedding" in _get_qualified_name(node.target): 34 | node.target = torch.ops.aten.embedding.default 35 | 36 | fx_module.recompile() 37 | 38 | return fx_module 39 | -------------------------------------------------------------------------------- /easydist/torch/passes/fix_meta_device.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import os 16 | import copy 17 | 18 | import torch 19 | 20 | import easydist.config as mdconfig 21 | 22 | 23 | def fix_meta_device(fx_module: torch.fx.GraphModule): 24 | 25 | for node in fx_module.graph.nodes: 26 | if node.op == 'call_function': 27 | if "device" in node.kwargs: 28 | new_kwargs = dict(copy.deepcopy(node.kwargs)) 29 | device = mdconfig.easydist_device 30 | new_kwargs["device"] = torch.device(device=device) 31 | assert isinstance(new_kwargs, dict) 32 | node.kwargs = new_kwargs 33 | 34 | fx_module.recompile() 35 | 36 | return fx_module 37 | -------------------------------------------------------------------------------- /easydist/torch/passes/fix_view.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import functools 16 | import operator 17 | import torch 18 | import torch.fx as fx 19 | 20 | from easydist.torch.preset_propagation import VIEW_OPS 21 | 22 | 23 | def _fix_view_node(input_shape, output_shape): 24 | if -1 in output_shape: 25 | numel = functools.reduce(operator.mul, input_shape) 26 | dim_size = -1 * numel // functools.reduce(operator.mul, output_shape) 27 | output_shape[output_shape.index(-1)] = dim_size 28 | 29 | intermediate_shape = [] 30 | i = j = 0 31 | while i < len(input_shape) and j < len(output_shape): 32 | accu_i = input_shape[i] 33 | accu_j = output_shape[j] 34 | while accu_i != accu_j: 35 | if accu_i < accu_j: 36 | i += 1 37 | accu_i *= input_shape[i] 38 | else: 39 | j += 1 40 | accu_j *= output_shape[j] 41 | intermediate_shape.append(accu_i) 42 | i += 1 43 | j += 1 44 | while i < len(input_shape): 45 | intermediate_shape.append(input_shape[i]) 46 | i += 1 47 | while j < len(output_shape): 48 | intermediate_shape.append(output_shape[j]) 49 | j += 1 50 | assert i == len(input_shape) and j == len(output_shape) 51 | return intermediate_shape 52 | 53 | 54 | def decouple_view(fx_module: fx.GraphModule): 55 | for node in fx_module.graph.nodes: 56 | if node.op == 'call_function' and node.target in VIEW_OPS: 57 | target_op = node.target 58 | input_shape = list(node.args[0].meta['val'].shape) 59 | output_shape = list(node.args[1]) 60 | intermediate_shape = _fix_view_node(input_shape, output_shape) 61 | if input_shape != intermediate_shape and output_shape != intermediate_shape: 62 | node.args = (node.args[0], intermediate_shape) 63 | fake_mode = node.meta['val'].fake_mode 64 | node.meta['val'] = fake_mode.from_tensor( 65 | torch.zeros(intermediate_shape, 66 | dtype=node.meta['val'].dtype)) 67 | with fx_module.graph.inserting_after(node): 68 | intermediate_view = fx_module.graph.call_function( 69 | target_op, args=(node, output_shape)) 70 | intermediate_view.meta['val'] = fake_mode.from_tensor( 71 | torch.zeros(output_shape, 72 | dtype=node.meta['val'].dtype)) 73 | node.replace_all_uses_with(intermediate_view, delete_user_cb=lambda x: x != intermediate_view) 74 | 75 | fx_module.recompile() 76 | return fx_module 77 | 78 | # def fix_sharded_view(fx_module: fx.GraphModule): 79 | # for node in fx_module.graph.nodes: 80 | # if (node.op == 'call_function' and node.target in VIEW_OPS and 81 | # node.args[0].op == 'call_function' and node.args[0].target == scatter_wrapper): 82 | # input_val = node.args[0].meta['val'] 83 | # output_shape = list(node.args[1]) 84 | # try: 85 | # _ = input_val.view(output_shape) 86 | # except Exception: 87 | 88 | # # global [768 (shard), 768] => view as [3, 256, 768] 89 | # # 0 [384, 768] => view as [1, 256, 768] 90 | # # 1 [384, 768] => view as [2, 256, 768] -------------------------------------------------------------------------------- /easydist/torch/passes/pp_passes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | from typing import List, Set 16 | 17 | import torch 18 | import torch.fx as fx 19 | 20 | 21 | def get_partition(fx_module: fx.GraphModule) -> List[Set[str]]: 22 | partitions = [] 23 | cur_par = set() 24 | for node in fx_module.graph.nodes: 25 | cur_par.add(node.name) 26 | if node.op == 'call_function' and node.target in [ 27 | torch.ops.easydist.fw_bw_split.default, 28 | torch.ops.easydist.step_split.default 29 | ]: 30 | partitions.append(cur_par) 31 | cur_par = set() 32 | partitions.append(cur_par) 33 | return partitions 34 | -------------------------------------------------------------------------------- /easydist/torch/passes/process_tag.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import torch 16 | import torch._custom_ops 17 | 18 | from easydist.torch.device_mesh import get_device_mesh 19 | from easydist.torch.passes.sharding import all_reduce_start, all_reduce_end 20 | 21 | 22 | @torch._custom_ops.custom_op("easydist::tag") 23 | def tag(input: torch.Tensor, tag: str) -> torch.Tensor: 24 | ... 25 | 26 | @torch._custom_ops.impl_abstract("easydist::tag") 27 | def tag_impl_abstract(input: torch.Tensor, tag: str) -> torch.Tensor: 28 | return torch.empty_like(input) 29 | 30 | 31 | @torch._custom_ops.impl("easydist::tag") 32 | def tag_impl(input: torch.Tensor, tag: str) -> torch.Tensor: 33 | return input 34 | 35 | 36 | def process_tag(traced_graph: torch.fx.GraphModule) -> torch.fx.GraphModule: 37 | 38 | tp_mesh = get_device_mesh('tp') 39 | 40 | for node in traced_graph.graph.nodes: 41 | if node.target == torch.ops.easydist.tag.default: 42 | if node.args[1] == "allreduce[sum]": 43 | reduceOp = "sum" 44 | ranks = tp_mesh.mesh.flatten().tolist() 45 | with traced_graph.graph.inserting_before(node): 46 | all_reduce_start_node = traced_graph.graph.call_function(all_reduce_start, 47 | args=(node.args[0], 48 | reduceOp, 49 | ranks)) 50 | all_reduce_end_node = traced_graph.graph.call_function( 51 | all_reduce_end, args=(all_reduce_start_node, reduceOp, ranks)) 52 | 53 | node.replace_all_uses_with(all_reduce_end_node) 54 | 55 | traced_graph.graph.eliminate_dead_code() 56 | traced_graph.recompile() 57 | 58 | return traced_graph -------------------------------------------------------------------------------- /easydist/torch/profiler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | from .stream_tracer import ( 16 | StreamTracer, 17 | ) 18 | 19 | __all__ = [ 20 | "StreamTracer", 21 | ] 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /easydist/torch/profiler/csrc/cupti_callback_api.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024, Alibaba Group; 2 | Licensed under the Apache License, Version 2.0 (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | 8 | Unless required by applicable law or agreed to in writing, software 9 | distributed under the License is distributed on an "AS IS" BASIS, 10 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | See the License for the specific language governing permissions and 12 | limitations under the License. 13 | ==============================================================================*/ 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #pragma once 22 | 23 | //#define TRACER_VERBOSE 24 | 25 | extern std::string g_cur_op_name; 26 | extern bool g_stream_tracing_active; 27 | extern bool g_in_op_core_context; 28 | 29 | typedef struct StreamTraceData_st { 30 | void addOpStream(const std::string& op_name, uint32_t stream_id, 31 | bool is_core) { 32 | if (is_core) { 33 | op_streams_[op_name].emplace_back(stream_id); 34 | } else { 35 | op_extra_streams_[op_name].emplace_back(stream_id); 36 | } 37 | } 38 | 39 | #ifdef TRACER_VERBOSE 40 | void addOpKernel(const std::string& op_name, const std::string& kernel_name, 41 | bool is_core) { 42 | if (is_core) { 43 | op_kernels_[op_name].emplace_back(kernel_name); 44 | } else { 45 | op_extra_kernels_[op_name].emplace_back(kernel_name); 46 | } 47 | } 48 | 49 | std::string toString() const { 50 | std::string ret; 51 | for (auto& item : op_streams_) { 52 | ret += "op: " + item.first + ", kernel num: " + \ 53 | std::to_string(item.second.size()) + "\n"; 54 | auto& kernels = op_kernels_.at(item.first); 55 | assert(kernels.size() == item.second.size()); 56 | for (int i=0; i> op_streams_; 87 | std::map> op_extra_streams_; 88 | 89 | #ifdef TRACER_VERBOSE 90 | std::map> op_kernels_; 91 | std::map> op_extra_kernels_; 92 | #endif 93 | } StreamTraceData; 94 | 95 | class StreamTracerCallbackApi { 96 | public: 97 | StreamTracerCallbackApi() = default; 98 | StreamTracerCallbackApi(const StreamTracerCallbackApi&) = delete; 99 | StreamTracerCallbackApi& operator=(const StreamTracerCallbackApi&) = delete; 100 | ~StreamTracerCallbackApi(); 101 | 102 | static std::shared_ptr singleton(); 103 | 104 | void initCallbackApi(); 105 | void start(); 106 | void stop(); 107 | StreamTraceData getTraceData() { 108 | return trace_data_; 109 | } 110 | private: 111 | CUpti_SubscriberHandle subscriber_ {0}; 112 | StreamTraceData trace_data_; 113 | }; 114 | 115 | -------------------------------------------------------------------------------- /easydist/torch/profiler/csrc/effective_cuda_allocator.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024, Alibaba Group; 2 | Licensed under the Apache License, Version 2.0 (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | 8 | Unless required by applicable law or agreed to in writing, software 9 | distributed under the License is distributed on an "AS IS" BASIS, 10 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | See the License for the specific language governing permissions and 12 | limitations under the License. 13 | ==============================================================================*/ 14 | 15 | #pragma once 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | #include 22 | #include 23 | 24 | 25 | namespace torch::cuda::CUDAPluggableAllocator { 26 | 27 | std::shared_ptr 28 | createCustomEffectiveAllocator( 29 | std::function alloc_fn, 30 | std::function free_fn); 31 | 32 | struct EffectiveCUDAAllocator : public CUDAPluggableAllocator { 33 | EffectiveCUDAAllocator( 34 | std::function alloc_fn, 35 | std::function free_fn); 36 | 37 | EffectiveCUDAAllocator(EffectiveCUDAAllocator& other); 38 | 39 | void* malloc(size_t size, int device, cudaStream_t stream); 40 | 41 | #if TORCH_VERSION_MAJOR>=2 && TORCH_VERSION_MINOR>=2 42 | c10::DataPtr allocate(size_t size) override; 43 | #else 44 | c10::DataPtr allocate(size_t size) const override; 45 | #endif 46 | virtual void raw_delete(void* ptr) override; 47 | 48 | bool enable_runtime_trace_ = false; // debug only 49 | std::unordered_map addr_op_map_; // debug only 50 | }; 51 | 52 | } // namespace torch::cuda::CUDAPluggableAllocator 53 | 54 | -------------------------------------------------------------------------------- /easydist/torch/profiler/csrc/profiling_allocator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "c10/cuda/CUDACachingAllocator.h" 6 | 7 | enum AllocatorMode { 8 | PROFILE, 9 | RUNTIME 10 | }; 11 | // CType Func 12 | extern "C"{ 13 | void init_fn(int device_count); 14 | 15 | void* meta_malloc(ssize_t size, int device, cudaStream_t stream); 16 | 17 | void meta_free(void* ptr, ssize_t size, int device, cudaStream_t stream); 18 | } 19 | // Pybind Func 20 | 21 | void set_start_recording(bool flag); 22 | 23 | void set_allocator_mode(AllocatorMode mode); 24 | 25 | void set_customized_flag(bool flag); 26 | 27 | void set_mem_size(long memory_size); 28 | 29 | void set_temp_mem_size(long temp_memory_size); 30 | 31 | void set_raw_mem_allocs(std::vector> py_mem_allocs); 32 | 33 | std::vector> get_allocator_profiling_info(); 34 | 35 | void save_back_allocator(); -------------------------------------------------------------------------------- /easydist/torch/profiler/csrc/stream_tracer.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024, Alibaba Group; 2 | Licensed under the Apache License, Version 2.0 (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | 8 | Unless required by applicable law or agreed to in writing, software 9 | distributed under the License is distributed on an "AS IS" BASIS, 10 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | See the License for the specific language governing permissions and 12 | limitations under the License. 13 | ==============================================================================*/ 14 | 15 | #include 16 | 17 | #include "stream_tracer.h" 18 | #include "cupti_callback_api.h" 19 | 20 | std::string g_cur_op_name("N/A"); 21 | bool g_stream_tracing_active = false; 22 | bool g_in_op_core_context = false; 23 | 24 | void prepareStreamTracer() { 25 | auto cbapi = StreamTracerCallbackApi::singleton(); 26 | cbapi->initCallbackApi(); 27 | } 28 | 29 | void enableStreamTracer() { 30 | auto cbapi = StreamTracerCallbackApi::singleton(); 31 | cbapi->start(); 32 | } 33 | 34 | void disableStreamTracer() { 35 | auto cbapi = StreamTracerCallbackApi::singleton(); 36 | cbapi->stop(); 37 | } 38 | 39 | void activateStreamTracer() { 40 | g_stream_tracing_active = true; 41 | } 42 | 43 | void inactivateStreamTracer() { 44 | g_stream_tracing_active = false; 45 | } 46 | 47 | void enterOpCore() { 48 | g_in_op_core_context = true; 49 | } 50 | 51 | void leaveOpCore() { 52 | g_in_op_core_context = false; 53 | } 54 | 55 | void setCurOpName(const char* cur_op_name) { 56 | assert(cur_op_name); 57 | g_cur_op_name = std::string(cur_op_name); 58 | } 59 | 60 | StreamTraceData 61 | getStreamTraceData() { 62 | auto cbapi = StreamTracerCallbackApi::singleton(); 63 | return cbapi->getTraceData(); 64 | } 65 | 66 | -------------------------------------------------------------------------------- /easydist/torch/profiler/csrc/stream_tracer.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024, Alibaba Group; 2 | Licensed under the Apache License, Version 2.0 (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | 8 | Unless required by applicable law or agreed to in writing, software 9 | distributed under the License is distributed on an "AS IS" BASIS, 10 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | See the License for the specific language governing permissions and 12 | limitations under the License. 13 | ==============================================================================*/ 14 | 15 | #pragma once 16 | 17 | #include "cupti_callback_api.h" 18 | 19 | 20 | void prepareStreamTracer(); 21 | void enableStreamTracer(); 22 | void disableStreamTracer(); 23 | void activateStreamTracer(); 24 | void inactivateStreamTracer(); 25 | void enterOpCore(); 26 | void leaveOpCore(); 27 | void setCurOpName(const char* cur_op_name); 28 | StreamTraceData getStreamTraceData(); 29 | 30 | 31 | -------------------------------------------------------------------------------- /easydist/torch/profiler/stream_tracer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import torch 16 | from easydist.torch.meta_allocator import profiling_allocator 17 | 18 | __all__ = [ 19 | "StreamTracer", 20 | ] 21 | 22 | 23 | class StreamTracer(): 24 | def __init__( 25 | self, 26 | enabled=True, 27 | ): 28 | self.enabled: bool = enabled 29 | if not self.enabled: 30 | return 31 | 32 | self.entered = False 33 | 34 | def __enter__(self): 35 | if not self.enabled: 36 | return 37 | if self.entered: 38 | raise RuntimeError("Stream tracer's context manager is not reentrant") 39 | self.prepare_trace() 40 | self.start_trace() 41 | return self 42 | 43 | def prepare_trace(self): 44 | self.entered = True 45 | profiling_allocator._prepare_stream_tracer() 46 | 47 | def start_trace(self): 48 | self.entered = True 49 | profiling_allocator._enable_stream_tracer() 50 | 51 | def __exit__(self, exc_type, exc_val, exc_tb): 52 | if not self.enabled: 53 | return 54 | torch.cuda.synchronize() 55 | profiling_allocator._disable_stream_tracer() 56 | return False 57 | 58 | def get_stream_trace_data(self): 59 | return profiling_allocator._get_stream_trace_data() 60 | 61 | 62 | -------------------------------------------------------------------------------- /easydist/torch/schedule/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/easydist/4ced4cd7ca7a39722670beae16332426d5430238/easydist/torch/schedule/__init__.py -------------------------------------------------------------------------------- /easydist/torch/schedule/graph_mem_plan.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | 16 | class GraphMemPlan: 17 | def __init__( 18 | self, 19 | mem_size: int, 20 | temp_mem_size: int 21 | ): 22 | self.mem_size = mem_size 23 | self.temp_mem_size = temp_mem_size 24 | self.raw_mem_allocs: List[(int, int, bool, str)] = [] # element: (addr, size, is_temp_mem, node_name) 25 | 26 | def append_addr_size(self, addr: int, size: int, is_temp_mem: bool, node_name: str): 27 | self.raw_mem_allocs.append((addr, size, is_temp_mem, node_name)) 28 | 29 | def get_addr(self, idx: int): 30 | assert idx str: 34 | mem_plan_str = "" 35 | for raw_mem_alloc in self.raw_mem_allocs: 36 | if raw_mem_alloc[2]: 37 | is_temp = "True" 38 | else: 39 | is_temp = "False" 40 | mem_plan_str += f"addr: {raw_mem_alloc[0]}, size: {raw_mem_alloc[1]}, is_temp: {is_temp}, node: {raw_mem_alloc[3]}\n" 41 | 42 | return mem_plan_str 43 | 44 | 45 | -------------------------------------------------------------------------------- /easydist/torch/schedule/memory_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import numpy as np 16 | import logging 17 | import torch 18 | from collections import defaultdict 19 | from typing import Callable, Set 20 | 21 | import easydist.config as mdconfig 22 | from easydist.torch.schedule.schedule_result import ScheduleResult 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | _funcs_served_by_back_allocator: Set[Callable] = { 28 | torch.ops.aten.sum.dim_IntList 29 | } 30 | 31 | class MemoryScheduler: 32 | def __init__( 33 | self, 34 | fx_module, # torch.fx.GraphModule 35 | graph_mem_info, # GraphMemInfo 36 | op_streams, 37 | align_scale 38 | ): 39 | self.fx_module = fx_module 40 | self.graph_mem_info = graph_mem_info 41 | self.nodes_by_back_allocator = set() 42 | 43 | self.nodes_to_schedule = [] 44 | self.args = [] 45 | self.outputs = [] 46 | self.op_streams = op_streams 47 | self.schedule_result = None 48 | 49 | for node in fx_module.graph.nodes: 50 | if node.op == 'placeholder' or node.op == 'get_attr': 51 | self.args.append(node) 52 | elif node.op == 'output': 53 | self.outputs.append(node) 54 | else: 55 | self.nodes_to_schedule.append(node) 56 | assert node.name in op_streams, f"node {node.name} misses stream id" 57 | 58 | if node.target in _funcs_served_by_back_allocator: 59 | self.nodes_by_back_allocator.add(node) 60 | 61 | self.node_set = set(self.nodes_to_schedule) 62 | 63 | tensor_sizes = set() 64 | for node in self.nodes_to_schedule: 65 | out_vars = self.graph_mem_info.get_out_vars(node) 66 | for out_var in out_vars: 67 | tensor_sizes.add(out_var.size()) 68 | 69 | logger.info(f"memory align value: {align_scale}") 70 | 71 | align_sizes = [(ten_size+align_scale-1)//align_scale for ten_size in tensor_sizes] 72 | 73 | self.align_scale = align_scale 74 | self.gcd = np.gcd.reduce(list(align_sizes)) 75 | self.gcd *= align_scale 76 | logger.info(f"gcd of memory sizes: {self.gcd}") 77 | 78 | def gen_mem_addresses(self): 79 | if not mdconfig.enable_reschedule: 80 | self.schedule_result = ScheduleResult() 81 | for node in self.fx_module.graph.nodes: 82 | if ( 83 | node.op != 'output' and 84 | node.op != 'placeholder' and 85 | node.op != 'get_attr' 86 | ): 87 | phy_stream_id = self.op_streams[node.name] 88 | self.schedule_result.schedule_node_at_end(node, phy_stream_id) 89 | 90 | #print(f"node schedule result:\n{str(self.schedule_result)}") 91 | required_memory, temp_memory, schedules, ordered_schedules, mem_alloc_info, inter_op_mems, back_alloced_nodes = \ 92 | self.create_min_mem_plan() 93 | 94 | return (required_memory, temp_memory, schedules, ordered_schedules, mem_alloc_info, inter_op_mems, back_alloced_nodes) 95 | 96 | -------------------------------------------------------------------------------- /easydist/torch/schedule/schedule_result.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import torch 16 | 17 | class NodeSchedule: 18 | def __init__( 19 | self, 20 | log_stream_id, 21 | local_idx, 22 | ): 23 | self.log_stream_id = log_stream_id 24 | self.local_idx = local_idx 25 | 26 | class ScheduleResult: 27 | def __init__( 28 | self 29 | ): 30 | self.node_sequences = [] # node sequence list, i.e. list of list 31 | self.node_idx_maps = [] # per stream maps 32 | self.phy_log_stream_id_map = {} 33 | self.phy_stream_ids = [] 34 | self.num_streams = 0 35 | self.global_node_schedules = {} # node -> node schedule(NodeSchedule) 36 | 37 | def schedule_node_at_end(self, node: torch.fx.Node, phy_stream_id: int): 38 | if phy_stream_id not in self.phy_log_stream_id_map: 39 | log_stream_id = self.num_streams 40 | self.num_streams += 1 41 | self.phy_stream_ids.append(phy_stream_id) 42 | self.phy_log_stream_id_map[phy_stream_id] = log_stream_id 43 | else: 44 | log_stream_id = self.phy_log_stream_id_map[phy_stream_id] 45 | while log_stream_id >= len(self.node_sequences): 46 | self.node_sequences.append([]) 47 | self.node_idx_maps.append({}) 48 | 49 | sequence = self.node_sequences[log_stream_id] 50 | node_idx_map = self.node_idx_maps[log_stream_id] 51 | local_idx = len(sequence) 52 | node_schedule = NodeSchedule(log_stream_id, local_idx) 53 | sequence.append(node) 54 | node_idx_map[node] = local_idx 55 | assert node not in self.global_node_schedules 56 | self.global_node_schedules[node] = node_schedule 57 | 58 | def get_node_idx_map(self, stream_id: int): 59 | return self.node_idx_maps[stream_id] 60 | 61 | def get_schedule(self, node: torch.fx.Node): 62 | assert node in self.global_node_schedules, f"node {node} is missed in schedule map" 63 | return self.global_node_schedules[node] 64 | 65 | def get_log_stream_id(self, node: torch.fx.Node): 66 | assert node in self.global_node_schedules 67 | return self.global_node_schedules[node].log_stream_id 68 | 69 | def get_sequence(self, log_stream_id: int): 70 | return self.node_sequences[log_stream_id] 71 | 72 | def get_node(self, log_stream_id: int, local_idx: int): 73 | return self.node_sequences[log_stream_id][local_idx] 74 | 75 | def __str__(self) -> str: 76 | ret = "" 77 | for log_id, sequence in enumerate(self.node_sequences): 78 | ret += f"logic stream id: {log_id}, real stream id: {self.phy_stream_ids[log_id]}\n" 79 | for idx, node in enumerate(sequence): 80 | ret += f" {idx}: {node.name}\n" 81 | return ret 82 | 83 | -------------------------------------------------------------------------------- /easydist/torch/split_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | from typing import List, Union, Tuple, Any, Dict, Callable 16 | 17 | import torch 18 | 19 | 20 | _before_split: Dict[type, Callable[[Any], List[torch.Tensor]]] = {} 21 | _after_split: Dict[type, Callable[[List[torch.Tensor]], Any]] = {} 22 | 23 | def before_split_register(*classes): 24 | 25 | def _register(func: Callable): 26 | for cls in classes: 27 | assert cls not in _before_split, f"split function for {cls} already registered" 28 | _before_split[cls] = func 29 | return func 30 | 31 | return _register 32 | 33 | 34 | def after_split_register(*classes): 35 | 36 | def _register(func: Callable): 37 | for cls in classes: 38 | assert cls not in _after_split, f"split function for {cls} already registered" 39 | _after_split[cls] = func 40 | return func 41 | 42 | return _register 43 | 44 | 45 | @before_split_register(torch.Tensor) 46 | def tensor_before_split(ctx: dict, input: torch.Tensor) -> List[torch.Tensor]: 47 | return [input] 48 | 49 | 50 | @after_split_register(torch.Tensor) 51 | def tensor_after_split(ctx: dict, output: Tuple[torch.Tensor]) -> torch.Tensor: 52 | return output[0] 53 | 54 | 55 | @before_split_register(list) 56 | def list_before_split(ctx: dict, input: List[Union[torch.Tensor, Any]]) -> List[torch.Tensor]: 57 | ctx['is_tensor'] = [] 58 | ctx['non_tensor_vals'] = [] 59 | tup = [] 60 | for x in input: 61 | ctx['is_tensor'].append(isinstance(x, torch.Tensor)) 62 | if ctx['is_tensor'][-1]: 63 | tup.append(x) 64 | else: 65 | ctx['non_tensor_vals'].append(x) 66 | 67 | return tup 68 | 69 | 70 | @after_split_register(list) 71 | def list_after_split(ctx: dict, output: Tuple[torch.Tensor]) -> List[Union[torch.Tensor, Any]]: 72 | ret = [] 73 | output = list(output) 74 | for is_tensor in ctx['is_tensor']: 75 | if is_tensor: 76 | ret.append(output.pop(0)) 77 | else: 78 | ret.append(ctx['non_tensor_vals'].pop(0)) 79 | return ret 80 | 81 | 82 | @before_split_register(tuple) 83 | def tuple_before_split(ctx: dict, input: Tuple[Union[torch.Tensor, Any]]) -> List[torch.Tensor]: 84 | return list_before_split(ctx, list(input)) 85 | 86 | 87 | @after_split_register(tuple) 88 | def tuple_after_split(ctx: dict, output: Tuple[torch.Tensor]) -> Tuple[Union[torch.Tensor, Any]]: 89 | return tuple(list_after_split(ctx, output)) 90 | 91 | 92 | -------------------------------------------------------------------------------- /easydist/torch/symphonia/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/easydist/4ced4cd7ca7a39722670beae16332426d5430238/easydist/torch/symphonia/__init__.py -------------------------------------------------------------------------------- /easydist/torch/symphonia/torch_actor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | 4 | import ray 5 | 6 | 7 | class DistributedTorchRayActor: 8 | 9 | def __init__(self, world_size, rank, local_rank, master_addr, master_port): 10 | self._world_size = world_size 11 | self._rank = rank 12 | self._local_rank = local_rank 13 | self._master_addr = master_addr if master_addr else self._get_current_node_ip() 14 | self._master_port = master_port if master_port else self._get_free_port() 15 | os.environ["MASTER_ADDR"] = self._master_addr 16 | os.environ["MASTER_PORT"] = str(self._master_port) 17 | os.environ["WORLD_SIZE"] = str(self._world_size) 18 | os.environ["RANK"] = str(self._rank) 19 | # NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES 20 | # environment variable for each actor, so always set device to 0 21 | os.environ["LOCAL_RANK"] = "0" 22 | 23 | @staticmethod 24 | def _get_current_node_ip(): 25 | address = ray._private.services.get_node_ip_address() 26 | # strip ipv6 address 27 | return address.strip("[]") 28 | 29 | @staticmethod 30 | def _get_free_port(): 31 | with socket.socket() as sock: 32 | sock.bind(("", 0)) 33 | return sock.getsockname()[1] 34 | 35 | def get_master_addr_port(self): 36 | return self._master_addr, self._master_port 37 | 38 | def entrypoint(self): 39 | raise NotImplementedError 40 | -------------------------------------------------------------------------------- /easydist/torch/tensorfield/README.md: -------------------------------------------------------------------------------- 1 | # TensorField 2 | 3 | TensorField is a module used for unified memory management across processes (libraries). In multi-process scenarios, such as with PyTorch, each process has its own private memory pool. Through the CacheAllocator mechanism, memory is reserved as much as possible without being released, and this reserved memory cannot be used by other processes, resulting in significantly reduced memory usage efficiency. The `tfield-server` is a unified memory management process that helps multiple processes manage memory through a shared memory pool. -------------------------------------------------------------------------------- /easydist/torch/tensorfield/__init__.py: -------------------------------------------------------------------------------- 1 | from .interface import (init_tensorfield_allocator, finalize_tensorfield_allocator, TFieldClient, 2 | init_on_tfeild, load_from_tfeild, is_enabled) 3 | 4 | __all__ = [ 5 | "init_tensorfield_allocator", "finalize_tensorfield_allocator", "TFieldClient", 6 | "init_on_tfeild", "load_from_tfeild", "is_enabled" 7 | ] 8 | -------------------------------------------------------------------------------- /easydist/torch/tensorfield/helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cupy 3 | 4 | torch_cupy_dtype_mapping = { 5 | torch.float32: cupy.float32, 6 | torch.float64: cupy.float64, 7 | torch.int32: cupy.int32, 8 | torch.int64: cupy.int64, 9 | torch.uint8: cupy.uint8, 10 | torch.int8: cupy.int8, 11 | torch.int16: cupy.int16, 12 | torch.float16: cupy.float16, 13 | } 14 | 15 | 16 | def count_param_or_buffer(model: torch.nn.Module) -> int: 17 | # Count the total size of the model parameters and buffers 18 | alloc_size = 0 19 | 20 | def count_fn(param_or_buffer): 21 | nonlocal alloc_size 22 | alloc_size += param_or_buffer.numel() * param_or_buffer.element_size() 23 | return param_or_buffer 24 | 25 | model = model._apply(count_fn) 26 | 27 | return alloc_size 28 | 29 | 30 | def get_tensor_from_ptr(param_or_buffer: torch.Tensor, base_ptr: int, 31 | copy_weight: bool) -> torch.Tensor: 32 | param_or_buffer_size = param_or_buffer.numel() * param_or_buffer.element_size() 33 | 34 | cupy_pointer = cupy.cuda.MemoryPointer(cupy.cuda.UnownedMemory(base_ptr, 35 | param_or_buffer_size, 36 | owner=None), 37 | offset=0) 38 | 39 | if param_or_buffer.dtype not in torch_cupy_dtype_mapping: 40 | raise ValueError( 41 | f"Unsupported dtype: {param_or_buffer.dtype}. Supported dtypes: {torch_cupy_dtype_mapping.keys()}" 42 | ) 43 | cupy_dtype = torch_cupy_dtype_mapping[param_or_buffer.dtype] 44 | 45 | cupy_tensor = cupy.ndarray(shape=param_or_buffer.size(), dtype=cupy_dtype, memptr=cupy_pointer) 46 | param_or_buffer_cuda = torch.as_tensor(cupy_tensor, dtype=param_or_buffer.dtype, device='cuda') 47 | 48 | if copy_weight: 49 | param_or_buffer_cuda.copy_(param_or_buffer.data, non_blocking=True) 50 | 51 | return param_or_buffer_cuda 52 | -------------------------------------------------------------------------------- /easydist/torch/tensorfield/mem_pool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | 4 | import cupy 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class IPCMemoryPool: 10 | _instance = None 11 | _lock = threading.Lock() 12 | 13 | def __new__(cls, *args, **kwargs): 14 | if cls._instance is None: 15 | with cls._lock: 16 | if cls._instance is None: 17 | cls._instance = super(IPCMemoryPool, cls).__new__(cls) 18 | return cls._instance 19 | 20 | def __init__(self, device_id): 21 | self.device_id = device_id 22 | 23 | self.allocaed_memory = {} 24 | self.alloc_handle = {} 25 | 26 | self.mem_pool = cupy.cuda.MemoryPool() 27 | 28 | def malloc(self, size_bytes, stream): 29 | 30 | cupy.cuda.runtime.setDevice(self.device_id) 31 | 32 | with cupy.cuda.ExternalStream(stream): 33 | mem_pointer = self.mem_pool.malloc(size_bytes) 34 | handle = cupy.cuda.runtime.ipcGetMemHandle(mem_pointer.ptr) 35 | 36 | # (NOTE) handle maybe same for pointer in the range of memory block from one malloc call 37 | # so we store the base pointer of the memory block in self.alloc_handle 38 | if handle not in self.alloc_handle: 39 | self.alloc_handle[handle] = mem_pointer.ptr 40 | 41 | if mem_pointer.ptr in self.allocaed_memory: 42 | logger.warn(f"Memory already allocated at {mem_pointer.ptr}") 43 | self.allocaed_memory[mem_pointer.ptr] = mem_pointer 44 | 45 | offset = mem_pointer.ptr - self.alloc_handle[handle] 46 | 47 | if offset < 0: 48 | raise ValueError(f"Offset is negative: {offset}") 49 | 50 | return mem_pointer.ptr, handle, offset 51 | 52 | def free(self, ptr, stream): 53 | if ptr in self.allocaed_memory: 54 | del self.allocaed_memory[ptr] 55 | else: 56 | logger.warn(f"Memory not found with handle {ptr}") 57 | 58 | def usage_statistics(self): 59 | statistics = { 60 | "total_bytes": int(self.mem_pool.total_bytes()), 61 | "used_bytes": int(self.mem_pool.used_bytes()), 62 | } 63 | return statistics 64 | -------------------------------------------------------------------------------- /easydist/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import functools 16 | 17 | 18 | def rsetattr(obj, attr, val): 19 | pre, _, post = attr.rpartition('.') 20 | return setattr(rgetattr(obj, pre) if pre else obj, post, val) 21 | 22 | 23 | def rgetattr(obj, attr, *args): 24 | 25 | def _getattr(obj, attr): 26 | return getattr(obj, attr, *args) 27 | 28 | return functools.reduce(_getattr, [obj] + attr.split('.')) 29 | -------------------------------------------------------------------------------- /easydist/utils/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from .spawn import * 2 | from .mock import * 3 | 4 | ALL_PLATFORM = ["torch", "jax", "tvm"] 5 | -------------------------------------------------------------------------------- /easydist/utils/testing/mock.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | 16 | class MockDeviceMesh: 17 | 18 | def __init__(self): 19 | pass 20 | 21 | 22 | class TorchMockDeviceMesh(MockDeviceMesh): 23 | 24 | def __init__(self, *arg, debug_only=False): 25 | super().__init__() 26 | self.shape = tuple(arg) 27 | self.debug_only = debug_only 28 | 29 | def size(self, i): 30 | return self.shape[i] 31 | 32 | def __str__(self) -> str: 33 | return f"TorchMockDeviceMesh(shape={self.shape})" 34 | 35 | def __repr__(self) -> str: 36 | return self.__str__() 37 | 38 | 39 | class JaxDeviceID: 40 | 41 | def __init__(self, *arg): 42 | self.shape = tuple(arg) 43 | 44 | 45 | class JaxMockDeviceMesh(MockDeviceMesh): 46 | 47 | def __init__(self, *arg): 48 | super().__init__() 49 | self.device_ids = JaxDeviceID(*arg) 50 | 51 | 52 | def assert_partial_func_equal(func1, func2): 53 | assert func1.args == func2.args 54 | assert func1.keywords == func2.keywords 55 | assert func1.func.__name__ == func2.func.__name__ 56 | -------------------------------------------------------------------------------- /easydist/utils/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import time 16 | 17 | import numpy as np 18 | 19 | from easydist.platform import get_backend 20 | import easydist.config as mdconfig 21 | 22 | 23 | class EDTimer: 24 | 25 | def __init__(self, 26 | func, 27 | trials=3, 28 | warmup_trials=3, 29 | times_per_trials=1, 30 | in_ms=True, 31 | device=None) -> None: 32 | self.func = func 33 | self.warmup_trials = warmup_trials 34 | self.trials = trials 35 | self.times_per_trials = times_per_trials 36 | self.in_ms = in_ms 37 | 38 | self.device = device 39 | if self.device == None: 40 | self.device = mdconfig.easydist_device 41 | 42 | self.backend = get_backend() 43 | 44 | def time(self, return_all=False): 45 | all_elapsed_time = None 46 | if self.backend == "jax": 47 | all_elapsed_time = self.time_jax() 48 | elif self.backend == "torch": 49 | if self.device == "cuda": 50 | all_elapsed_time = self.time_torch_cuda() 51 | elif self.device == "cpu": 52 | all_elapsed_time = self.time_cpu() 53 | if all_elapsed_time is not None: 54 | if return_all is True: 55 | return all_elapsed_time 56 | return np.mean(all_elapsed_time) 57 | return None 58 | 59 | def time_cpu(self): 60 | for _ in range(self.warmup_trials): 61 | self.func() 62 | 63 | elapsed_time = [] 64 | for _ in range(self.trials): 65 | start_t = time.perf_counter() 66 | for _ in range(self.times_per_trials): 67 | self.func() 68 | elapsed_time.append(time.perf_counter() - start_t) 69 | 70 | elapsed_time = np.array(elapsed_time) / self.times_per_trials 71 | 72 | # time elapsed in **milliseconds** 73 | if self.in_ms: 74 | return elapsed_time * 1000 75 | return elapsed_time 76 | 77 | def time_torch_cuda(self): 78 | import torch 79 | 80 | start_evt = [] 81 | end_evt = [] 82 | for _ in range(0, self.trials): 83 | start_evt.append(torch.cuda.Event(enable_timing=True)) 84 | end_evt.append(torch.cuda.Event(enable_timing=True)) 85 | 86 | for trial_idx in range(0, self.trials + self.warmup_trials): 87 | evt_idx = trial_idx - self.warmup_trials 88 | 89 | if evt_idx >= 0: 90 | start_evt[evt_idx].record() 91 | 92 | for _ in range(self.times_per_trials): 93 | self.func() 94 | 95 | if evt_idx >= 0: 96 | end_evt[evt_idx].record() 97 | 98 | torch.cuda.synchronize() 99 | elapsed_time = [] 100 | for evt_idx in range(0, self.trials): 101 | # time elapsed in **milliseconds** 102 | elapsed_time.append(start_evt[evt_idx].elapsed_time(end_evt[evt_idx])) 103 | elapsed_time = np.array(elapsed_time) / self.times_per_trials 104 | 105 | if self.in_ms: 106 | return elapsed_time 107 | return elapsed_time / 1000 108 | 109 | def time_jax(self): 110 | import jax 111 | for _ in range(self.warmup_trials): 112 | self.func() 113 | (jax.device_put(0.) + 0).block_until_ready() 114 | 115 | elapsed_time = [] 116 | for _ in range(self.trials): 117 | 118 | start_t = time.perf_counter() 119 | 120 | for _ in range(self.times_per_trials): 121 | self.func() 122 | 123 | (jax.device_put(0.) + 0).block_until_ready() 124 | 125 | elapsed_time.append(time.perf_counter() - start_t) 126 | elapsed_time = np.array(elapsed_time) / self.times_per_trials 127 | 128 | # time elapsed in **milliseconds** 129 | if self.in_ms: 130 | return elapsed_time * 1000 131 | return elapsed_time 132 | -------------------------------------------------------------------------------- /easydist/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import re 16 | 17 | VERSION = "0.1.0" 18 | 19 | 20 | def is_release_version(): 21 | return bool(re.match(r"^\d+\.\d+\.\d+$", VERSION)) 22 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ## PyTorch Examples 4 | 5 | Please use `torchrun` to launch the pytorch examples. Take `simple_function.py` for example: 6 | 7 | ```shell 8 | # for single-node environment (2 GPUs for examle) 9 | torchrun --nproc_per_node 2 --master_port 9543 ./torch/test_simple.py 10 | 11 | # for multi-node environment (2 nodes for example, 2 GPUs each node): 12 | # Machine1: 13 | torchrun --nnodes 2 --node_rank 0 \ 14 | --master_addr [Machine1 IP] --master_port 9543 \ 15 | ./torch/test_simple.py 16 | # Machine2: 17 | torchrun --nnodes 2 --node_rank 1 \ 18 | --master_addr [Machine1 IP] --master_port 9543 \ 19 | ./torch/test_simple.py 20 | ``` 21 | 22 | For more details of `torchrun` please refer [Torch Distributed Elastic](https://pytorch.org/docs/stable/elastic/run.html). 23 | 24 | 25 | ## Jax Examples 26 | 27 | Please use `mpirun` to launch the pytorch examples. Take `simple_function.py` for example: 28 | 29 | ```shell 30 | # for single-node environment (2 GPUs for examle) 31 | mpirun -np 2 python ./examples/jax/simple_function.py 32 | ``` 33 | 34 | For multi-node environments, you may need to read docs [Multi Process in Jax](https://jax.readthedocs.io/en/latest/multi_process.html). Also, the function `easydist_setup_jax` in `easydist/jax/__init__.py` may need to be modified to launch process in a clustered environment such as SLURM. -------------------------------------------------------------------------------- /examples/jax/simple_function.py: -------------------------------------------------------------------------------- 1 | # mpirun -np 2 python ./examples/jax/simple_function.py 2 | 3 | import logging 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from easydist import easydist_setup, mdconfig 9 | from easydist.jax.api import easydist_compile 10 | 11 | 12 | @easydist_compile() 13 | def foo_func(x, y): 14 | tanh = jnp.tanh(x) 15 | return jnp.exp(tanh) @ y + tanh 16 | 17 | 18 | def main(): 19 | mdconfig.log_level = logging.INFO 20 | easydist_setup(backend="jax", device="cuda") 21 | 22 | key = jax.random.PRNGKey(0) 23 | key, subkey = jax.random.split(key) 24 | randn_x = jax.random.normal(key, (10, 10)) 25 | randn_y = jax.random.normal(subkey, (10, 10)) 26 | 27 | jax_out = foo_func.original_func(randn_x, randn_y) 28 | md_out = foo_func(randn_x, randn_y) 29 | 30 | if not jax.numpy.allclose(jax_out, md_out): 31 | raise RuntimeError("simlpe function test failed!!") 32 | 33 | print("simlpe function example pass.") 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /examples/jax/simple_model.py: -------------------------------------------------------------------------------- 1 | # mpirun -np 2 python ./examples/jax/simple_model.py --mode inference 2 | 3 | import argparse 4 | import logging 5 | 6 | import jax 7 | import optax 8 | from flax import linen as nn 9 | from flax.training import train_state 10 | from jax import random 11 | 12 | from easydist import easydist_setup, mdconfig 13 | from easydist.jax.api import easydist_compile 14 | 15 | 16 | class Foo(nn.Module): 17 | """A simple CNN model.""" 18 | 19 | @nn.compact 20 | def __call__(self, x): 21 | x = nn.LayerNorm()(x) 22 | x = nn.Dense(features=6)(x) 23 | x = nn.relu(x) 24 | return x 25 | 26 | 27 | def inference_example(module, params, input): 28 | 29 | @easydist_compile() 30 | def inference_step(params, input): 31 | out = module.apply({'params': params}, input) 32 | return out 33 | 34 | jax_out = inference_step.original_func(params, input) 35 | md_out = inference_step(params, input) 36 | 37 | if not jax.numpy.allclose(jax_out, md_out): 38 | raise RuntimeError("simlpe model test failed!!") 39 | 40 | print("simlpe model inference example pass.") 41 | 42 | 43 | def train_example(module, params, input): 44 | 45 | tx = optax.adam(learning_rate=0.01) 46 | 47 | state = train_state.TrainState.create(apply_fn=module.apply, params=params, tx=tx) 48 | 49 | @easydist_compile() 50 | def train_step(state, batch): 51 | """Train for a single step.""" 52 | 53 | def loss_fn(params): 54 | logits = state.apply_fn({'params': params}, batch) 55 | loss = logits.mean() 56 | return loss 57 | 58 | grad_fn = jax.grad(loss_fn) 59 | grads = grad_fn(state.params) 60 | state = state.apply_gradients(grads=grads) 61 | return state 62 | 63 | jax_state = train_step.original_func(state, input) 64 | md_state = train_step(state, input) 65 | 66 | if not jax.tree_util.tree_all(jax.tree_map(jax.numpy.allclose, jax_state, md_state)): 67 | raise RuntimeError("simlpe model train example failed!!") 68 | 69 | print("simlpe model train example pass.") 70 | 71 | 72 | def main(): 73 | parser = argparse.ArgumentParser(description="Simple example of parallelize model.") 74 | 75 | parser.add_argument("--mode", 76 | type=str, 77 | default=None, 78 | choices=["train", "inference"], 79 | required=True) 80 | 81 | args = parser.parse_args() 82 | 83 | mdconfig.log_level = logging.INFO 84 | easydist_setup(backend="jax", device="cuda") 85 | 86 | model = Foo() 87 | 88 | root_key = jax.random.PRNGKey(seed=0) 89 | main_key, params_key = jax.random.split(key=root_key, num=2) 90 | rand_input = random.normal(main_key, (4, 6)) 91 | variables = model.init(params_key, rand_input) 92 | params = variables['params'] 93 | 94 | if args.mode == "train": 95 | train_example(model, params, rand_input) 96 | if args.mode == "inference": 97 | inference_example(model, params, rand_input) 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /examples/torch/cifar10.py: -------------------------------------------------------------------------------- 1 | # code modified from https://github.com/pytorch/tutorials/blob/main/beginner_source/blitz/cifar10_tutorial.py 2 | 3 | import logging 4 | import os 5 | import random 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torchvision 11 | from torchvision.models import resnet18 12 | import torchvision.transforms as transforms 13 | 14 | from easydist import easydist_setup, mdconfig 15 | from easydist.torch.api import easydist_compile 16 | 17 | random.seed(42) 18 | torch.manual_seed(42) 19 | 20 | def main(): 21 | 22 | # setting up easydist and torch.distributed 23 | mdconfig.log_level = logging.INFO 24 | easydist_setup(backend="torch", device="cuda") 25 | 26 | torch.distributed.init_process_group(backend="nccl") 27 | local_rank = int(os.environ["LOCAL_RANK"]) 28 | torch.cuda.set_device(local_rank) 29 | 30 | transform = transforms.Compose( 31 | [transforms.ToTensor(), 32 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 33 | 34 | batch_size = 128 35 | 36 | trainset = torchvision.datasets.CIFAR10(root='./data', 37 | train=True, 38 | download=True, 39 | transform=transform) 40 | trainloader = torch.utils.data.DataLoader(trainset, 41 | batch_size=batch_size, 42 | shuffle=True, 43 | drop_last=True, 44 | num_workers=2) 45 | 46 | net = resnet18().cuda() 47 | 48 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 49 | 50 | @easydist_compile 51 | def train_step(net, optimizer, inputs, labels): 52 | 53 | criterion = nn.CrossEntropyLoss() 54 | 55 | outputs = net(inputs) 56 | loss = criterion(outputs, labels) 57 | loss.backward() 58 | optimizer.step() 59 | 60 | optimizer.zero_grad() 61 | 62 | return loss 63 | 64 | for epoch in range(2): # loop over the dataset multiple times 65 | 66 | running_loss = 0.0 67 | for i, data in enumerate(trainloader, 0): 68 | # get the inputs; data is a list of [inputs, labels] 69 | inputs, labels = data[0].cuda(), data[1].cuda() 70 | 71 | loss = train_step(net, optimizer, inputs, labels) 72 | 73 | # print statistics 74 | running_loss += loss.item() 75 | if i % 100 == 99: # print every 100 mini-batches 76 | print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}') 77 | running_loss = 0.0 78 | 79 | print('Finished Training') 80 | 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /examples/torch/gnn/data.py: -------------------------------------------------------------------------------- 1 | # code modified from https://github.com/Diego999/pyGAT 2 | 3 | import dgl 4 | 5 | 6 | def load_data(dataset="cora"): 7 | 8 | if dataset == "cora": 9 | dataset = dgl.data.CoraGraphDataset() 10 | adj = dataset[0].adj().to_dense() 11 | 12 | features = dataset[0].ndata["feat"] 13 | labels = dataset[0].ndata["label"] 14 | 15 | train_mask = dataset[0].ndata["train_mask"] 16 | val_mask = dataset[0].ndata["val_mask"] 17 | test_mask = dataset[0].ndata["test_mask"] 18 | 19 | idx_train = (train_mask == True).nonzero().flatten() 20 | idx_val = (val_mask == True).nonzero().flatten() 21 | idx_test = (test_mask == True).nonzero().flatten() 22 | 23 | return adj, features, labels, idx_train, idx_val, idx_test, train_mask 24 | 25 | elif dataset == "wiki-cs": 26 | dataset = dgl.data.WikiCSDataset() 27 | adj = dataset[0].adj().to_dense() 28 | 29 | features = dataset[0].ndata["feat"] 30 | labels = dataset[0].ndata["label"] 31 | 32 | train_mask = dataset[0].ndata["train_mask"][:, 0] 33 | val_mask = dataset[0].ndata["val_mask"][:, 0] 34 | test_mask = dataset[0].ndata["test_mask"] 35 | 36 | idx_train = (train_mask == True).nonzero().flatten() 37 | idx_val = (val_mask == True).nonzero().flatten() 38 | idx_test = (test_mask == True).nonzero().flatten() 39 | 40 | return adj, features, labels, idx_train, idx_val, idx_test, train_mask 41 | -------------------------------------------------------------------------------- /examples/torch/gnn/gat.py: -------------------------------------------------------------------------------- 1 | # code modified from https://github.com/Diego999/pyGAT 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class GATLayer(nn.Module): 9 | 10 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 11 | super(GATLayer, self).__init__() 12 | self.dropout = dropout 13 | self.in_features = in_features 14 | self.out_features = out_features 15 | self.alpha = alpha 16 | self.concat = concat 17 | 18 | self.linear_w = nn.Linear(in_features=in_features, out_features=out_features, bias=None) 19 | nn.init.xavier_uniform_(self.linear_w.weight.data, gain=1.414) 20 | 21 | self.linear_a_1 = nn.Linear(in_features=out_features, out_features=1, bias=None) 22 | self.linear_a_2 = nn.Linear(in_features=out_features, out_features=1, bias=None) 23 | nn.init.xavier_uniform_(self.linear_a_1.weight.data, gain=1.414) 24 | nn.init.xavier_uniform_(self.linear_a_2.weight.data, gain=1.414) 25 | 26 | self.leakyrelu = nn.LeakyReLU(self.alpha) 27 | 28 | self.softmax = nn.Softmax(dim=-1) 29 | self.dropout = nn.Dropout(self.dropout) 30 | 31 | self.register_buffer("zero_vec", None) 32 | 33 | def forward(self, h, adj): 34 | wh = self.linear_w(h) 35 | wh1 = self.linear_a_1(wh) 36 | wh2 = self.linear_a_2(wh) 37 | 38 | e = self.leakyrelu(wh1 + wh2.T) 39 | 40 | if self.zero_vec is None: 41 | self.zero_vec = -10e10 * torch.ones_like(e) 42 | attention = torch.where(adj > 0, e, self.zero_vec) 43 | attention = self.dropout(self.softmax(attention)) 44 | 45 | h_new = torch.matmul(attention, wh) 46 | 47 | if self.concat: 48 | return F.elu(h_new) 49 | else: 50 | return h_new 51 | 52 | 53 | class GAT(nn.Module): 54 | 55 | def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads): 56 | """Dense version of GAT.""" 57 | super(GAT, self).__init__() 58 | self.dropout = dropout 59 | 60 | self.attentions = nn.ModuleList([ 61 | GATLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads) 62 | ]) 63 | 64 | self.last_layer = GATLayer(nhid * nheads, 65 | nclass, 66 | dropout=dropout, 67 | alpha=alpha, 68 | concat=False) 69 | 70 | self.dropout_1 = nn.Dropout(self.dropout) 71 | self.dropout_2 = nn.Dropout(self.dropout) 72 | 73 | def forward(self, x, adj): 74 | x = self.dropout_1(x) 75 | 76 | multi_head_att = [] 77 | for att in self.attentions: 78 | multi_head_att.append(att(x, adj)) 79 | 80 | x = torch.cat(multi_head_att, dim=1) 81 | x = self.dropout_2(x) 82 | x = F.elu(self.last_layer(x, adj)) 83 | 84 | return F.log_softmax(x, dim=1) 85 | -------------------------------------------------------------------------------- /examples/torch/gnn/train.py: -------------------------------------------------------------------------------- 1 | # code modified from https://github.com/Diego999/pyGAT 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import sys 7 | import random 8 | import time 9 | 10 | import rich 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | from easydist import easydist_setup, mdconfig 15 | from easydist.torch.api import easydist_compile 16 | 17 | sys.path.append(os.path.abspath(__file__)) 18 | from gat import GAT 19 | from data import load_data 20 | 21 | random.seed(42) 22 | torch.manual_seed(42) 23 | 24 | 25 | def accuracy(output, labels): 26 | preds = output.max(1)[1].type_as(labels) 27 | correct = preds.eq(labels).double() 28 | correct = correct.sum() 29 | return correct / len(labels) 30 | 31 | 32 | def get_args(): 33 | parser = argparse.ArgumentParser() 34 | 35 | parser.add_argument('--hidden', type=int, default=4096) 36 | parser.add_argument('--nb_heads', type=int, default=4) 37 | parser.add_argument('--dropout', type=float, default=0.5) 38 | parser.add_argument('--alpha', type=float, default=0.2) 39 | 40 | parser.add_argument('--lr', type=float, default=0.0005) 41 | parser.add_argument('--weight_decay', type=float, default=5e-4) 42 | 43 | parser.add_argument('--dataset', type=str, choices=["wiki-cs", "cora"], default="wiki-cs") 44 | parser.add_argument('--epochs', type=int, default=800) 45 | 46 | args = parser.parse_args() 47 | rich.print(f"Training config: {args}") 48 | 49 | return args 50 | 51 | 52 | def main(): 53 | 54 | # setting up easydist and torch.distributed 55 | mdconfig.log_level = logging.INFO 56 | easydist_setup(backend="torch", device="cuda") 57 | 58 | torch.distributed.init_process_group(backend="nccl") 59 | local_rank = int(os.environ["LOCAL_RANK"]) 60 | torch.cuda.set_device(local_rank) 61 | 62 | args = get_args() 63 | 64 | adj, features, labels, idx_train, idx_val, _, train_mask = load_data(args.dataset) 65 | 66 | model = GAT(nfeat=features.shape[1], 67 | nhid=args.hidden, 68 | nclass=int(labels.max()) + 1, 69 | dropout=args.dropout, 70 | nheads=args.nb_heads, 71 | alpha=args.alpha).cuda() 72 | 73 | adj, features, labels, idx_train = adj.cuda(), features.cuda(), labels.cuda(), idx_train.cuda() 74 | train_mask = train_mask.cuda() 75 | 76 | optimizer = torch.optim.Adam(model.parameters(), 77 | lr=args.lr, 78 | weight_decay=args.weight_decay, 79 | foreach=True, 80 | capturable=True) 81 | 82 | loss_scale = len(idx_train) / len(labels) 83 | 84 | @easydist_compile() 85 | def train_step(model, optimizer, adj, features, labels, mask): 86 | output = model(features, adj) 87 | loss_train = F.nll_loss(output * mask[:, None], labels * mask) / loss_scale 88 | loss_train.backward() 89 | optimizer.step() 90 | optimizer.zero_grad() 91 | 92 | return output, loss_train 93 | 94 | for epoch in range(args.epochs): 95 | 96 | start_t = time.perf_counter() 97 | 98 | output, loss_train = train_step(model, optimizer, adj, features, labels, train_mask) 99 | 100 | acc_train = accuracy(output[idx_train], labels[idx_train]) 101 | 102 | loss_val = F.nll_loss(output[idx_val], labels[idx_val]) 103 | acc_val = accuracy(output[idx_val], labels[idx_val]) 104 | 105 | rich.print('epoch={:04d} | '.format(epoch + 1), 106 | 'loss_train={:.6f}'.format(loss_train.data.item()), 107 | 'acc_train={:.6f}'.format(acc_train.data.item()), 108 | 'loss_val={:.6f}'.format(loss_val.data.item()), 109 | 'acc_val={:.6f}'.format(acc_val.data.item()), 110 | 'time={:.3f}'.format(time.perf_counter() - start_t)) 111 | 112 | 113 | if __name__ == "__main__": 114 | main() 115 | -------------------------------------------------------------------------------- /examples/torch/gpt_train.py: -------------------------------------------------------------------------------- 1 | # EASYDIST_LOGLEVEL=INFO torchrun --nproc_per_node 8 examples/torch/gpt_train.py 2 | import copy 3 | import os 4 | 5 | import torch 6 | from torch.distributed._tensor import DeviceMesh 7 | from torch.distributed.distributed_c10d import _get_default_group 8 | from torch.distributed.utils import _sync_module_states 9 | 10 | from benchmark.bench_case import GPTCase 11 | from benchmark.torch.model.gpt import GPT 12 | from easydist import easydist_setup 13 | from easydist.torch.api import easydist_compile 14 | from easydist.torch.device_mesh import set_device_mesh 15 | 16 | 17 | def broadcast_module(model): 18 | _sync_module_states(model, 19 | _get_default_group(), 20 | broadcast_bucket_size=int(250 * 1024 * 1024), 21 | src=0, 22 | params_and_buffers_to_ignore=set()) 23 | 24 | return model 25 | 26 | GPT_CASE = GPTCase( 27 | num_layers=4, 28 | hidden_dim=1024, 29 | num_heads=32, 30 | seq_size=128 31 | ) 32 | 33 | def train_example(): 34 | 35 | # when using cuda_graph, because of the warm-up and cuda graph capture, 36 | # the result of the first step is equivalent to the original result of the third step 37 | @easydist_compile(tracing_mode="fake", cuda_graph=False) 38 | def train_step(input, model, opt): 39 | out = model(input).mean() 40 | out.backward() 41 | opt.step() 42 | opt.zero_grad(True) 43 | return out 44 | 45 | # (NOTE) initialize cuda context first see https://github.com/pytorch/pytorch/issues/92627 46 | torch.ones(1).cuda() 47 | with torch.device('cuda'): 48 | model = GPT( 49 | depth=GPT_CASE.num_layers, 50 | dim=GPT_CASE.hidden_dim, 51 | num_heads=GPT_CASE.num_heads, 52 | ) 53 | 54 | randn_input = torch.randn(GPT_CASE.batch_size, GPT_CASE.seq_size, GPT_CASE.hidden_dim) 55 | 56 | # broadcast the parameter and input 57 | model = broadcast_module(model) 58 | torch.distributed.broadcast(randn_input, src=0) 59 | 60 | opt = torch.optim.SGD(model.parameters(), lr=0.001, foreach=True) 61 | 62 | model_2 = copy.deepcopy(model) 63 | opt_2 = torch.optim.SGD(model_2.parameters(), lr=0.001, foreach=True) 64 | 65 | torch_step_1_result = train_step.original_func(randn_input, model, opt) 66 | torch_step_2_result = train_step.original_func(randn_input, model, opt) 67 | 68 | md_step_1_result = train_step(randn_input, model_2, opt_2) 69 | md_step_2_result = train_step(randn_input, model_2, opt_2) 70 | 71 | assert torch.allclose(torch_step_1_result, 72 | md_step_1_result), f"GPT model training test failed. {torch_step_1_result} {md_step_1_result}" 73 | assert torch.allclose(torch_step_2_result, 74 | md_step_2_result), f"GPT model training test failed. {torch_step_1_result} {md_step_1_result}" 75 | 76 | print("GPT model training example pass.") 77 | 78 | 79 | def main(): 80 | # setting up easydist and torch.distributed 81 | easydist_setup(backend="torch", device="cuda", allow_tf32=False) 82 | 83 | torch.distributed.init_process_group(backend="nccl") 84 | local_rank = int(os.environ["LOCAL_RANK"]) 85 | world_size = int(os.environ["WORLD_SIZE"]) 86 | torch.cuda.set_device(local_rank) 87 | 88 | mesh = torch.arange(world_size).reshape(2, 2, 2) 89 | set_device_mesh(DeviceMesh("cuda", mesh, mesh_dim_names=["spmd0", "spmd1", "spmd2"])) 90 | 91 | train_example() 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /examples/torch/resnet18.py: -------------------------------------------------------------------------------- 1 | # EASYDIST_LOGLEVEL=DEBUG torchrun --nproc_per_node 8 examples/torch/resnet18.py 2 | import argparse 3 | import copy 4 | import os 5 | from contextlib import nullcontext 6 | 7 | import torch 8 | from torch._subclasses.fake_tensor import FakeTensorMode 9 | from torch.distributed.distributed_c10d import _get_default_group 10 | from torch.distributed.utils import _sync_module_states 11 | from torch.utils.checkpoint import checkpoint 12 | from torch.distributed._tensor import DeviceMesh 13 | 14 | from torchvision.models import resnet18 15 | 16 | from easydist import easydist_setup, mdconfig 17 | from easydist.torch.api import easydist_compile 18 | from easydist.torch.device_mesh import set_device_mesh 19 | from easydist.torch.experimental.pp.compile_pipeline import annotate_split_points 20 | 21 | 22 | def broadcast_module(model): 23 | _sync_module_states(model, 24 | _get_default_group(), 25 | broadcast_bucket_size=int(250 * 1024 * 1024), 26 | src=0, 27 | params_and_buffers_to_ignore=set()) 28 | 29 | return model 30 | 31 | 32 | def train_example(): 33 | 34 | # when using cuda_graph, because of the warm-up and cuda graph capture, 35 | # the result of the first step is equivalent to the original result of the third step 36 | @easydist_compile(tracing_mode="fake", cuda_graph=False) 37 | def train_step(input, model, opt): 38 | out = model(input).mean() 39 | out.backward() 40 | opt.step() 41 | opt.zero_grad(True) 42 | return out 43 | 44 | fake_mode = FakeTensorMode() 45 | 46 | # (NOTE) initialize cuda context first see https://github.com/pytorch/pytorch/issues/92627 47 | torch.ones(1).cuda() 48 | with torch.device('cuda'): 49 | model = resnet18() 50 | 51 | randn_input = torch.randn(16, 3, 224, 224) 52 | 53 | # broadcast the parameter and input 54 | model = broadcast_module(model) 55 | torch.distributed.broadcast(randn_input, src=0) 56 | 57 | opt = torch.optim.SGD(model.parameters(), lr=0.001, foreach=True) 58 | 59 | model_2 = copy.deepcopy(model) 60 | opt_2 = torch.optim.SGD(model_2.parameters(), lr=0.001, foreach=True) 61 | 62 | torch_step_1_result = train_step.original_func(randn_input, model, opt) 63 | torch_step_2_result = train_step.original_func(randn_input, model, opt) 64 | 65 | md_step_1_result = train_step(randn_input, model_2, opt_2) 66 | md_step_2_result = train_step(randn_input, model_2, opt_2) 67 | 68 | assert torch.allclose(torch_step_1_result, 69 | md_step_1_result), "resnet model training test failed." 70 | assert torch.allclose(torch_step_2_result, 71 | md_step_2_result), "resnet model training test failed." 72 | 73 | print("resnet model training example pass.") 74 | 75 | 76 | def main(): 77 | # setting up easydist and torch.distributed 78 | easydist_setup(backend="torch", device="cuda", allow_tf32=False) 79 | 80 | torch.distributed.init_process_group(backend="nccl") 81 | local_rank = int(os.environ["LOCAL_RANK"]) 82 | world_size = int(os.environ["WORLD_SIZE"]) 83 | torch.cuda.set_device(local_rank) 84 | 85 | mesh = torch.arange(world_size).reshape(2, 2, 2) 86 | set_device_mesh(DeviceMesh("cuda", mesh, mesh_dim_names=["spmd0", "spmd1", "spmd2"])) 87 | 88 | train_example() 89 | 90 | if __name__ == "__main__": 91 | main() -------------------------------------------------------------------------------- /examples/torch/simple_ddp.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | 5 | import numpy 6 | import torch 7 | from torch.distributed.distributed_c10d import _get_default_group 8 | from torch.distributed.utils import _sync_module_states 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from torch.distributed._tensor import DeviceMesh 11 | 12 | from easydist import easydist_setup, mdconfig 13 | from easydist.torch.api import easydist_compile 14 | from easydist.torch.device_mesh import set_device_mesh 15 | 16 | 17 | def broadcast_module(model): 18 | _sync_module_states(model, 19 | _get_default_group(), 20 | broadcast_bucket_size=int(250 * 1024 * 1024), 21 | src=0, 22 | params_and_buffers_to_ignore=set()) 23 | 24 | return model 25 | 26 | 27 | class Foo(torch.nn.Module): 28 | 29 | def __init__(self): 30 | super().__init__() 31 | self.norm = torch.nn.LayerNorm(1024) 32 | self.linear = torch.nn.Linear(1024, 1024) 33 | 34 | def forward(self, x): 35 | x = self.norm(x) 36 | x = self.linear(x) 37 | return x.relu() 38 | 39 | 40 | def main(): 41 | 42 | # setting up easydist and torch.distributed 43 | mdconfig.log_level = logging.INFO 44 | easydist_setup(backend="torch", device="cuda", allow_tf32=False) 45 | 46 | torch.distributed.init_process_group(backend="nccl") 47 | local_rank = int(os.environ["LOCAL_RANK"]) 48 | torch.cuda.set_device(local_rank) 49 | 50 | world_size = torch.distributed.get_world_size() 51 | mesh_shape = numpy.array(range(world_size)).reshape(-1, 1).tolist() 52 | mesh = DeviceMesh("cuda", mesh_shape, mesh_dim_names=["dp", "placeholder"]) 53 | set_device_mesh(mesh) 54 | 55 | # when using cuda_graph, because of the warm-up and cuda graph capture, 56 | # the result of the first step is equivalent to the original result of the third step 57 | @easydist_compile(tracing_mode="fake", cuda_graph=False, parallel_mode="ddp") 58 | def train_step(input, model, opt): 59 | out = model(input).mean() 60 | out.backward() 61 | opt.step() 62 | opt.zero_grad(True) 63 | 64 | return out 65 | 66 | with torch.device('cuda'): 67 | model = Foo() 68 | randn_input = torch.randn(1024, 1024) 69 | 70 | model = broadcast_module(model) 71 | model_2 = copy.deepcopy(model) 72 | 73 | ddp_model = DDP(model) 74 | 75 | opt = torch.optim.Adam(ddp_model.parameters(), lr=0.001, fused=True, capturable=True) 76 | opt_2 = torch.optim.Adam(model_2.parameters(), lr=0.001, fused=True, capturable=True) 77 | 78 | torch_step_1_result = train_step.original_func(randn_input, ddp_model, opt) 79 | torch_step_2_result = train_step.original_func(randn_input, ddp_model, opt) 80 | 81 | md_step_1_result = train_step(randn_input, model_2, opt_2) 82 | md_step_2_result = train_step(randn_input, model_2, opt_2) 83 | 84 | assert torch.allclose(torch_step_1_result, 85 | md_step_1_result), "simple model training test failed." 86 | assert torch.allclose(torch_step_2_result, 87 | md_step_2_result), "simple model training test failed." 88 | 89 | print("simple ddp training example pass.") 90 | 91 | 92 | if __name__ == "__main__": 93 | main() 94 | -------------------------------------------------------------------------------- /examples/torch/simple_function.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import torch 5 | import torch.distributed 6 | from torch.distributed._tensor import DeviceMesh 7 | 8 | from easydist import easydist_setup, mdconfig 9 | from easydist.torch.api import easydist_compile, set_device_mesh 10 | 11 | 12 | def main(): 13 | mdconfig.log_level = logging.INFO 14 | easydist_setup(backend="torch", device="cuda", allow_tf32=False) 15 | 16 | torch.distributed.init_process_group(backend="nccl") 17 | local_rank = int(os.environ["LOCAL_RANK"]) 18 | torch.cuda.set_device(local_rank) 19 | 20 | world_size = torch.distributed.get_world_size() 21 | device_mesh = DeviceMesh('cuda', torch.arange(world_size).reshape(-1, 1), mesh_dim_names=['spmd0', 'spmd1']) 22 | set_device_mesh(device_mesh) 23 | 24 | randn_x = torch.randn(10, 10, requires_grad=True).cuda() 25 | randn_y = torch.randn(10, 10, requires_grad=True).cuda() 26 | torch.distributed.broadcast(randn_x, src=0) 27 | torch.distributed.broadcast(randn_y, src=0) 28 | 29 | @easydist_compile(cuda_graph=False) 30 | def foo_func(x, y): 31 | tanh = torch.tanh(x) 32 | return torch.mm(torch.exp(tanh), y) + tanh 33 | 34 | torch_out = foo_func.original_func(randn_x, randn_y) 35 | md_out = foo_func(randn_x, randn_y) 36 | 37 | if not torch.allclose(torch_out, md_out): 38 | raise RuntimeError("simlpe function test failed!!") 39 | 40 | print("simlpe function example pass.") 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /examples/torch/stable_diffusion.py: -------------------------------------------------------------------------------- 1 | # ENABLE_COMPILE_CACHE=1 torchrun --nproc_per_node 4 examples/torch/stable_diffusion.py 2 | import copy 3 | import functools 4 | import logging 5 | import os 6 | 7 | import diffusers 8 | import torch 9 | import torch.utils._pytree as pytree 10 | from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline 11 | 12 | from easydist import easydist_setup, mdconfig 13 | from easydist.torch.api import easydist_compile 14 | from easydist.torch.device_mesh import set_device_mesh 15 | from torch.distributed._tensor import DeviceMesh 16 | 17 | pytree._register_pytree_node( 18 | diffusers.models.unet_2d_condition.UNet2DConditionOutput, lambda x: ([x.sample], None), 19 | lambda values, _: diffusers.models.unet_2d_condition.UNet2DConditionOutput(values[0])) 20 | 21 | def main(): 22 | # setting up easydist and torch.distributed 23 | mdconfig.log_level = logging.INFO 24 | easydist_setup(backend="torch", device="cuda") 25 | 26 | torch.distributed.init_process_group(backend="nccl") 27 | local_rank = int(os.environ["LOCAL_RANK"]) 28 | world_size = int(os.environ["WORLD_SIZE"]) 29 | torch.cuda.set_device(local_rank) 30 | torch.manual_seed(42) 31 | 32 | mesh = torch.arange(world_size).reshape(1, -1) 33 | set_device_mesh(DeviceMesh("cuda", mesh, mesh_dim_names=["spmd0", "spmd1"])) 34 | 35 | model_id = "stabilityai/stable-diffusion-2" 36 | 37 | # Use the Euler scheduler here instead 38 | scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") 39 | pipe = StableDiffusionPipeline.from_pretrained(model_id, 40 | scheduler=scheduler, 41 | torch_dtype=torch.float16) 42 | pipe = pipe.to("cuda") 43 | 44 | @easydist_compile(use_hint=True) 45 | @torch.inference_mode() 46 | def sharded_unet(model, *args, **kwargs): 47 | return model(*args, **kwargs) 48 | 49 | pipe.unet.forward = functools.partial(sharded_unet, copy.copy(pipe.unet)) 50 | 51 | prompt = "a photo of Pride and Prejudice" 52 | image = pipe(prompt, width=1024, height=1024).images[0] 53 | image.save("pride_and_prejudice.png") 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /examples/torch/tensorfeild/cifar10.py: -------------------------------------------------------------------------------- 1 | # code modified from https://github.com/pytorch/tutorials/blob/main/beginner_source/blitz/cifar10_tutorial.py 2 | 3 | import random 4 | import time 5 | import argparse 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torchvision 11 | from torchvision.models import resnet18 12 | import torchvision.transforms as transforms 13 | 14 | from easydist.torch import tensorfield 15 | 16 | random.seed(42) 17 | torch.manual_seed(42) 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--tensorfield', action='store_true') 23 | parser.add_argument('--bs', default=512, type=int, help='Batch Size') 24 | parser.add_argument('--epochs', default=2, type=int, help='Number of Epochs') 25 | args = parser.parse_args() 26 | 27 | if args.tensorfield: 28 | tensorfield.init_tensorfield_allocator() 29 | 30 | transform = transforms.Compose( 31 | [transforms.ToTensor(), 32 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 33 | 34 | batch_size = args.bs 35 | epochs = args.epochs 36 | 37 | trainset = torchvision.datasets.CIFAR10(root='./data', 38 | train=True, 39 | download=True, 40 | transform=transform) 41 | trainloader = torch.utils.data.DataLoader(trainset, 42 | batch_size=batch_size, 43 | shuffle=True, 44 | drop_last=True, 45 | num_workers=2) 46 | 47 | net = resnet18().cuda() 48 | 49 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 50 | 51 | def train_step(net, optimizer, inputs, labels): 52 | 53 | criterion = nn.CrossEntropyLoss() 54 | 55 | outputs = net(inputs) 56 | loss = criterion(outputs, labels) 57 | loss.backward() 58 | optimizer.step() 59 | 60 | optimizer.zero_grad() 61 | 62 | return loss 63 | 64 | for epoch in range(epochs): # loop over the dataset multiple times 65 | 66 | torch.cuda.synchronize() 67 | start_t = time.perf_counter() 68 | 69 | running_loss = 0.0 70 | for i, data in enumerate(trainloader, 0): 71 | # get the inputs; data is a list of [inputs, labels] 72 | inputs, labels = data[0].cuda(), data[1].cuda() 73 | 74 | loss = train_step(net, optimizer, inputs, labels) 75 | # print statistics 76 | running_loss += loss.item() 77 | if i % 10 == 9: # print every 100 mini-batches 78 | print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}') 79 | running_loss = 0.0 80 | 81 | torch.cuda.synchronize() 82 | print("epoch time elapsed: ", time.perf_counter() - start_t) 83 | 84 | print('Finished Training') 85 | 86 | if args.tensorfield: 87 | tensorfield.finalize_tensorfield_allocator() 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /examples/torch/tensorfeild/matmul.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from easydist.torch import tensorfield 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--mnk', default=4096, type=int, help='Matrix Size (N x N)') 11 | parser.add_argument('--precision', 12 | default='fp16', 13 | type=str, 14 | help='Precision (fp16, fp32)', 15 | choices=['fp16', 'fp32']) 16 | parser.add_argument('--trials', default=10, type=int, help='Number of Trials to Execute') 17 | parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard') 18 | parser.add_argument('--tensorfield', action='store_true') 19 | parser.add_argument('--profile', action='store_true') 20 | args = parser.parse_args() 21 | 22 | if args.tensorfield: 23 | tensorfield.init_tensorfield_allocator() 24 | print("Tensorfield allocator initialized.") 25 | 26 | if args.profile: 27 | log_dir = f'./log/matmul_tfield_{args.tensorfield}_profile' 28 | prof = torch.profiler.profile(schedule=torch.profiler.schedule(wait=1, 29 | warmup=args.warmup_trials, 30 | active=args.trials, 31 | repeat=1), 32 | on_trace_ready=torch.profiler.tensorboard_trace_handler( 33 | log_dir, use_gzip=True), 34 | profile_memory=False, 35 | record_shapes=False, 36 | with_stack=False) 37 | prof.start() 38 | 39 | start_evt, end_evt = [], [] 40 | for _ in range(0, args.trials): 41 | start_evt.append(torch.cuda.Event(enable_timing=True)) 42 | end_evt.append(torch.cuda.Event(enable_timing=True)) 43 | 44 | for trial in range(0, args.trials + args.warmup_trials): 45 | evt_idx = trial - args.warmup_trials 46 | 47 | if evt_idx >= 0: 48 | start_evt[evt_idx].record() 49 | 50 | precision = torch.float32 if args.precision == 'fp32' else torch.float16 51 | 52 | tensor1 = torch.rand(args.mnk, args.mnk, device='cuda', dtype=precision) 53 | tensor2 = torch.rand(args.mnk, args.mnk, device='cuda', dtype=precision) 54 | _ = torch.mm(tensor1, tensor2) 55 | 56 | if evt_idx >= 0: 57 | end_evt[evt_idx].record() 58 | 59 | if args.profile: 60 | prof.step() 61 | 62 | torch.cuda.synchronize() 63 | 64 | if args.profile: 65 | prof.stop() 66 | 67 | elapsed_time_ms = np.zeros(args.trials) 68 | for trial in range(0, args.trials): 69 | elapsed_time_ms[trial] = start_evt[trial].elapsed_time(end_evt[trial]) 70 | 71 | print( 72 | f"Average time elapsed: {np.mean(elapsed_time_ms)} ms (variance: {np.var(elapsed_time_ms)}" 73 | ) 74 | 75 | if args.tensorfield: 76 | tensorfield.finalize_tensorfield_allocator() 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /examples/torch/tensorfeild/param_group_ray.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import ray 4 | import torch 5 | 6 | from easydist.torch.tensorfield import TFieldClient, init_on_tfeild, load_from_tfeild 7 | from easydist.torch.tensorfield.server import TFeildActor 8 | 9 | 10 | class SimpleNet(torch.nn.Module): 11 | 12 | def __init__(self): 13 | super(SimpleNet, self).__init__() 14 | self.fc1 = torch.nn.Linear(1024, 4096) 15 | self.fc2 = torch.nn.Linear(4096, 1024) 16 | 17 | def forward(self, x): 18 | x = self.fc1(x) 19 | x = self.fc2(x) 20 | return x 21 | 22 | 23 | @ray.remote 24 | class SimpleNetCreateActor: 25 | 26 | def __init__(self): 27 | self.device_id = int(os.environ.get('CUDA_VISIBLE_DEVICES')) 28 | self.socket_file = f"/tmp/tensorfield.{self.device_id}.sock" 29 | 30 | def entrypoint(self): 31 | client = TFieldClient(self.socket_file) 32 | model = SimpleNet() 33 | model = init_on_tfeild(client, model, "simple_net") 34 | client.close() 35 | return True 36 | 37 | 38 | @ray.remote 39 | class SimpleNetRunActor: 40 | 41 | def __init__(self): 42 | self.device_id = int(os.environ.get('CUDA_VISIBLE_DEVICES')) 43 | self.socket_file = f"/tmp/tensorfield.{self.device_id}.sock" 44 | 45 | def entrypoint(self): 46 | client = TFieldClient(self.socket_file) 47 | model = SimpleNet() 48 | model = load_from_tfeild(client, model, "simple_net", copy_weight=False) 49 | input_tensor = torch.rand(1024, 1024).cuda() 50 | output_tensor = model(input_tensor) 51 | print(output_tensor) 52 | client.close() 53 | return True 54 | 55 | 56 | @ray.remote 57 | class SimpleNetDestroyActor: 58 | 59 | def __init__(self): 60 | self.device_id = int(os.environ.get('CUDA_VISIBLE_DEVICES')) 61 | self.socket_file = f"/tmp/tensorfield.{self.device_id}.sock" 62 | 63 | def entrypoint(self): 64 | client = TFieldClient(self.socket_file) 65 | client.free_param_group("simple_net") 66 | client.close() 67 | return True 68 | 69 | 70 | def main(): 71 | ray.init() 72 | 73 | pg = ray.util.placement_group([{"CPU": 8, "GPU": 1}], strategy="STRICT_PACK") 74 | 75 | tfeild_actor = TFeildActor.options( 76 | num_cpus=4, 77 | num_gpus=0.75, 78 | scheduling_strategy=ray.util.scheduling_strategies.PlacementGroupSchedulingStrategy( 79 | placement_group=pg, placement_group_bundle_index=0), 80 | ).remote() 81 | 82 | tfeild_actor.start.remote() 83 | 84 | import time 85 | time.sleep(5) 86 | 87 | create_actor = SimpleNetCreateActor.options( 88 | num_cpus=4, 89 | num_gpus=0.25, 90 | scheduling_strategy=ray.util.scheduling_strategies.PlacementGroupSchedulingStrategy( 91 | placement_group=pg, placement_group_bundle_index=0), 92 | ).remote() 93 | 94 | run_actor = SimpleNetRunActor.options( 95 | num_cpus=4, 96 | num_gpus=0.25, 97 | scheduling_strategy=ray.util.scheduling_strategies.PlacementGroupSchedulingStrategy( 98 | placement_group=pg, placement_group_bundle_index=0), 99 | ).remote() 100 | 101 | destroy_actor = SimpleNetDestroyActor.options( 102 | num_cpus=4, 103 | num_gpus=0.25, 104 | scheduling_strategy=ray.util.scheduling_strategies.PlacementGroupSchedulingStrategy( 105 | placement_group=pg, placement_group_bundle_index=0), 106 | ).remote() 107 | 108 | return_code = create_actor.entrypoint.remote() 109 | print(ray.get(return_code)) 110 | 111 | ray.kill(create_actor) 112 | 113 | return_code = run_actor.entrypoint.remote() 114 | print(ray.get(return_code)) 115 | 116 | ray.kill(run_actor) 117 | 118 | return_code = destroy_actor.entrypoint.remote() 119 | print(ray.get(return_code)) 120 | 121 | ray.kill(destroy_actor) 122 | 123 | ray.kill(tfeild_actor) 124 | 125 | 126 | if __name__ == '__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /examples/torch/tensorfeild/param_group_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import copy 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from easydist.torch.tensorfield import TFieldClient, init_on_tfeild 9 | import easydist.config as mdconfig 10 | 11 | 12 | class SimpleNet(nn.Module): 13 | 14 | def __init__(self): 15 | super(SimpleNet, self).__init__() 16 | self.fc1 = nn.Linear(1024, 4096) 17 | self.fc2 = nn.Linear(4096, 1024) 18 | 19 | def forward(self, x): 20 | x = self.fc1(x) 21 | x = self.fc2(x) 22 | return x 23 | 24 | 25 | def main(): 26 | mdconfig.log_level = logging.DEBUG 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--socket-file', type=str, default='/tmp/tensorfield.sock') 30 | args = parser.parse_args() 31 | 32 | client = TFieldClient(args.socket_file) 33 | 34 | model = SimpleNet() 35 | model_reference = copy.deepcopy(model).cuda() 36 | 37 | model_tfeild = init_on_tfeild(client=client, model=model, param_group_name="param_group_test") 38 | 39 | reference_param = {name: param for name, param in model_reference.named_parameters()} 40 | tfeild_param = {name: param for name, param in model_tfeild.named_parameters()} 41 | 42 | for name, param in reference_param.items(): 43 | assert torch.allclose(param, tfeild_param[name]) 44 | 45 | input_tensor = torch.rand(1024, 1024).cuda() 46 | 47 | output_tfeild = model_reference(input_tensor) 48 | 49 | import time 50 | torch.cuda.synchronize() 51 | start = time.perf_counter() 52 | for i in range(10): 53 | output_reference = model_reference(input_tensor) 54 | torch.cuda.synchronize() 55 | print("reference time: ", time.perf_counter() - start) 56 | 57 | output_reference = model_tfeild(input_tensor) 58 | 59 | torch.cuda.synchronize() 60 | start = time.perf_counter() 61 | for i in range(10): 62 | output_tfeild = model_tfeild(input_tensor) 63 | 64 | torch.cuda.synchronize() 65 | print("tfeild time: ", time.perf_counter() - start) 66 | 67 | assert torch.allclose(output_reference, output_tfeild) 68 | 69 | client.free_param_group("param_group_test") 70 | client.close() 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /examples/torch/test_dynamo_export.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch._dynamo 5 | 6 | 7 | def foo(x, y): 8 | a = torch.sin(x) 9 | b = torch.cos(x) 10 | return a + b 11 | 12 | 13 | def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): 14 | print("custom backend called with FX graph:") 15 | print(gm.graph) 16 | return gm.forward 17 | 18 | 19 | model_exp = torch._dynamo.export(foo, 20 | torch.randn(10, 10), 21 | torch.randn(10, 10), 22 | aten_graph=True, 23 | tracing_mode="fake") 24 | model_exp[0].print_readable() 25 | # opt_foo1 = torch.compile(foo, backend=custom_backend, fullgraph=True) 26 | a, b = torch.randn(10, 10), torch.randn(10, 10) 27 | print(foo(a, b)) 28 | 29 | print(model_exp[0](a, b)) 30 | -------------------------------------------------------------------------------- /examples/tvm/test_simple.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import rich 4 | import tvm 5 | import tvm.testing 6 | from tvm import te 7 | import numpy as np 8 | 9 | import easydist as md 10 | 11 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', 12 | datefmt='%m/%d %H:%M:%S', 13 | level=logging.DEBUG) 14 | 15 | md.platform.init_backend("tvm") 16 | 17 | tgt = tvm.target.Target(target="llvm", host="llvm") 18 | 19 | n = te.var("n") 20 | A = te.placeholder((n, ), name="A") 21 | B = te.placeholder((n, ), name="B") 22 | C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") 23 | 24 | s = te.create_schedule(C.op) 25 | 26 | fadd = tvm.build(s, [A, B, C], tgt, name="myadd") 27 | 28 | 29 | def fadd_wrapped(a, b): 30 | c = md.platform.zeros_like(a) 31 | assert a.shape == b.shape 32 | fadd(a, b, c) 33 | return c 34 | 35 | 36 | dev = tvm.device(tgt.kind.name, 0) 37 | 38 | n = 1024 39 | a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) 40 | b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) 41 | 42 | meta_op = md.metashard.MetaOp(fadd_wrapped, ((a, b), {})) 43 | sharding_annotion, combination_ann = meta_op.sharding_discovery() 44 | 45 | rich.print(sharding_annotion) 46 | rich.print(combination_ann) 47 | 48 | c = fadd_wrapped(a, b) 49 | print(c) 50 | print(a.numpy() + b.numpy()) 51 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | torch: Tests for PyTorch 4 | all_platform: Tests for all platforms 5 | world_2: Tests for world size == 2 6 | world_3: Tests for world size == 3 7 | world_4: Tests for world size == 4 8 | world_8: Tests for world size == 8 9 | long_duration: Tests that take long time to complete 10 | -------------------------------------------------------------------------------- /requirements/core-requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml 2 | rich 3 | mip>=1.13 4 | sortedcontainers 5 | ortools 6 | pandas # for ortools 7 | pydot==1.4.2 8 | intervaltree==3.1.0 9 | matplotlib 10 | pybind11 11 | ninja 12 | nvidia-ml-py 13 | ray[default] 14 | cupy-cuda11x 15 | tqdm 16 | transformers 17 | bitarray 18 | 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import os 16 | from importlib.machinery import SourceFileLoader 17 | 18 | import setuptools 19 | 20 | version = ( 21 | SourceFileLoader("easydist.version", os.path.join( 22 | "easydist", "version.py")).load_module().VERSION 23 | ) 24 | 25 | 26 | def is_comment_or_empty(line): 27 | stripped = line.strip() 28 | return stripped == "" or stripped.startswith("#") 29 | 30 | 31 | def remove_comments_and_empty_lines(lines): 32 | return [line for line in lines if not is_comment_or_empty(line)] 33 | 34 | 35 | def get_core_requirements(): 36 | with open(os.path.join("requirements", "core-requirements.txt")) as f: 37 | core_requirements = remove_comments_and_empty_lines( 38 | f.read().splitlines()) 39 | return core_requirements 40 | 41 | 42 | def get_long_description(): 43 | with open("README.md", "r") as fh: 44 | long_description = fh.read() 45 | return long_description 46 | 47 | setuptools.setup( 48 | name="pai-easydist", 49 | version=version, 50 | author="Shenggan Cheng", 51 | author_email="shenggan.c@u.nus.edu", 52 | description="Efficient Automatic Training System for Super-Large Models", 53 | long_description=get_long_description(), 54 | long_description_content_type="text/markdown", 55 | url="https://github.com/alibaba/easydist", 56 | packages=setuptools.find_packages(), 57 | include_package_data=True, 58 | install_requires=get_core_requirements(), 59 | extras_require={ 60 | "torch": [ 61 | "torch", 62 | "torchvision", 63 | ], 64 | "jax": [ 65 | "jax[cuda11_pip]", 66 | "flax", 67 | ], 68 | "dev": [ 69 | "pre-commit", 70 | "autoflake", 71 | "isort", 72 | ] 73 | }, 74 | entry_points = { 75 | 'console_scripts': [ 76 | 'tfield-server = easydist.torch.tensorfield.server:main' 77 | ] 78 | } 79 | ) 80 | -------------------------------------------------------------------------------- /tests/test_combination/test_gather.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import numpy 16 | import pytest 17 | 18 | from easydist import platform 19 | from easydist.metashard.combination import CombinationFunc 20 | from easydist.utils.testing import ALL_PLATFORM 21 | from easydist import easydist_setup 22 | 23 | 24 | @pytest.mark.all_platform 25 | @pytest.mark.parametrize("backend", ALL_PLATFORM) 26 | def test_gather(backend): 27 | easydist_setup(backend) 28 | shard_tensor = [platform.from_numpy(numpy.ones((3, 4)))] * 4 29 | global_tensor_1 = platform.from_numpy(numpy.ones((12, 4))) 30 | gather_dim1 = CombinationFunc.gather(shard_tensor, dim=0) 31 | 32 | assert platform.allclose(global_tensor_1, gather_dim1) 33 | 34 | global_tensor_2 = platform.from_numpy(numpy.ones((3, 16))) 35 | gather_dim2 = CombinationFunc.gather(shard_tensor, dim=1) 36 | 37 | assert platform.allclose(global_tensor_2, gather_dim2) 38 | 39 | 40 | @pytest.mark.all_platform 41 | @pytest.mark.parametrize("backend", ALL_PLATFORM) 42 | def test_gather_halo(backend): 43 | easydist_setup(backend) 44 | shard_tensor = [platform.from_numpy(numpy.array([1, 1, 1]))] * 3 45 | global_tensor_1 = platform.from_numpy(numpy.array([1, 1, 2, 1, 2, 1, 1])) 46 | gather_halo_1 = CombinationFunc.gather(shard_tensor, dim=0, halowidth=1) 47 | 48 | assert platform.allclose(global_tensor_1, gather_halo_1) 49 | 50 | global_tensor_2 = platform.from_numpy(numpy.array([1, 1, 1, 1, 1])) 51 | gather_halo_2 = CombinationFunc.gather(shard_tensor, dim=0, halowidth=-1) 52 | 53 | assert platform.allclose(global_tensor_2, gather_halo_2) 54 | 55 | 56 | @pytest.mark.all_platform 57 | @pytest.mark.parametrize("backend", ALL_PLATFORM) 58 | def test_gather_chunk(backend): 59 | easydist_setup(backend) 60 | shard_tensor = [platform.from_numpy(numpy.array([1, 2, 3]))] * 3 61 | global_tensor = platform.from_numpy(numpy.array([1, 1, 1, 2, 2, 2, 3, 3, 3])) 62 | gather_chunk = CombinationFunc.gather(shard_tensor, dim=0, chunk=3) 63 | 64 | assert platform.allclose(global_tensor, gather_chunk) 65 | -------------------------------------------------------------------------------- /tests/test_combination/test_help_func.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import numpy 16 | import pytest 17 | from easydist import platform 18 | from easydist.metashard.combination import aligned_prefix, shape_aligned_otherdim 19 | 20 | 21 | @pytest.mark.all_platform 22 | def test_aligned_prefix(): 23 | t1 = platform.from_numpy(numpy.array([1, 2, 3, 4])) 24 | t2 = platform.from_numpy(numpy.array([1, 2, 3, 4])) 25 | assert 4 == aligned_prefix(t1, t2, dim_idx=0) 26 | 27 | t1 = platform.from_numpy(numpy.array([1, 2, 3, 4])) 28 | t2 = platform.from_numpy(numpy.array([2, 2, 3, 4])) 29 | assert 0 == aligned_prefix(t1, t2, dim_idx=0) 30 | 31 | t1 = platform.from_numpy(numpy.array([[1, 2, 3, 4], [1, 2, 3, 4]])) 32 | t2 = platform.from_numpy(numpy.array([[1, 2, 3, 4], [1, 2, 3, 5]])) 33 | assert 1 == aligned_prefix(t1, t2, dim_idx=0) 34 | assert 3 == aligned_prefix(t1, t2, dim_idx=1) 35 | 36 | 37 | @pytest.mark.all_platform 38 | def test_aligned_otherdim(): 39 | shape_1 = (10, 11, 12) 40 | shape_2 = (10, 13, 12) 41 | assert shape_aligned_otherdim(shape_1, shape_2, 1) == True 42 | assert shape_aligned_otherdim(shape_1, shape_2, 2) == False 43 | 44 | shape_1 = (10, 11, 12) 45 | shape_2 = (10, 13, 13) 46 | assert shape_aligned_otherdim(shape_1, shape_2, 1) == False 47 | assert shape_aligned_otherdim(shape_1, shape_2, 2) == False 48 | 49 | shape_1 = (10, 11, 12) 50 | shape_2 = (10, 11, 12, 13) 51 | assert shape_aligned_otherdim(shape_1, shape_2, 2) == False 52 | assert shape_aligned_otherdim(shape_1, shape_2, 3) == False 53 | -------------------------------------------------------------------------------- /tests/test_combination/test_identity.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import numpy 16 | import pytest 17 | 18 | from easydist import platform 19 | from easydist.metashard.combination import CombinationFunc 20 | from easydist.utils.testing import ALL_PLATFORM 21 | from easydist import easydist_setup 22 | 23 | 24 | @pytest.mark.all_platform 25 | @pytest.mark.parametrize("backend", ALL_PLATFORM) 26 | def test_identity(backend): 27 | easydist_setup(backend) 28 | shard_tensor = [platform.from_numpy(numpy.array([1, 2, 3]))] * 4 29 | global_tensor = platform.from_numpy(numpy.array([1, 2, 3])) 30 | combination_tensor = CombinationFunc.identity(shard_tensor) 31 | 32 | assert platform.allclose(global_tensor, combination_tensor) 33 | 34 | 35 | @pytest.mark.all_platform 36 | @pytest.mark.parametrize("backend", ALL_PLATFORM) 37 | def test_identity_2(backend): 38 | easydist_setup(backend) 39 | shard_tensor = [platform.from_numpy(numpy.array([1, 2, 3]))] * 4 40 | global_tensor = platform.from_numpy(numpy.array([1, 2, 4])) 41 | combination_tensor = CombinationFunc.identity(shard_tensor) 42 | 43 | assert not platform.allclose(global_tensor, combination_tensor) 44 | 45 | 46 | @pytest.mark.all_platform 47 | @pytest.mark.parametrize("backend", ALL_PLATFORM) 48 | def test_identity_3(backend): 49 | easydist_setup(backend) 50 | shard_tensor = [platform.from_numpy(numpy.array([1, 2, 3]))] * 3 + [ 51 | platform.from_numpy(numpy.array([1, 2, 4])) 52 | ] 53 | combination_tensor = CombinationFunc.identity(shard_tensor) 54 | 55 | assert combination_tensor is None 56 | -------------------------------------------------------------------------------- /tests/test_combination/test_reduce.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import numpy 16 | import pytest 17 | 18 | from easydist import platform 19 | from easydist.metashard.combination import CombinationFunc, ReduceOp 20 | from easydist.utils.testing import ALL_PLATFORM 21 | from easydist import easydist_setup 22 | 23 | @pytest.mark.all_platform 24 | @pytest.mark.parametrize("backend", ALL_PLATFORM) 25 | def test_reduce(backend): 26 | easydist_setup(backend) 27 | shard_tensor = [platform.from_numpy(numpy.array([i, i, i])) for i in range(4)] 28 | max_tensor = platform.from_numpy(numpy.array([3, 3, 3])) 29 | combination_max = CombinationFunc.reduce(shard_tensor, ops=ReduceOp.MAX) 30 | 31 | assert platform.allclose(max_tensor, combination_max) 32 | 33 | min_tensor = platform.from_numpy(numpy.array([0, 0, 0])) 34 | combination_min = CombinationFunc.reduce(shard_tensor, ops=ReduceOp.MIN) 35 | 36 | assert platform.allclose(min_tensor, combination_min) 37 | 38 | sum_tensor = platform.from_numpy(numpy.array([6, 6, 6])) 39 | combination_sum = CombinationFunc.reduce(shard_tensor, ops=ReduceOp.SUM) 40 | 41 | assert platform.allclose(sum_tensor, combination_sum) 42 | -------------------------------------------------------------------------------- /tests/test_combination/test_try_combination_single.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import numpy 16 | import pytest 17 | import functools 18 | 19 | from easydist import platform 20 | import easydist.config as mdconfig 21 | from easydist.metashard.combination import CombinationFunc, ReduceOp, try_combination_single 22 | from easydist.utils.testing import ALL_PLATFORM, assert_partial_func_equal 23 | from easydist import easydist_setup 24 | 25 | 26 | @pytest.mark.all_platform 27 | @pytest.mark.parametrize("backend", ALL_PLATFORM) 28 | def test_reduce(backend): 29 | easydist_setup(backend) 30 | shard_tensor = [platform.from_numpy(numpy.random.uniform(size=(3, 4))) for _ in range(4)] 31 | 32 | for op_type in [ReduceOp.MAX, ReduceOp.MIN, ReduceOp.SUM]: 33 | comb_func = functools.partial(CombinationFunc.reduce, ops=op_type) 34 | global_tensor = comb_func(shard_tensor) 35 | 36 | return_func = try_combination_single(shard_tensor, global_tensor) 37 | 38 | assert_partial_func_equal(comb_func, return_func) 39 | 40 | 41 | @pytest.mark.all_platform 42 | @pytest.mark.parametrize("backend", ALL_PLATFORM) 43 | def test_gather(backend): 44 | easydist_setup(backend) 45 | shard_tensor = [platform.from_numpy(numpy.random.uniform(size=(3, 4))) for _ in range(4)] 46 | 47 | for dim_ in [0, 1]: 48 | comb_func = functools.partial(CombinationFunc.gather, dim=dim_) 49 | global_tensor = comb_func(shard_tensor) 50 | 51 | return_func = try_combination_single(shard_tensor, global_tensor) 52 | 53 | assert_partial_func_equal(comb_func, return_func) 54 | 55 | 56 | @pytest.mark.all_platform 57 | @pytest.mark.parametrize("backend", ALL_PLATFORM) 58 | def test_gather_halo(backend): 59 | easydist_setup(backend) 60 | mdconfig.extend_space = True 61 | 62 | shard_tensor = [platform.from_numpy(numpy.random.uniform(size=(3, 4))) for _ in range(3)] 63 | 64 | for dim_, halo_ in zip([0, 1], [1, 2]): 65 | comb_func = functools.partial(CombinationFunc.gather, dim=dim_, halowidth=halo_) 66 | global_tensor = comb_func(shard_tensor) 67 | 68 | return_func = try_combination_single(shard_tensor, global_tensor) 69 | 70 | assert_partial_func_equal(comb_func, return_func) 71 | 72 | 73 | @pytest.mark.all_platform 74 | @pytest.mark.parametrize("backend", ALL_PLATFORM) 75 | def test_gather_chunk(backend): 76 | easydist_setup(backend) 77 | mdconfig.extend_space = True 78 | 79 | shard_tensor = [platform.from_numpy(numpy.random.uniform(size=(3, 4))) for _ in range(3)] 80 | 81 | for dim_, chunk_ in zip([0, 1], [3, 2]): 82 | comb_func = functools.partial(CombinationFunc.gather, dim=dim_, chunk=chunk_) 83 | global_tensor = comb_func(shard_tensor) 84 | 85 | return_func = try_combination_single(shard_tensor, global_tensor) 86 | 87 | assert_partial_func_equal(comb_func, return_func) 88 | -------------------------------------------------------------------------------- /tests/test_strategy/jax/simple_function1.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from easydist import easydist_setup, mdconfig 8 | from easydist.jax.api import easydist_compile 9 | 10 | 11 | @easydist_compile(compile_only=True) 12 | def foo_func(x, y): 13 | tanh = jnp.tanh(x) 14 | return jnp.exp(tanh) @ y + tanh 15 | 16 | 17 | def main(): 18 | mdconfig.log_level = logging.INFO 19 | easydist_setup(backend="jax", device="cuda") 20 | 21 | key = jax.random.PRNGKey(0) 22 | key, subkey = jax.random.split(key) 23 | randn_x = jax.random.normal(key, (10, 10)) 24 | randn_y = jax.random.normal(subkey, (10, 10)) 25 | 26 | foo_func(randn_x, randn_y) 27 | 28 | 29 | if __name__ == '__main__': 30 | main() 31 | -------------------------------------------------------------------------------- /tests/test_strategy/jax/test_simple_function1.sh: -------------------------------------------------------------------------------- 1 | export ENABLE_GRAPH_COARSEN="True" 2 | export COARSEN_LEVEL=0 3 | 4 | mpirun -np 2 python simple_function1.py |& tee 1n2g.log 5 | 6 | expected_cost=0.0 7 | costs=`grep -e "\[Communication Cost\]:[^\n]*" 1n2g.log -o | grep -e "[0-9.]*" -o | awk '{print $1}'`; 8 | cost=`echo $costs | awk '{print $1}'`; 9 | echo -e "\n*****************************************" 10 | echo "Communication cost is $cost" 11 | 12 | if [ `echo "${cost} > ${expected_cost}"|bc` -eq 1 ]; 13 | then echo -e "Failed!\nExpected communication cost is ${expected_cost}." 14 | else echo "Successful!" 15 | fi 16 | echo -e "*****************************************\n" 17 | 18 | 19 | -------------------------------------------------------------------------------- /tests/test_strategy/torch/test_simple_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.distributed 4 | 5 | from easydist import easydist_setup 6 | from easydist.torch.api import easydist_compile 7 | from easydist.torch.device_mesh import set_device_mesh, NDDeviceMesh 8 | from easydist.utils.testing import spawn 9 | from torch.distributed._tensor import DeviceMesh 10 | 11 | 12 | class Foo(torch.nn.Module): 13 | 14 | def __init__(self): 15 | super().__init__() 16 | self.norm = torch.nn.LayerNorm(8) 17 | self.linear = torch.nn.Linear(8, 8) 18 | 19 | def forward(self, x): 20 | x = self.norm(x) 21 | x = self.linear(x) 22 | return x.relu() 23 | 24 | 25 | @easydist_compile(tracing_mode="fake", cuda_graph=False, compile_only=True) 26 | def train_step(input, model, opt): 27 | out = model(input).mean() 28 | out.backward() 29 | opt.step() 30 | opt.zero_grad(True) 31 | 32 | return out 33 | 34 | 35 | def train_example(): 36 | 37 | torch.ones(1).cuda() 38 | with torch.device('cuda'): 39 | model = Foo() 40 | randn_input = torch.randn(16, 8) 41 | 42 | torch.distributed.broadcast(randn_input, src=0) 43 | 44 | opt = torch.optim.Adam(model.parameters(), lr=0.001, foreach=True, capturable=True) 45 | 46 | # trace train step func 47 | mesh = NDDeviceMesh(DeviceMesh( 48 | "cuda", [0, 1], mesh_dim_names=["spmd"] 49 | )) 50 | set_device_mesh(mesh) 51 | 52 | train_step(randn_input, model, opt) 53 | 54 | def main(): 55 | # setting up easydist and torch.distributed 56 | easydist_setup(backend="torch", device="cuda", allow_tf32=False) 57 | torch.cuda.set_device(torch.distributed.get_rank()) 58 | train_example() 59 | 60 | @pytest.mark.torch 61 | def test_simple_model(): 62 | spawn(main, nprocs=2) 63 | 64 | if __name__ == "__main__": 65 | test_simple_model() 66 | -------------------------------------------------------------------------------- /tests/test_torch/test_pp/test_reslink.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import random 16 | 17 | import numpy as np 18 | 19 | import torch 20 | import torch.distributed as dist 21 | 22 | from torch.distributed._tensor import DeviceMesh 23 | from tqdm import tqdm 24 | 25 | from easydist import easydist_setup 26 | from easydist.torch.api import easydist_compile 27 | from easydist.torch.utils import seed 28 | from easydist.torch.device_mesh import set_device_mesh 29 | from easydist.torch.experimental.pp.runtime import ScheduleDAPPLE, ScheduleGPipe 30 | from easydist.torch.experimental.pp.compile_pipeline import annotate_split_points 31 | from easydist.utils.testing import spawn 32 | 33 | import pytest 34 | 35 | class Foo(torch.nn.Module): 36 | 37 | def __init__(self): 38 | super().__init__() 39 | self.layer0 = torch.nn.Linear(1024, 1024) 40 | self.layer1 = torch.nn.Linear(1024, 1024) 41 | self.layer2 = torch.nn.Linear(1024, 1024) 42 | self.layer3 = torch.nn.Linear(1024, 1) 43 | 44 | def forward(self, x): 45 | res = self.layer0(x) 46 | x = self.layer1(res) 47 | x = self.layer2(x) 48 | x = self.layer3(x + res) 49 | return x 50 | 51 | 52 | def main(schedule_cls): 53 | rank = dist.get_rank() 54 | world_size = dist.get_world_size() 55 | pp_size = world_size 56 | per_chunk_sz = 1 57 | num_chunks = 16 58 | batch_size = per_chunk_sz * num_chunks 59 | seed(42) 60 | easydist_setup(backend="torch", device="cuda", allow_tf32=False) 61 | 62 | device = torch.device('cuda') 63 | torch.cuda.set_device(rank) 64 | 65 | set_device_mesh(DeviceMesh("cuda", torch.arange(pp_size), mesh_dim_names=['pp'])) 66 | 67 | module = Foo().train().to(device) 68 | opt = torch.optim.Adam(module.parameters(), foreach=True, capturable=True) 69 | 70 | annotate_split_points(module, {'layer0', 'layer1', 'layer2'}) 71 | 72 | @easydist_compile(parallel_mode="pp", 73 | tracing_mode="fake", 74 | cuda_graph=False, 75 | schedule_cls=schedule_cls, 76 | num_chunks=num_chunks, 77 | return_to_all_stages=False) 78 | def train_step(input, label, model, opt): 79 | out = model(input) 80 | loss = out.mean() 81 | loss.backward() 82 | opt.step() 83 | opt.zero_grad() 84 | return out, loss 85 | 86 | dataset_size = 100 87 | train_dataloader = [(torch.randn( 88 | batch_size, 1024, device=device), torch.randint(0, 10, (batch_size, ), device=device)) 89 | ] * (dataset_size // batch_size) 90 | 91 | x_batch, y_batch = next(iter(train_dataloader)) 92 | epochs = 1 93 | 94 | for _ in range(epochs): 95 | for x_batch, y_batch in tqdm(train_dataloader, 96 | dynamic_ncols=True) if rank == 0 else train_dataloader: 97 | _ = train_step(x_batch, y_batch, module, opt) 98 | 99 | @pytest.mark.torch 100 | @pytest.mark.world_4 101 | @pytest.mark.parametrize("schedule_cls", [ScheduleGPipe, ScheduleDAPPLE]) 102 | @pytest.mark.timeout(50) 103 | def test_reslink(schedule_cls): 104 | spawn(main, (schedule_cls,), nprocs=4) 105 | -------------------------------------------------------------------------------- /tests/test_torch/test_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import pytest 16 | import torch 17 | import functorch 18 | from functorch.compile import aot_function 19 | import rich 20 | 21 | from easydist.utils.testing.mock import TorchMockDeviceMesh 22 | from easydist.torch import EDTorchShardingAnn, set_device_mesh 23 | from easydist.torch.passes import fix_addmm_bias, eliminate_detach 24 | from easydist import easydist_setup 25 | 26 | 27 | def fn_1(x, y): 28 | return torch.concat([x, y], dim=1) 29 | 30 | 31 | def fn_2(x, y): 32 | return torch.mm(torch.exp(torch.tanh(x)), y) 33 | 34 | 35 | @functorch.compile.make_boxed_compiler 36 | def compiler_fn(fx_module: torch.fx.GraphModule, inps): 37 | fx_module = fix_addmm_bias(fx_module) 38 | fx_module = eliminate_detach(fx_module) 39 | fx_module.recompile() 40 | print(fx_module.graph) 41 | 42 | sharding_interpreter = EDTorchShardingAnn(fx_module) 43 | sharding_info, fwd_shape_info = sharding_interpreter.run(*inps) 44 | rich.print("sharding_info:\n", sharding_info) 45 | rich.print("fwd_shape_info:\n", fwd_shape_info) 46 | 47 | return fx_module 48 | 49 | @pytest.mark.skip 50 | @pytest.mark.parametrize("fn", [fn_1, fn_2]) 51 | def test_simple_case(fn): 52 | easydist_setup("torch") 53 | 54 | mock_mesh = TorchMockDeviceMesh(1, 2, debug_only=True) 55 | set_device_mesh(mock_mesh) 56 | 57 | x = torch.randn(10, 10, requires_grad=True) 58 | y = torch.randn(10, 10, requires_grad=True) 59 | aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn) 60 | res = aot_print_fn(x, y) 61 | 62 | grad_res = torch.ones_like(res) 63 | res.backward(grad_res) 64 | 65 | 66 | if __name__ == '__main__': 67 | test_simple_case() 68 | -------------------------------------------------------------------------------- /tests/test_torch/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributed.distributed_c10d import _get_default_group 3 | from torch.distributed.utils import _sync_module_states 4 | 5 | from easydist.utils import rgetattr, rsetattr 6 | 7 | from benchmark.torch.model.gpt import GPT 8 | from benchmark.bench_case import GPTCase 9 | 10 | def train_step(input, model, opt): 11 | out = model(input) 12 | loss = out.mean() 13 | loss.backward() 14 | if opt: 15 | opt.step() 16 | opt.zero_grad() 17 | return out 18 | 19 | 20 | def fw_bw_step(input, model): 21 | out = model(input) 22 | loss = out.mean() 23 | loss.backward() 24 | grads = {k: p.grad.clone().detach() for k, p in model.named_parameters()} 25 | return out, grads 26 | 27 | 28 | def train_step_chunked(input, model, opt, num_chunks, show_micro_grad=False): 29 | output, prev_grads, micro_batch_grads = [], None, [] 30 | for chunk in input.chunk(num_chunks): 31 | out, grads = fw_bw_step(chunk, model) 32 | output.append(out) 33 | if prev_grads is None: 34 | micro_batch_grads.append(grads) 35 | else: 36 | micro_batch_grads.append({k: grads[k] - prev_grads[k] for k in grads}) 37 | prev_grads = grads 38 | 39 | output = torch.concat(output) 40 | opt.step() 41 | opt.zero_grad() 42 | return output, micro_batch_grads, prev_grads 43 | 44 | 45 | def broadcast_module(model): 46 | _sync_module_states(model, 47 | _get_default_group(), 48 | broadcast_bucket_size=int(250 * 1024 * 1024), 49 | src=0, 50 | params_and_buffers_to_ignore=set()) 51 | 52 | return model 53 | 54 | 55 | TEST_GPT_CASE = GPTCase( 56 | num_layers=4, 57 | hidden_dim=128, 58 | num_heads=4, 59 | seq_size=128 60 | ) 61 | 62 | class TEST_GPT(GPT): 63 | def __init__(self): 64 | super().__init__( 65 | depth=TEST_GPT_CASE.num_layers, 66 | dim=TEST_GPT_CASE.hidden_dim, 67 | num_heads=TEST_GPT_CASE.num_heads 68 | ) 69 | 70 | 71 | def get_module_opt_states(module, opt, init_opt_state): 72 | params = dict(module.named_parameters()) 73 | buffers = dict(module.named_buffers()) 74 | named_states = {} 75 | 76 | if init_opt_state: 77 | # assign grad and warm up optimizer 78 | for name in dict(module.named_parameters()): 79 | with torch.no_grad(): 80 | rsetattr(module, name + ".grad", torch.zeros_like(rgetattr(module, name).data)) 81 | 82 | opt.step() 83 | opt.zero_grad(True) 84 | 85 | for n, p in params.items(): 86 | if p in opt.state: 87 | named_states[n] = opt.state[p] # type: ignore[index] 88 | # if step in state, reduce one for warmup step. 89 | if init_opt_state and 'step' in named_states[n]: 90 | named_states[n]['step'] -= 1 91 | 92 | return params, buffers, named_states 93 | 94 | 95 | class Foo(torch.nn.Module): 96 | 97 | def __init__(self): 98 | super().__init__() 99 | self.norm = torch.nn.BatchNorm1d(1024) 100 | self.linear = torch.nn.Linear(1024, 1024) 101 | 102 | def forward(self, x): 103 | x = self.norm(x) 104 | x = self.linear(x) 105 | return x 106 | 107 | 108 | class Foo1(torch.nn.Module): 109 | 110 | def __init__(self): 111 | super().__init__() 112 | self.norm = torch.nn.BatchNorm1d(1024) 113 | self.linear0_0 = torch.nn.Linear(1024, 512) 114 | self.linear0_1 = torch.nn.Linear(512, 256) 115 | self.linear1 = torch.nn.Linear(256, 1024) 116 | 117 | def forward(self, x): 118 | x = self.norm(x) 119 | x0 = self.linear0_0(x) 120 | x0 = self.linear0_1(x0) 121 | x1 = self.linear1(x0) 122 | y = x + x1 123 | return y 124 | -------------------------------------------------------------------------------- /tests/test_unfiyshard/test_unifyop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import functools 16 | 17 | import pytest 18 | import torch 19 | 20 | from easydist.metashard.combination import CombinationFunc 21 | from easydist.metashard import ShardAnnotation, ShardDim, MetaOp 22 | from easydist.utils.testing.mock import assert_partial_func_equal 23 | from easydist import easydist_setup 24 | 25 | @pytest.mark.torch 26 | def test_metaop_preset(): 27 | easydist_setup("torch") 28 | input_args = (torch.rand((3, 4, 768)), 3, 2), {} 29 | meta_op = MetaOp(torch.ops.aten.chunk, input_args) 30 | preset_anno = ShardAnnotation([[ShardDim(0), ShardDim(0), ShardDim(1, chunk=3)]]) 31 | comb_func = meta_op.sharding_discovery_with_preset(preset_anno) 32 | 33 | right_answer = [functools.partial(CombinationFunc.gather, dim=2)] * 3 34 | 35 | assert comb_func != None 36 | assert len(comb_func) == len(right_answer) 37 | 38 | for func1, func2 in zip(comb_func, right_answer): 39 | assert_partial_func_equal(func1, func2) 40 | -------------------------------------------------------------------------------- /tests/test_unfiyshard/test_view_propagation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Alibaba Group; 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | import functools 16 | 17 | import pytest 18 | 19 | from easydist.metashard import ShardAnnotation, ShardDim 20 | from easydist.metashard.combination import CombinationFunc 21 | from easydist.metashard.view_propagation import view_propagation_preset 22 | from easydist.utils.testing.mock import assert_partial_func_equal 23 | 24 | @pytest.mark.torch 25 | def test_view_propagation_preset(): 26 | preset_anno = ShardAnnotation([[ShardDim(1, chunk=5), ShardDim(0)]]) 27 | comb_func = view_propagation_preset([10, 8], [5, 2, 8], preset_anno) 28 | answer = functools.partial(CombinationFunc.gather, dim=1) 29 | assert_partial_func_equal(comb_func, answer) 30 | 31 | preset_anno = ShardAnnotation([[ShardDim(0), ShardDim(1, chunk=2)]]) 32 | comb_func = view_propagation_preset([10, 8], [10, 2, 2, 2], preset_anno) 33 | answer = functools.partial(CombinationFunc.gather, dim=2) 34 | assert_partial_func_equal(comb_func, answer) 35 | 36 | preset_anno = ShardAnnotation([[ShardDim(0), ShardDim(1, chunk=4)]]) 37 | comb_func = view_propagation_preset([10, 8], [10, 2, 2, 2], preset_anno) 38 | answer = functools.partial(CombinationFunc.gather, dim=3) 39 | assert_partial_func_equal(comb_func, answer) 40 | 41 | preset_anno = ShardAnnotation([[ShardDim(1, chunk=3), ShardDim(0)]]) 42 | comb_func = view_propagation_preset([10, 8], [5, 2, 8], preset_anno) 43 | assert comb_func is None 44 | 45 | preset_anno = ShardAnnotation([[ShardDim(0), ShardDim(1, chunk=2)]]) 46 | comb_func = view_propagation_preset([10, 8], [5, 2, 8], preset_anno) 47 | assert comb_func is None 48 | --------------------------------------------------------------------------------