├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── NOTICE
├── README.md
├── assets
└── arch.svg
├── benchmark
├── README.md
├── bench_case.py
├── jax
│ ├── bench_jax.py
│ ├── bench_jax_alpa.py
│ ├── bench_jax_dp.py
│ └── model
│ │ ├── __init__.py
│ │ ├── gat.py
│ │ ├── gpt.py
│ │ ├── resnet.py
│ │ └── wresnet.py
└── torch
│ ├── bench_torch.py
│ ├── bench_torch_tp.py
│ ├── model
│ ├── __init__.py
│ ├── gat.py
│ ├── gpt.py
│ ├── gpt_tp.py
│ └── wresnet.py
│ └── pp
│ ├── gpt
│ └── speed
│ │ ├── batch.sh
│ │ ├── easydist_pipeline.py
│ │ ├── torchgpipe_pipeline.py
│ │ └── vanilla_torch.py
│ └── resnet101
│ ├── accuracy
│ ├── pipeline.py
│ └── vanilla_torch.py
│ └── speed
│ ├── batch.sh
│ ├── easydist_pipeline.py
│ ├── resnet
│ ├── __init__.py
│ ├── bottleneck.py
│ └── flatten_sequential.py
│ ├── torchgpipe_pipeline.py
│ └── vanila_torch.py
├── easydist
├── __init__.py
├── autoflow
│ ├── __init__.py
│ └── solver.py
├── config.py
├── jax
│ ├── __init__.py
│ ├── api.py
│ ├── bridge.py
│ ├── device_mesh.py
│ ├── sharding_interpreter.py
│ └── utils.py
├── metashard
│ ├── __init__.py
│ ├── annotation.py
│ ├── combination.py
│ ├── halo.py
│ ├── metair.py
│ ├── metaop.py
│ └── view_propagation.py
├── platform
│ ├── __init__.py
│ ├── jax.py
│ ├── torch.py
│ └── tvm.py
├── torch
│ ├── __init__.py
│ ├── api.py
│ ├── bridge.py
│ ├── compile.py
│ ├── compile_auto.py
│ ├── compile_dp.py
│ ├── cuda
│ │ ├── __init__.py
│ │ ├── mem_allocator.py
│ │ └── scheduled_graph_drawer.py
│ ├── decomp_utils.py
│ ├── device_mesh.py
│ ├── experimental
│ │ ├── __init__.py
│ │ └── pp
│ │ │ ├── __init__.py
│ │ │ ├── api.py
│ │ │ ├── compile_pipeline.py
│ │ │ ├── ed_split_module.py
│ │ │ ├── microbatch.py
│ │ │ ├── runtime.py
│ │ │ ├── split_utils.py
│ │ │ └── utils.py
│ ├── graph_profile_db.py
│ ├── init_helper.py
│ ├── mem_allocation_info.py
│ ├── mem_anaylize.py
│ ├── meta_allocator.py
│ ├── passes
│ │ ├── __init__.py
│ │ ├── allocator_profiler.py
│ │ ├── comm_optimize.py
│ │ ├── edinfo_utils.py
│ │ ├── eliminate_detach.py
│ │ ├── fix_bias.py
│ │ ├── fix_embedding.py
│ │ ├── fix_meta_device.py
│ │ ├── fix_view.py
│ │ ├── pp_passes.py
│ │ ├── process_tag.py
│ │ ├── rule_override.py
│ │ ├── runtime_prof.py
│ │ ├── sharding.py
│ │ └── tile_comm.py
│ ├── preset_propagation.py
│ ├── profiler
│ │ ├── __init__.py
│ │ ├── csrc
│ │ │ ├── cupti_callback_api.cpp
│ │ │ ├── cupti_callback_api.h
│ │ │ ├── effective_cuda_allocator.cpp
│ │ │ ├── effective_cuda_allocator.h
│ │ │ ├── profiling_allocator.cpp
│ │ │ ├── profiling_allocator.h
│ │ │ ├── python_tracer_init.cpp
│ │ │ ├── stream_tracer.cpp
│ │ │ └── stream_tracer.h
│ │ └── stream_tracer.py
│ ├── reachability.py
│ ├── schedule
│ │ ├── __init__.py
│ │ ├── efficient_memory_scheduler.py
│ │ ├── graph_mem_plan.py
│ │ ├── ilp_memory_scheduler.py
│ │ ├── joint_learning.py
│ │ ├── lifetime_info.py
│ │ ├── memory_scheduler.py
│ │ ├── rcpsp.py
│ │ └── schedule_result.py
│ ├── scope_auto
│ │ ├── build_scope_modules.py
│ │ └── scope_marker.py
│ ├── sharding_interpreter.py
│ ├── split_utils.py
│ ├── spmd_prop_rule.py
│ ├── symphonia
│ │ ├── __init__.py
│ │ └── torch_actor.py
│ ├── tensorfield
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── csrc
│ │ │ └── allocator_interface.cpp
│ │ ├── helper.py
│ │ ├── interface.py
│ │ ├── mem_pool.py
│ │ └── server.py
│ └── utils.py
├── utils
│ ├── __init__.py
│ ├── testing
│ │ ├── __init__.py
│ │ ├── mock.py
│ │ └── spawn.py
│ └── timer.py
└── version.py
├── examples
├── README.md
├── jax
│ ├── simple_function.py
│ └── simple_model.py
├── torch
│ ├── bert_train.py
│ ├── cifar10.py
│ ├── gnn
│ │ ├── data.py
│ │ ├── gat.py
│ │ └── train.py
│ ├── gpt_train.py
│ ├── resnet18.py
│ ├── resnet_train.py
│ ├── simple_ddp.py
│ ├── simple_function.py
│ ├── simple_model.py
│ ├── stable_diffusion.py
│ ├── tensorfeild
│ │ ├── cifar10.py
│ │ ├── cifar10_ray.py
│ │ ├── matmul.py
│ │ ├── param_group_ray.py
│ │ └── param_group_test.py
│ └── test_dynamo_export.py
└── tvm
│ └── test_simple.py
├── pytest.ini
├── requirements
└── core-requirements.txt
├── setup.py
├── style.cfg
└── tests
├── test_combination
├── test_gather.py
├── test_help_func.py
├── test_identity.py
├── test_reduce.py
└── test_try_combination_single.py
├── test_scope_auto
├── mark_scope1.py
├── mark_scope2.py
└── mult_mesh1.py
├── test_strategy
├── jax
│ ├── simple_function1.py
│ └── test_simple_function1.sh
└── torch
│ └── test_simple_model.py
├── test_torch
├── test_hybrid.py
├── test_pp
│ ├── test_reslink.py
│ ├── test_runtime.py
│ └── test_split.py
├── test_simple.py
├── test_spmd.py
└── test_utils.py
└── test_unfiyshard
├── test_unifyop.py
└── test_view_propagation.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # log file of MetaDist
2 | *.metair
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # vscode
135 | .vscode/
136 |
137 | # setup
138 | dist/
139 | build/
140 |
141 | # temporary files
142 | md_compiled
143 | tmp
144 |
145 | # graphviz output files
146 | *.dot
147 |
148 | # cached datasets
149 | data
150 |
151 | # pt files
152 | *.pt
153 |
154 | # txt files
155 | *.txt
156 |
157 | # svg files
158 | *.svg
159 |
160 | # log folder
161 | log
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 |
3 | - repo: https://github.com/PyCQA/autoflake
4 | rev: v2.3.1
5 | hooks:
6 | - id: autoflake
7 | name: remove unused variables and imports
8 | args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
9 |
10 | - repo: https://github.com/pre-commit/mirrors-yapf
11 | rev: v0.32.0
12 | hooks:
13 | - id: yapf
14 | name: format python code
15 | entry: yapf
16 | language: python
17 | files: \.py$
18 | types: [python]
19 | args: ["--style=style.cfg"]
20 |
21 | - repo: https://github.com/pycqa/isort
22 | rev: 5.13.2
23 | hooks:
24 | - id: isort
25 | name: sort python imports
26 | args: ["--profile", "black"]
27 |
28 | - repo: https://github.com/pre-commit/pre-commit-hooks
29 | rev: v4.6.0
30 | hooks:
31 | - id: check-yaml
32 | - id: check-merge-conflict
33 | - id: check-case-conflict
34 | - id: trailing-whitespace
35 | - id: end-of-file-fixer
36 | - id: mixed-line-ending
37 | args: ['--fix=lf']
38 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | Alibaba has adopted a Code of Conduct that we expect project participants to adhere to.
4 |
5 | Please refer to [Alibaba Open Source Code of Conduct](https://github.com/AlibabaDR/community/blob/master/CODE_OF_CONDUCT.md).
6 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## Contributing guidelines
2 |
3 | Contributors are welcome to submit their code and ideas. In a long run, we hope this project can be managed by developers from both inside and outside Alibaba.
4 |
5 | ### Contributor License Agreements
6 |
7 | * Sign CLA of EasyDist:
8 | Please download EasyDist [CLA](https://gist.github.com/alibaba-oss/151a13b0a72e44ba471119c7eb737d74). Follow the instructions to sign it.
9 |
10 | ### Pull Request Checklist
11 |
12 | Here is a checklist to prepare and submit your PR (pull request).
13 |
14 | * Create your own Github branch by forking EasyDist.
15 | * Read the [README](README.md).
16 | * Read the [contributing guidelines](CONTRIBUTING.md).
17 | * Read the [Code of Conduct](CODE_OF_CONDUCT.md).
18 | * Ensure you have signed the
19 | [Contributor License Agreement (CLA)](https://gist.github.com/alibaba-oss/151a13b0a72e44ba471119c7eb737d74).
20 | * Push changes to your personal fork.
21 | * Create a PR with a detail description, if commit messages do not express themselves.
22 | * Submit PR for review and address all feedbacks.
23 | * Wait for merging (done by committers).
24 |
25 | Let's use an example to walk through the list.
26 |
27 | ## An Example of Submitting Code Change to EasyDist
28 |
29 | ### Fork Your Own Branch
30 |
31 | On Github page of [EasyDist](https://github.com/alibaba/easydist), Click **fork** button to create your own easydist repository.
32 |
33 | ### Create Local Repository
34 | ```bash
35 | git clone --recursive https://github.com/your_github/easydist.git
36 | ```
37 | ### Create a dev Branch (named as your_github_id_feature_name)
38 | ```bash
39 | git branch your_github_id_feature_name
40 | ```
41 | ### Make Changes and Commit Locally
42 | ```bash
43 | git status
44 | git add files-to-change
45 | git commit -m "messages for your modifications"
46 | ```
47 |
48 | ### Rebase and Commit to Remote Repository
49 | ```bash
50 | git checkout main
51 | git pull
52 | git checkout your_github_id_feature_name
53 | git rebase main
54 | -- resolve conflict, run test --
55 | git push --recurse-submodules=on-demand origin your_github_id_feature_name
56 | ```
57 |
58 | ### Create a PR
59 | Click **New pull request** or **Compare & pull request** button, choose to compare branches easydist/main and your_github/your_github_id_feature_name, and write PR description.
60 |
61 | ### Address Reviewers' Comments
62 | Resolve all problems raised by reviewers and update PR.
63 |
64 | ### Merge
65 | It is done by EasyDist committers.
66 | ___
67 |
68 | Copyright © Alibaba Group, Inc.
69 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | recursive-include easydist *.cpp *.h
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EasyDist
2 |
3 | EasyDist is an automated parallelization system and infrastructure designed for multiple ecosystems, offering the following key features:
4 |
5 | - **Usability**. With EasyDist, parallelizing your training or inference code to a larger scale becomes effortless with just a single line of change.
6 |
7 | - **Ecological Compatibility**. EasyDist serves as a centralized source of truth for SPMD rules at the operator-level for various machine learning frameworks. Currently, EasyDist currently supports PyTorch, Jax natively, and the TVM Tensor Expression operator for SPMD rules.
8 |
9 | - **Infrastructure**. EasyDist decouples auto-parallel algorithms from specific machine learning frameworks and IRs. This design choice allows for the development and benchmarking of different auto-parallel algorithms in a more flexible manner, leveraging the capabilities and abstractions provided by EasyDist.
10 |
11 | ## One Line of Code for Parallelism
12 |
13 | To parallelize your training loop using EasyDist, you can use the `easydist_compile` decorator. Here's an example of how it can be used with PyTorch:
14 |
15 | ```python
16 | @easydist_compile()
17 | def train_step(net, optimizer, inputs, labels):
18 |
19 | outputs = net(inputs)
20 | loss = nn.CrossEntropyLoss()(outputs, labels)
21 | loss.backward()
22 |
23 | optimizer.step()
24 | optimizer.zero_grad()
25 |
26 | return loss
27 | ```
28 |
29 | This one-line decorator parallelizes the training step. You can find more examples in the [`./examples/`](./examples/) directory.
30 |
31 | ## Overview
32 |
33 | EasyDist introduces the concept of MetaOp and MetaIR to decouple automatic parallelization methods from specific intermediate representations (IR) and frameworks. Additionally, it presents the ShardCombine Algorithm, which defines operator Single-Program, Multiple-Data (SPMD) sharding rules without requiring manual annotations. The architecture of EasyDist is as follows:
34 |
35 |
36 |

37 |
38 |
39 | ## Installation
40 |
41 | To install EasyDist, you can use pip and install from PyPI:
42 |
43 | ```shell
44 | # For PyTorch users
45 | pip install pai-easydist[torch]
46 |
47 | # For Jax users
48 | pip install pai-easydist[jax] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
49 | ```
50 |
51 | If you prefer to install EasyDist from source, you can clone the GitHub repository and then install it with the appropriate extras:
52 |
53 | ```shell
54 | git clone https://github.com/alibaba/easydist.git && cd easydist
55 |
56 | # EasyDist with PyTorch installation
57 | pip install -e '.[torch]'
58 |
59 | # EasyDist with Jax installation
60 | pip install -e '.[jax]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
61 | ```
62 |
63 | ## Contributing
64 |
65 | See CONTRIBUTING.md for details.
66 |
67 | ## Contributors
68 |
69 | EasyDist is developed by Alibaba Group and NUS HPC-AI Lab. This work is supported by [Alibaba Innovative Research(AIR)](https://damo.alibaba.com/air/).
70 |
71 | ## License
72 |
73 | EasyDist is licensed under the Apache License (Version 2.0). See LICENSE file.
74 | This product contains some third-party testcases under other open source licenses.
75 | See the NOTICE file for more information.
76 |
77 |
--------------------------------------------------------------------------------
/benchmark/README.md:
--------------------------------------------------------------------------------
1 | ## Benchmark
2 |
3 | ```shell
4 | torchrun --nproc_per_node 2 --master_port 26543 ./benchmark/torch/bench_torch.py
5 | torchrun --nproc_per_node 2 --master_port 26543 ./benchmark/torch/bench_torch_tp.py
6 | ```
7 |
--------------------------------------------------------------------------------
/benchmark/bench_case.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 |
4 | @dataclass
5 | class GPTCase:
6 | batch_size: int = 4
7 | seq_size: int = 1024
8 | num_layers: int = 1
9 | hidden_dim: int = 12288
10 | num_heads: int = 48
11 | dropout_rate: float = 0.0
12 | use_bias: bool = True
13 | dtype = "float32"
14 |
15 |
16 | @dataclass
17 | class ResNetCase:
18 | batch_size: int = 128
19 |
20 |
21 | @dataclass
22 | class GATCase:
23 | num_node: int = 4096
24 | in_feature: int = 12288
25 | out_feature: int = 12288
26 |
--------------------------------------------------------------------------------
/benchmark/jax/bench_jax_alpa.py:
--------------------------------------------------------------------------------
1 | # python ./benchmark/bench_jax_alpa.py
2 |
3 | import logging
4 | import os
5 | import sys
6 |
7 | import alpa
8 | import jax
9 |
10 | os.environ["EASYDIST_DEVICE"] = "cuda"
11 | os.environ["EASYDIST_BACKEND"] = "jax"
12 |
13 | from easydist.utils.timer import EDTimer
14 |
15 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
16 | from benchmark.jax.model.gpt import GPTSimple
17 | from benchmark.jax.model.wresnet import resnet18
18 | from benchmark.jax.model.gat import GATLayer
19 | from benchmark.bench_case import GPTCase, ResNetCase, GATCase
20 |
21 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
22 | datefmt='%m/%d %H:%M:%S',
23 | level=logging.DEBUG)
24 |
25 |
26 | def get_gpt_case():
27 | case = GPTCase()
28 | model = GPTSimple(case)
29 |
30 | root_key = jax.random.PRNGKey(seed=0)
31 | main_key, params_key = jax.random.split(key=root_key)
32 | input_ = jax.random.normal(
33 | main_key, (case.batch_size, case.seq_size, case.hidden_dim)) # Dummy input data
34 | variables = model.init(params_key, input_, deterministic=True)
35 | params = variables['params']
36 |
37 | # DataParallel() Zero3Parallel() Zero2Parallel()
38 | @alpa.parallelize(method=alpa.ShardParallel())
39 | def train_step(params, input_):
40 | lr = 0.0001
41 |
42 | def loss_fn(params):
43 | dropout_key = jax.random.PRNGKey(seed=0)
44 | return model.apply({
45 | 'params': params
46 | },
47 | input_,
48 | deterministic=False,
49 | rngs={
50 | 'dropout': dropout_key
51 | }).mean()
52 |
53 | grad_fn = jax.grad(loss_fn)
54 | grads = grad_fn(params)
55 | params = jax.tree_map(lambda x, y: x - lr * y, params, grads)
56 | return params
57 |
58 | return train_step, [params, input_]
59 |
60 |
61 | def get_resnet_case():
62 | case = ResNetCase()
63 | model = resnet18()
64 |
65 | key1, key2 = jax.random.split(jax.random.PRNGKey(0), num=2)
66 | input_ = jax.random.normal(key1, (case.batch_size, 224, 224, 3)) # Dummy input data
67 | variables = model.init(key2, input_) # Initialization call
68 | params, batch_stats = variables['params'], variables['batch_stats']
69 |
70 | @alpa.parallelize(method=alpa.ShardParallel())
71 | def train_step(params, batch_stats, input_):
72 | lr = 0.0001
73 |
74 | def loss_fn(params, batch_stats):
75 | out_, batch_stats = model.apply({
76 | 'params': params,
77 | 'batch_stats': batch_stats
78 | },
79 | input_,
80 | mutable=['batch_stats'])
81 | return out_.mean()
82 |
83 | grad_fn = jax.grad(loss_fn)
84 | grads = grad_fn(params, batch_stats)
85 | params = jax.tree_map(lambda x, y: x - lr * y, params, grads)
86 | return params
87 |
88 | return train_step, [params, batch_stats, input_]
89 |
90 |
91 | def get_gat_case():
92 |
93 | case = GATCase()
94 | model = GATLayer(case.in_feature, case.out_feature)
95 |
96 | key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), num=3)
97 | h = jax.random.normal(key1, (case.num_node, case.in_feature)) # Dummy input data
98 | adj = jax.random.normal(key2, (case.num_node, case.num_node)) # Dummy input data
99 | variables = model.init(key3, h, adj) # Initialization call
100 | params = variables['params']
101 |
102 | @alpa.parallelize(method=alpa.ShardParallel())
103 | def train_step(params, h, adj):
104 | lr = 0.0001
105 |
106 | def loss_fn(params):
107 | return model.apply({'params': params}, h, adj).mean()
108 |
109 | grad_fn = jax.grad(loss_fn)
110 | grads = grad_fn(params)
111 | params = jax.tree_map(lambda x, y: x - lr * y, params, grads)
112 | return params
113 |
114 | return train_step, [params, h, adj]
115 |
116 |
117 | def bench_alpa(func, args):
118 |
119 | def train_step():
120 | func(*args)
121 |
122 | timer = EDTimer(train_step, in_ms=False)
123 |
124 | elaps_time = timer.time()
125 |
126 | print(f"Time: {elaps_time}")
127 |
128 |
129 | def main():
130 | os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
131 | print(jax.devices())
132 |
133 | func, args = get_gat_case()
134 |
135 | bench_alpa(func, args)
136 |
137 |
138 | if __name__ == '__main__':
139 | main()
140 |
--------------------------------------------------------------------------------
/benchmark/jax/bench_jax_dp.py:
--------------------------------------------------------------------------------
1 | # python ./benchmark/bench_jax_dp.py
2 |
3 | import os
4 | import sys
5 | import logging
6 | from functools import partial
7 |
8 | import jax
9 |
10 | from easydist import easydist_setup
11 | from easydist.utils.timer import EDTimer
12 |
13 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
14 | from benchmark.bench_case import GPTCase, ResNetCase
15 | from benchmark.jax.model.gpt import GPTSimple
16 | from benchmark.jax.model.wresnet import resnet18
17 |
18 |
19 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
20 | datefmt='%m/%d %H:%M:%S',
21 | level=logging.INFO)
22 |
23 |
24 | def get_gpt_case():
25 | case = GPTCase()
26 | model = GPTSimple(case)
27 |
28 | root_key = jax.random.PRNGKey(seed=0)
29 | main_key, params_key = jax.random.split(key=root_key)
30 | input_ = jax.random.normal(
31 | main_key, (case.batch_size, case.seq_size, case.hidden_dim)) # Dummy input data
32 | variables = model.init(params_key, input_, deterministic=True)
33 | params = variables['params']
34 |
35 | @partial(jax.pmap, axis_name="batch")
36 | def train_step(params, input_):
37 | lr = 0.0001
38 |
39 | def loss_fn(params):
40 | dropout_key = jax.random.PRNGKey(seed=0)
41 | return model.apply({
42 | 'params': params
43 | },
44 | input_,
45 | deterministic=False,
46 | rngs={
47 | 'dropout': dropout_key
48 | }).mean()
49 |
50 | grad_fn = jax.grad(loss_fn)
51 | grads = grad_fn(params)
52 | grads = jax.lax.pmean(grads, axis_name="batch")
53 | params = jax.tree_map(lambda x, y: x - lr * y, params, grads)
54 | return params
55 |
56 | devices = jax.local_devices()
57 | params = jax.device_put_replicated(params, devices)
58 |
59 | def shard_batch(x):
60 | x = x.reshape((len(devices), -1) + x.shape[1:])
61 | return jax.device_put_sharded(list(x), devices)
62 |
63 | input_ = jax.tree_map(shard_batch, input_)
64 |
65 | return train_step, [params, input_]
66 |
67 |
68 | def get_resnet_case():
69 | case = ResNetCase()
70 | model = resnet18()
71 |
72 | key1, key2 = jax.random.split(jax.random.PRNGKey(0), num=2)
73 | input_ = jax.random.normal(key1, (case.batch_size, 224, 224, 3)) # Dummy input data
74 | variables = model.init(key2, input_) # Initialization call
75 | params, batch_stats = variables['params'], variables['batch_stats']
76 |
77 | @partial(jax.pmap, axis_name="batch")
78 | def train_step(params, batch_stats, input_):
79 | lr = 0.0001
80 |
81 | def loss_fn(params, batch_stats):
82 | out_, batch_stats = model.apply({
83 | 'params': params,
84 | 'batch_stats': batch_stats
85 | },
86 | input_,
87 | mutable=['batch_stats'])
88 | return out_.mean()
89 |
90 | grad_fn = jax.grad(loss_fn)
91 | grads = grad_fn(params, batch_stats)
92 | grads = jax.lax.pmean(grads, axis_name="batch")
93 | params = jax.tree_map(lambda x, y: x - lr * y, params, grads)
94 | return params
95 |
96 | devices = jax.local_devices()
97 | params = jax.device_put_replicated(params, devices)
98 | batch_stats = jax.device_put_replicated(batch_stats, devices)
99 |
100 | def shard_batch(x):
101 | x = x.reshape((len(devices), -1) + x.shape[1:])
102 | return jax.device_put_sharded(list(x), devices)
103 |
104 | input_ = jax.tree_map(shard_batch, input_)
105 |
106 | return train_step, [params, batch_stats, input_]
107 |
108 |
109 | def bench_pmap_dp(func, args):
110 |
111 | def train_step():
112 | func(*args)
113 |
114 | timer = EDTimer(train_step, in_ms=False)
115 |
116 | elaps_time = timer.time()
117 |
118 | print(f"Time: {elaps_time}")
119 |
120 |
121 | def main():
122 | # setup easydist
123 | easydist_setup(backend="jax", device="cuda")
124 |
125 | print(jax.devices())
126 |
127 | func, args = get_gpt_case()
128 |
129 | bench_pmap_dp(func, args)
130 |
131 |
132 | if __name__ == '__main__':
133 | main()
134 |
--------------------------------------------------------------------------------
/benchmark/jax/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .gpt import GPT, GPTSimple, GPTBlock, GPTConfig
2 | from .resnet import ResNet, ResNet18
3 | from .wresnet import resnet18, resnet34, resnet50, wresnet50, wresnet101
4 | from .gat import GATLayer
5 |
6 | __all__ = [
7 | "GPT", "GPTSimple", "GPTBlock", "GPTConfig", "ResNet", "ResNet18", "resnet18", "resnet34",
8 | "resnet50", "wresnet50", "wresnet101", "GATLayer"
9 | ]
10 |
--------------------------------------------------------------------------------
/benchmark/jax/model/gat.py:
--------------------------------------------------------------------------------
1 | import jax
2 | from flax import linen as nn
3 |
4 |
5 | class GATLayer(nn.Module):
6 | in_features: int
7 | out_features: int
8 |
9 | @nn.compact
10 | def __call__(self, h, adj):
11 | wh = nn.Dense(features=self.out_features)(h)
12 | wh1 = nn.Dense(features=1)(wh)
13 | wh2 = nn.Dense(features=1)(wh)
14 | e = nn.leaky_relu(wh1 + wh2.T)
15 |
16 | zero_vec = -10e10 * jax.numpy.ones_like(e)
17 | attention = jax.numpy.where(adj > 0, e, zero_vec)
18 | attention = nn.softmax(attention)
19 |
20 | h_new = jax.numpy.matmul(attention, wh)
21 |
22 | return nn.elu(h_new)
23 |
--------------------------------------------------------------------------------
/benchmark/jax/model/resnet.py:
--------------------------------------------------------------------------------
1 | """Flax implementation of ResNet V1."""
2 |
3 | # Copyright 2023 The Flax Authors.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # See issue #620.
18 | # pytype: disable=wrong-arg-count
19 |
20 | from functools import partial
21 | from typing import Any, Callable, Sequence, Tuple
22 |
23 | from jax import numpy as jnp
24 | from flax import linen as nn
25 |
26 | ModuleDef = Any
27 |
28 |
29 | class ResNetBlock(nn.Module):
30 | """ResNet block."""
31 | filters: int
32 | conv: ModuleDef
33 | norm: ModuleDef
34 | act: Callable
35 | strides: Tuple[int, int] = (1, 1)
36 |
37 | @nn.compact
38 | def __call__(
39 | self,
40 | x,
41 | ):
42 | residual = x
43 | y = self.conv(self.filters, (3, 3), self.strides)(x)
44 | y = self.norm()(y)
45 | y = self.act(y)
46 | y = self.conv(self.filters, (3, 3))(y)
47 | y = self.norm(scale_init=nn.initializers.zeros_init())(y)
48 |
49 | if residual.shape != y.shape:
50 | residual = self.conv(self.filters, (1, 1), self.strides, name='conv_proj')(residual)
51 | residual = self.norm(name='norm_proj')(residual)
52 |
53 | return self.act(residual + y)
54 |
55 |
56 | class BottleneckResNetBlock(nn.Module):
57 | """Bottleneck ResNet block."""
58 | filters: int
59 | conv: ModuleDef
60 | norm: ModuleDef
61 | act: Callable
62 | strides: Tuple[int, int] = (1, 1)
63 |
64 | @nn.compact
65 | def __call__(self, x):
66 | residual = x
67 | y = self.conv(self.filters, (1, 1))(x)
68 | y = self.norm()(y)
69 | y = self.act(y)
70 | y = self.conv(self.filters, (3, 3), self.strides)(y)
71 | y = self.norm()(y)
72 | y = self.act(y)
73 | y = self.conv(self.filters * 4, (1, 1))(y)
74 | y = self.norm(scale_init=nn.initializers.zeros_init())(y)
75 |
76 | if residual.shape != y.shape:
77 | residual = self.conv(self.filters * 4, (1, 1), self.strides,
78 | name='conv_proj')(residual)
79 | residual = self.norm(name='norm_proj')(residual)
80 |
81 | return self.act(residual + y)
82 |
83 |
84 | class ResNet(nn.Module):
85 | """ResNetV1."""
86 | stage_sizes: Sequence[int]
87 | block_cls: ModuleDef
88 | num_classes: int = 1000
89 | num_filters: int = 64
90 | dtype: Any = jnp.float32
91 | act: Callable = nn.relu
92 | conv: ModuleDef = nn.Conv
93 |
94 | @nn.compact
95 | def __call__(self, x, train: bool = True):
96 | conv = partial(self.conv, use_bias=False, dtype=self.dtype)
97 | norm = partial(nn.BatchNorm,
98 | use_running_average=not train,
99 | momentum=0.9,
100 | epsilon=1e-5,
101 | dtype=self.dtype)
102 |
103 | x = conv(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init')(x)
104 | x = norm(name='bn_init')(x)
105 | x = nn.relu(x)
106 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
107 | for i, block_size in enumerate(self.stage_sizes):
108 | for j in range(block_size):
109 | strides = (2, 2) if i > 0 and j == 0 else (1, 1)
110 | x = self.block_cls(self.num_filters * 2**i,
111 | strides=strides,
112 | conv=conv,
113 | norm=norm,
114 | act=self.act)(x)
115 | x = jnp.mean(x, axis=(1, 2))
116 | x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
117 | x = jnp.asarray(x, self.dtype)
118 | return x
119 |
120 |
121 | ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock)
122 |
--------------------------------------------------------------------------------
/benchmark/torch/bench_torch_tp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | import torch.optim as optim
6 | from torch.nn.parallel import DistributedDataParallel as DDP
7 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel
8 | from fairscale.nn.model_parallel import get_data_parallel_group
9 |
10 | from easydist.utils.timer import EDTimer
11 | from easydist import easydist_setup
12 |
13 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
14 | from benchmark.bench_case import GPTCase
15 | from benchmark.torch.model.gpt_tp import GPT
16 |
17 |
18 | def get_gpt_case(cuda=True):
19 |
20 | case = GPTCase()
21 | model = GPT(depth=case.num_layers, dim=case.hidden_dim, num_heads=case.num_heads)
22 | data_in = torch.ones(case.batch_size, case.seq_size, case.hidden_dim)
23 |
24 | if cuda:
25 | return model.cuda(), data_in.cuda()
26 |
27 | return model, data_in
28 |
29 |
30 | def bench_tp(model, data_in):
31 |
32 | ddp_model = DDP(model, process_group=get_data_parallel_group())
33 | optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
34 |
35 | def train_step():
36 | optimizer.zero_grad()
37 | out = ddp_model(data_in)
38 | out_grad = torch.ones_like(out)
39 | out.backward(out_grad)
40 | optimizer.step()
41 |
42 | torch.cuda.reset_peak_memory_stats()
43 |
44 | timer = EDTimer(train_step, in_ms=False)
45 |
46 | elaps_time = timer.time()
47 | peak_memory = torch.cuda.max_memory_allocated()
48 |
49 | print(f"Memory: {peak_memory / 1024 / 1024 / 1024} GB")
50 | print(f"Time: {elaps_time}")
51 |
52 |
53 | def main():
54 | easydist_setup(backend="torch", device="cuda")
55 | # setup distributed
56 | torch.distributed.init_process_group(backend="nccl")
57 | local_rank = int(os.environ["LOCAL_RANK"])
58 | world_size = torch.distributed.get_world_size()
59 | initialize_model_parallel(world_size)
60 | torch.cuda.set_device(local_rank)
61 |
62 | model, data_in = get_gpt_case(cuda=True)
63 |
64 | bench_tp(model, data_in)
65 |
66 |
67 | if __name__ == '__main__':
68 | main()
69 |
--------------------------------------------------------------------------------
/benchmark/torch/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .gat import GATLayer
2 | from .gpt import GPT, GPTLayer, FeedForward, SelfAttention
3 | from .wresnet import resnet18, resnet34, resnet50, wresnet50, wresnet101
4 |
5 | __all__ = [
6 | 'GATLayer', 'GPT', 'GPTLayer', 'FeedForward', 'SelfAttention', 'wresnet50', 'wresnet101',
7 | 'resnet18', 'resnet34', 'resnet50', 'LLAMA', 'LLAMAConfig'
8 | ]
9 |
--------------------------------------------------------------------------------
/benchmark/torch/model/gat.py:
--------------------------------------------------------------------------------
1 | # code modified from https://github.com/Diego999/pyGAT
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class GATLayer(nn.Module):
9 |
10 | def __init__(self, in_features, out_features):
11 | super(GATLayer, self).__init__()
12 | self.in_features = in_features
13 | self.out_features = out_features
14 |
15 | self.linear_w = nn.Linear(in_features=in_features, out_features=out_features, bias=None)
16 |
17 | self.linear_a_1 = nn.Linear(in_features=out_features, out_features=1, bias=None)
18 | self.linear_a_2 = nn.Linear(in_features=out_features, out_features=1, bias=None)
19 |
20 | self.leakyrelu = nn.LeakyReLU()
21 |
22 | self.softmax = nn.Softmax(dim=-1)
23 |
24 | def forward(self, h, adj):
25 | wh = self.linear_w(h)
26 | wh1 = self.linear_a_1(wh)
27 | wh2 = self.linear_a_2(wh)
28 |
29 | e = self.leakyrelu(wh1 + wh2.T)
30 | zero_vec = -10e10 * torch.ones(*e.shape, device=e.device, dtype=e.dtype)
31 | attention = torch.where(adj > 0, e, zero_vec)
32 | attention = self.softmax(attention)
33 |
34 | h_new = torch.matmul(attention, wh)
35 |
36 | return F.elu(h_new)
37 |
--------------------------------------------------------------------------------
/benchmark/torch/pp/gpt/speed/batch.sh:
--------------------------------------------------------------------------------
1 | # threshold microbatch size
2 | for ((i=1 ; i <= 16 ; i += 1)); do
3 | torchrun --nproc_per_node 4 benchmark/torch/pp/gpt/speed/easydist_pipeline.py --micro-batch-size $i --num-chunks 1 --schedule gpipe
4 | done
5 |
6 | for ((i=1 ; i <= 16 ; i += 1)); do
7 | torchrun --nproc_per_node 4 benchmark/torch/pp/gpt/speed/easydist_pipeline.py --micro-batch-size $i --num-chunks 1 --schedule dapple
8 | done
9 |
10 | for ((i=1 ; i <= 16 ; i += 1)); do
11 | python benchmark/torch/pp/gpt/speed/torchgpipe_pipeline.py --micro-batch-size $i --num-chunks 1
12 | done
13 |
14 | for ((i=1 ; i <= 16 ; i += 1)); do
15 | python benchmark/torch/pp/gpt/speed/vanilla_torch.py --micro-batch-size $i --num-chunks 1
16 | done
17 |
18 |
19 | # num chunks
20 | for ((i=1; i <= 32 ; i *= 2)); do
21 | torchrun --nproc_per_node 4 benchmark/torch/pp/gpt/speed/easydist_pipeline.py --dataset-size 5000 --micro-batch-size 16 --num-chunks $i --schedule gpipe
22 | done
23 | torchrun --nproc_per_node 4 benchmark/torch/pp/gpt/speed/easydist_pipeline.py --dataset-size 5000 --micro-batch-size 16 --num-chunks 34 --schedule gpipe
24 |
25 | for ((i=1; i <= 64 ; i *= 2)); do
26 | torchrun --nproc_per_node 4 benchmark/torch/pp/gpt/speed/easydist_pipeline.py --dataset-size 5000 --micro-batch-size 16 --num-chunks $i --schedule dapple
27 | done
28 | torchrun --nproc_per_node 4 benchmark/torch/pp/gpt/speed/easydist_pipeline.py --dataset-size 5000 --micro-batch-size 16 --num-chunks 98 --schedule dapple
29 |
30 | for ((i=1; i <= 32 ; i *= 2)); do
31 | python benchmark/torch/pp/gpt/speed/torchgpipe_pipeline.py --dataset-size 5000 --micro-batch-size 16 --num-chunks $i
32 | done
33 | python benchmark/torch/pp/gpt/speed/torchgpipe_pipeline.py --dataset-size 5000 --micro-batch-size 16 --num-chunks 34
34 |
35 | for ((i=1; i <= 256 ; i *= 2)); do
36 | python benchmark/torch/pp/gpt/speed/vanilla_torch.py --dataset-size 5000 --micro-batch-size 16 --num-chunks $i
37 | done
38 |
39 |
40 |
--------------------------------------------------------------------------------
/benchmark/torch/pp/resnet101/accuracy/vanilla_torch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | # python benchmark/torch/pp/resnet101/accuracy/vanilla_torch.py
16 | import os
17 | import random
18 | import time
19 |
20 | import numpy as np
21 |
22 | import torch
23 |
24 | from torchvision import datasets, transforms
25 | from torchvision.models import resnet18
26 | from torch.profiler import profile, record_function, ProfilerActivity
27 |
28 | from tqdm import tqdm
29 |
30 |
31 | def seed(seed=42):
32 | # Set seed for PyTorch
33 | torch.manual_seed(seed)
34 | # torch.use_deterministic_algorithms(True)
35 | if torch.cuda.is_available():
36 | torch.cuda.manual_seed(seed)
37 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
38 | # Set seed for numpy
39 | np.random.seed(seed)
40 | # Set seed for built-in Python
41 | random.seed(seed)
42 | # Set(seed) for each of the random number generators in python:
43 | torch.backends.cudnn.deterministic = True
44 | torch.backends.cudnn.benchmark = False
45 |
46 |
47 | criterion = torch.nn.CrossEntropyLoss()
48 |
49 |
50 | def train_step(input, label, model, opt):
51 | opt.zero_grad()
52 | out = model(input)
53 | loss = criterion(out, label)
54 | loss.backward()
55 | opt.step()
56 | return out, loss
57 |
58 |
59 | def test_main():
60 | seed(1)
61 |
62 | device = torch.device('cuda')
63 |
64 | module = resnet18().train().to(device)
65 | module.fc = torch.nn.Linear(module.fc.in_features, 10).to(device)
66 | batch_size = 1024
67 |
68 | opt = torch.optim.Adam(module.parameters(), foreach=True, capturable=True)
69 |
70 | transform = transforms.Compose([
71 | transforms.ToTensor(),
72 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
73 | ])
74 | train_data = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
75 | train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
76 | x_batch, y_batch = next(iter(train_dataloader))
77 | train_step(x_batch.to(device), y_batch.to(device), module, opt)
78 | epochs = 5
79 | for epoch in range(epochs):
80 | all_cnt, correct_cnt, loss_sum = 0, 0, 0
81 | time_start = time.time()
82 | for x_batch, y_batch in tqdm(train_dataloader, dynamic_ncols=True):
83 | x_batch = x_batch.to(device)
84 | y_batch = y_batch.to(device)
85 | if x_batch.size(0) != batch_size: # TODO need to solve this
86 | continue
87 | out, loss = train_step(x_batch, y_batch, module, opt)
88 | all_cnt += len(out)
89 | preds = out.argmax(-1)
90 | correct_cnt += (preds == y_batch).sum()
91 | loss_sum += loss.mean().item()
92 | print(
93 | f'epoch {epoch} train accuracy: {correct_cnt / all_cnt}, loss sum {loss_sum}, avg loss: {loss_sum / all_cnt} '
94 | f'time: {time.time() - time_start}')
95 |
96 |
97 | if __name__ == '__main__':
98 | test_main()
99 |
--------------------------------------------------------------------------------
/benchmark/torch/pp/resnet101/speed/batch.sh:
--------------------------------------------------------------------------------
1 | for ((i=1; i <= 256 ; i *= 2)); do
2 | torchrun --nproc_per_node 4 benchmark/torch/pp/resnet101/speed/easydist_pipeline.py --micro-batch-size 128 --num-chunks $i --schedule gpipe
3 | done
4 |
5 | for ((i=1; i <= 256 ; i *= 2)); do
6 | torchrun --nproc_per_node 4 benchmark/torch/pp/resnet101/speed/easydist_pipeline.py --micro-batch-size 128 --num-chunks $i --schedule dapple
7 | done
8 |
9 |
10 |
--------------------------------------------------------------------------------
/benchmark/torch/pp/resnet101/speed/resnet/__init__.py:
--------------------------------------------------------------------------------
1 | """A ResNet implementation but using :class:`nn.Sequential`. :func:`resnet101`
2 | returns a :class:`nn.Sequential` instead of ``ResNet``.
3 |
4 | This code is transformed :mod:`torchvision.models.resnet`.
5 |
6 | """
7 | from collections import OrderedDict
8 | from typing import Any, List
9 |
10 | from torch import nn
11 |
12 | from resnet.bottleneck import bottleneck
13 | from resnet.flatten_sequential import flatten_sequential
14 |
15 | __all__ = ['resnet101']
16 |
17 |
18 | def build_resnet(layers: List[int],
19 | num_classes: int = 1000,
20 | inplace: bool = False
21 | ) -> nn.Sequential:
22 | """Builds a ResNet as a simple sequential model.
23 |
24 | Note:
25 | The implementation is copied from :mod:`torchvision.models.resnet`.
26 |
27 | """
28 | inplanes = 64
29 |
30 | def make_layer(planes: int,
31 | blocks: int,
32 | stride: int = 1,
33 | inplace: bool = False,
34 | ) -> nn.Sequential:
35 | nonlocal inplanes
36 |
37 | downsample = None
38 | if stride != 1 or inplanes != planes * 4:
39 | downsample = nn.Sequential(
40 | nn.Conv2d(inplanes, planes * 4,
41 | kernel_size=1, stride=stride, bias=False),
42 | nn.BatchNorm2d(planes * 4),
43 | )
44 |
45 | layers = []
46 | layers.append(bottleneck(inplanes, planes, stride, downsample, inplace))
47 | inplanes = planes * 4
48 | for _ in range(1, blocks):
49 | layers.append(bottleneck(inplanes, planes, inplace=inplace))
50 |
51 | return nn.Sequential(*layers)
52 |
53 | # Build ResNet as a sequential model.
54 | model = nn.Sequential(OrderedDict([
55 | ('conv1', nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)),
56 | ('bn1', nn.BatchNorm2d(64)),
57 | ('relu', nn.ReLU()),
58 | ('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
59 |
60 | ('layer1', make_layer(64, layers[0], inplace=inplace)),
61 | ('layer2', make_layer(128, layers[1], stride=2, inplace=inplace)),
62 | ('layer3', make_layer(256, layers[2], stride=2, inplace=inplace)),
63 | ('layer4', make_layer(512, layers[3], stride=2, inplace=inplace)),
64 |
65 | ('avgpool', nn.AdaptiveAvgPool2d((1, 1))),
66 | ('flat', nn.Flatten()),
67 | ('fc', nn.Linear(512 * 4, num_classes)),
68 | ]))
69 |
70 | # Flatten nested sequentials.
71 | model = flatten_sequential(model)
72 |
73 | # Initialize weights for Conv2d and BatchNorm2d layers.
74 | # Stolen from torchvision-0.4.0.
75 | def init_weight(m: nn.Module) -> None:
76 | if isinstance(m, nn.Conv2d):
77 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
78 | return
79 |
80 | if isinstance(m, nn.BatchNorm2d):
81 | nn.init.constant_(m.weight, 1)
82 | nn.init.constant_(m.bias, 0)
83 | return
84 |
85 | model.apply(init_weight)
86 |
87 | return model
88 |
89 |
90 | def resnet101(**kwargs: Any) -> nn.Sequential:
91 | """Constructs a ResNet-101 model."""
92 | return build_resnet([3, 4, 23, 3], **kwargs)
93 |
--------------------------------------------------------------------------------
/benchmark/torch/pp/resnet101/speed/resnet/bottleneck.py:
--------------------------------------------------------------------------------
1 | """A ResNet bottleneck implementation but using :class:`nn.Sequential`."""
2 | from collections import OrderedDict
3 | from typing import TYPE_CHECKING, Optional, Tuple, Union
4 |
5 | from torch import Tensor, nn
6 |
7 | from torchgpipe.skip import Namespace, pop, skippable, stash
8 |
9 | __all__ = ['bottleneck']
10 |
11 | Tensors = Tuple[Tensor, ...]
12 | TensorOrTensors = Union[Tensor, Tensors]
13 |
14 | if TYPE_CHECKING:
15 | NamedModules = OrderedDict[str, nn.Module]
16 | else:
17 | NamedModules = OrderedDict
18 |
19 |
20 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
21 | """3x3 convolution with padding"""
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
23 | padding=1, bias=False)
24 |
25 |
26 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
27 | """1x1 convolution"""
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
29 |
30 |
31 | @skippable(stash=['identity'])
32 | class Identity(nn.Module):
33 | def forward(self, tensor: Tensor) -> Tensor: # type: ignore
34 | yield stash('identity', tensor)
35 | return tensor
36 |
37 |
38 | @skippable(pop=['identity'])
39 | class Residual(nn.Module):
40 | """A residual block for ResNet."""
41 |
42 | def __init__(self, downsample: Optional[nn.Module] = None):
43 | super().__init__()
44 | self.downsample = downsample
45 |
46 | def forward(self, input: Tensor) -> Tensor: # type: ignore
47 | identity = yield pop('identity')
48 | if self.downsample is not None:
49 | identity = self.downsample(identity)
50 | return input + identity
51 |
52 |
53 | def bottleneck(inplanes: int,
54 | planes: int,
55 | stride: int = 1,
56 | downsample: Optional[nn.Module] = None,
57 | inplace: bool = False,
58 | ) -> nn.Sequential:
59 | """Creates a bottleneck block in ResNet as a :class:`nn.Sequential`."""
60 |
61 | layers: NamedModules = OrderedDict()
62 |
63 | ns = Namespace()
64 | layers['identity'] = Identity().isolate(ns) # type: ignore
65 |
66 | layers['conv1'] = conv1x1(inplanes, planes)
67 | layers['bn1'] = nn.BatchNorm2d(planes)
68 | layers['relu1'] = nn.ReLU(inplace=inplace)
69 |
70 | layers['conv2'] = conv3x3(planes, planes, stride)
71 | layers['bn2'] = nn.BatchNorm2d(planes)
72 | layers['relu2'] = nn.ReLU(inplace=inplace)
73 |
74 | layers['conv3'] = conv1x1(planes, planes * 4)
75 | layers['bn3'] = nn.BatchNorm2d(planes * 4)
76 | layers['residual'] = Residual(downsample).isolate(ns) # type: ignore
77 | layers['relu3'] = nn.ReLU(inplace=inplace)
78 |
79 | return nn.Sequential(layers)
80 |
--------------------------------------------------------------------------------
/benchmark/torch/pp/resnet101/speed/resnet/flatten_sequential.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Iterator, Tuple
3 |
4 | from torch import nn
5 |
6 |
7 | def flatten_sequential(module: nn.Sequential) -> nn.Sequential:
8 | """flatten_sequentials a nested sequential module."""
9 | if not isinstance(module, nn.Sequential):
10 | raise TypeError('not sequential')
11 |
12 | return nn.Sequential(OrderedDict(_flatten_sequential(module)))
13 |
14 |
15 | def _flatten_sequential(module: nn.Sequential) -> Iterator[Tuple[str, nn.Module]]:
16 | for name, child in module.named_children():
17 | # flatten_sequential child sequential layers only.
18 | if isinstance(child, nn.Sequential):
19 | for sub_name, sub_child in _flatten_sequential(child):
20 | yield (f'{name}_{sub_name}', sub_child)
21 | else:
22 | yield (name, child)
23 |
--------------------------------------------------------------------------------
/benchmark/torch/pp/resnet101/speed/vanila_torch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import os
16 | import random
17 | import time
18 |
19 | import numpy as np
20 |
21 | import torch
22 |
23 | from torchvision import datasets, transforms
24 | from torchvision.models import resnet101
25 | from torch.profiler import profile, record_function, ProfilerActivity
26 |
27 | from tqdm import tqdm
28 |
29 |
30 | def seed(seed=42):
31 | # Set seed for PyTorch
32 | torch.manual_seed(seed)
33 | # torch.use_deterministic_algorithms(True)
34 | if torch.cuda.is_available():
35 | torch.cuda.manual_seed(seed)
36 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
37 | # Set seed for numpy
38 | np.random.seed(seed)
39 | # Set seed for built-in Python
40 | random.seed(seed)
41 | # Set(seed) for each of the random number generators in python:
42 | torch.backends.cudnn.deterministic = True
43 | torch.backends.cudnn.benchmark = False
44 |
45 |
46 | criterion = torch.nn.CrossEntropyLoss()
47 |
48 |
49 | def train_step(input, label, model, opt):
50 | out = model(input)
51 | loss = criterion(out, label)
52 | loss.backward()
53 | opt.step()
54 | opt.zero_grad()
55 | return out, loss
56 |
57 |
58 | def test_main():
59 | seed(42)
60 |
61 | device = torch.device('cuda')
62 |
63 | module = resnet101().train().to(device)
64 | module.fc = torch.nn.Linear(2048, 10).to(device)
65 |
66 | opt = torch.optim.Adam(module.parameters(), foreach=True, capturable=True)
67 |
68 | dataset_size = 10000
69 | batch_size = 128
70 | train_dataloader = [(torch.randn(batch_size, 3, 224, 224), torch.randint(
71 | 0, 10, (batch_size, )))] * (dataset_size // batch_size)
72 |
73 | x_batch, y_batch = next(iter(train_dataloader))
74 | train_step(x_batch.to(device), y_batch.to(device), module, opt)
75 | epochs = 1
76 | for epoch in range(epochs):
77 | all_cnt, correct_cnt, loss_sum = 0, 0, 0
78 | time_start = time.time()
79 | for x_batch, y_batch in tqdm(train_dataloader, dynamic_ncols=True):
80 | x_batch = x_batch.to(device)
81 | y_batch = y_batch.to(device)
82 | if x_batch.size(0) != batch_size: # TODO need to solve this
83 | continue
84 | out, loss = train_step(x_batch, y_batch, module, opt)
85 | all_cnt += len(out)
86 | preds = out.argmax(-1)
87 | correct_cnt += (preds == y_batch).sum()
88 | loss_sum += loss.mean().item()
89 | print(
90 | f'epoch {epoch} train accuracy: {correct_cnt / all_cnt}, loss sum {loss_sum}, avg loss: {loss_sum / all_cnt} '
91 | f'time: {time.time() - time_start}'
92 | f'max memory: {torch.cuda.max_memory_allocated() / 1024 / 1024}mb')
93 |
94 |
95 | if __name__ == '__main__':
96 | test_main()
97 |
--------------------------------------------------------------------------------
/easydist/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import logging
16 |
17 | from . import platform
18 | import easydist.config as mdconfig
19 |
20 |
21 | def easydist_setup(backend, device="cpu", allow_tf32=True):
22 | mdconfig.easydist_device = device
23 |
24 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
25 | datefmt='%m/%d %H:%M:%S',
26 | level=mdconfig.log_level)
27 |
28 | if backend == "jax":
29 | from .jax import easydist_setup_jax
30 | easydist_setup_jax(device, allow_tf32)
31 |
32 | logging.getLogger("jax._src").setLevel(logging.INFO)
33 | elif backend == "torch":
34 | from .torch import easydist_setup_torch
35 | easydist_setup_torch(device, allow_tf32)
36 |
37 | logging.getLogger("torch._subclasses.fake_tensor").setLevel(logging.INFO)
38 |
39 | platform.init_backend(backend)
40 |
--------------------------------------------------------------------------------
/easydist/autoflow/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | from .solver import AutoFlowSolver1D
16 |
17 | __all__ = ["AutoFlowSolver1D"]
18 |
--------------------------------------------------------------------------------
/easydist/jax/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import os
16 | import logging
17 |
18 | import jax
19 | from mpi4py import MPI
20 |
21 | from .api import easydist_shard, get_opt_strategy, set_device_mesh
22 | from .sharding_interpreter import EDJaxShardingAnn
23 | from .bridge import jax2md_bridge
24 |
25 | __all__ = [
26 | "easydist_shard", "get_opt_strategy", "set_device_mesh", "EDJaxShardingAnn", "jax2md_bridge"
27 | ]
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 |
32 | def is_jax_distributed_initialized():
33 | return jax._src.distributed.global_state.client is not None
34 |
35 |
36 | def easydist_setup_jax(device, allow_tf32):
37 | os.environ["NVIDIA_TF32_OVERRIDE"] = "1" if allow_tf32 else "0"
38 | jax.config.update('jax_platforms', device)
39 |
40 | # setup distributed
41 | comm = MPI.COMM_WORLD
42 | size, rank = comm.Get_size(), comm.Get_rank()
43 |
44 | if not is_jax_distributed_initialized():
45 |
46 | jax.distributed.initialize(coordinator_address="localhost:19705",
47 | num_processes=size,
48 | process_id=rank,
49 | local_device_ids=rank)
50 |
51 | logging.info(
52 | f"[Rank {rank}], Global Devices: {jax.device_count()}, Local Devices: {jax.local_device_count()}"
53 | )
54 |
--------------------------------------------------------------------------------
/easydist/jax/device_mesh.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import functools
16 | import operator
17 |
18 | from easydist.metashard import metair
19 |
20 | JAX_DEVICE_MESH = None
21 |
22 | def set_device_mesh(device_mesh):
23 | global JAX_DEVICE_MESH
24 | JAX_DEVICE_MESH = device_mesh
25 |
26 | mesh_shape = device_mesh.device_ids.shape
27 |
28 | if len(mesh_shape) > 1:
29 | raise ValueError("Only support 1D mesh now")
30 |
31 |
32 | def get_device_mesh():
33 | global JAX_DEVICE_MESH
34 | return JAX_DEVICE_MESH
35 |
36 |
37 | def device_mesh_world_size(device_mesh=None):
38 | if device_mesh is None:
39 | device_mesh = get_device_mesh()
40 |
41 | if device_mesh is None:
42 | return None
43 |
44 | device_mesh_shape = device_mesh.device_ids.shape
45 |
46 | return functools.reduce(operator.mul, device_mesh_shape)
47 |
--------------------------------------------------------------------------------
/easydist/jax/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import os
16 | from contextlib import contextmanager
17 |
18 |
19 | @contextmanager
20 | def _sharding_ann_env():
21 |
22 | ori_tf32_override = os.environ.get("NVIDIA_TF32_OVERRIDE", None)
23 |
24 | os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
25 |
26 | try:
27 | yield
28 | finally:
29 | if ori_tf32_override is None:
30 | os.environ.pop("NVIDIA_TF32_OVERRIDE")
31 | os.environ["NVIDIA_TF32_OVERRIDE"] = ori_tf32_override
32 |
--------------------------------------------------------------------------------
/easydist/metashard/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | from .metaop import MetaOp
16 | from .view_propagation import view_propagation
17 | from .annotation import ShardDim, ShardAnnotation
18 |
19 | __all__ = ["MetaOp", "view_propagation", "ShardDim", "ShardAnnotation"]
20 |
--------------------------------------------------------------------------------
/easydist/metashard/halo.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | from typing import List
16 |
17 | from easydist import platform
18 |
19 |
20 | class HaloInfo:
21 |
22 | def __init__(self, halowidth: int, dim: int) -> None:
23 | self.halowidth = halowidth
24 | self.dim = dim
25 |
26 | def __str__(self) -> str:
27 | return self.halowidth.__str__()
28 |
29 | def __repr__(self) -> str:
30 | return self.__str__()
31 |
32 |
33 | def halo_padding(tensor_list_: List[platform.Tensor], haloinfo: HaloInfo) -> List[platform.Tensor]:
34 | """add halo padding to tensor_list_"""
35 |
36 | if haloinfo is None or len(tensor_list_) < 2:
37 | return tensor_list_
38 |
39 | halo = haloinfo.halowidth
40 | dim = haloinfo.dim
41 |
42 | padded_tensor_list = []
43 | for idx in range(len(tensor_list_)):
44 | to_concatenate = [tensor_list_[idx]]
45 | if idx >= 1:
46 | dim_size = tensor_list_[idx - 1].shape[dim]
47 | if dim_size < halo:
48 | raise RuntimeError("Cannot halo padding for this sharded_tensor")
49 | to_concatenate.insert(
50 | 0, platform.narrow(tensor_list_[idx - 1], dim, dim_size - halo, halo))
51 | if idx <= len(tensor_list_) - 2:
52 | to_concatenate.append(platform.narrow(tensor_list_[idx + 1], dim, 0, halo))
53 | padded_tensor_list.append(platform.concatenate(to_concatenate, dim=dim))
54 |
55 | return padded_tensor_list
56 |
--------------------------------------------------------------------------------
/easydist/platform/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import os
16 | import logging
17 | import importlib
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 | EASYDIST_BACKEND = None
22 |
23 | __all__ = [
24 | "add", "equal", "zeros_like", "min", "max", "allclose", "concatenate", "chunk", "narrow",
25 | "Tensor", "tree_flatten", "tree_unflatten", "clone", "from_numpy"
26 | ]
27 |
28 |
29 | def backend_valid(_backend):
30 | return _backend in {"torch", "jax", "tvm"}
31 |
32 |
33 | def init_backend(backend="torch"):
34 | assert backend_valid(backend)
35 | global EASYDIST_BACKEND
36 | EASYDIST_BACKEND = backend
37 | modules = importlib.import_module("." + backend, __name__)
38 | for val in __all__:
39 | exec("globals()['%s'] = modules.%s" % (val, val))
40 | logger.info(f"========= EasyDist init with backend {backend}. =========")
41 |
42 |
43 | def get_backend():
44 | global EASYDIST_BACKEND
45 | return EASYDIST_BACKEND
46 |
47 |
48 | for val in __all__:
49 | exec("globals()['%s'] = None" % val)
50 |
--------------------------------------------------------------------------------
/easydist/platform/jax.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | from functools import partial
16 |
17 | import jax
18 |
19 | add = jax.numpy.add
20 | equal = jax.numpy.array_equal
21 | zeros_like = jax.numpy.zeros_like
22 | min = jax.numpy.minimum
23 | max = jax.numpy.maximum
24 | allclose = partial(jax.numpy.allclose, rtol=5e-3, atol=5e-03)
25 |
26 |
27 | def concatenate(tensors, dim=0):
28 | return jax.numpy.concatenate(tensors, axis=dim)
29 |
30 |
31 | def chunk(input, chunks, dim=0):
32 | return jax.numpy.array_split(input, chunks, axis=dim)
33 |
34 |
35 | def narrow(input, dim, start, length):
36 | indices = jax.numpy.asarray(range(start, start + length))
37 | return jax.numpy.take(input, indices, axis=dim)
38 |
39 |
40 | Tensor = jax.Array
41 |
42 | tree_flatten = jax.tree_util.tree_flatten
43 |
44 |
45 | def tree_unflatten(values, spec):
46 | return jax.tree_util.tree_unflatten(spec, values)
47 |
48 |
49 | clone = jax.numpy.copy
50 |
51 | from_numpy = jax.numpy.array
52 |
--------------------------------------------------------------------------------
/easydist/platform/torch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | from functools import partial
16 |
17 | import torch
18 | import torch.utils._pytree as pytree
19 |
20 | add = torch.Tensor.add
21 | equal = torch.equal
22 | zeros_like = torch.zeros_like
23 | min = torch.min
24 | max = torch.max
25 | allclose = partial(torch.allclose, rtol=1e-3, atol=1e-07)
26 | concatenate = torch.concatenate
27 | chunk = torch.chunk
28 | narrow = torch.narrow
29 |
30 | Tensor = torch.Tensor
31 |
32 | tree_flatten = pytree.tree_flatten
33 | tree_unflatten = pytree.tree_unflatten
34 |
35 | clone = torch.clone
36 | from_numpy = torch.from_numpy
37 |
--------------------------------------------------------------------------------
/easydist/platform/tvm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import tvm
16 | import numpy
17 | import torch.utils._pytree as pytree
18 |
19 |
20 | def add(a, b):
21 | return tvm.nd.array(numpy.add(a.numpy(), b.numpy()))
22 |
23 |
24 | def equal(a, b):
25 | return numpy.array_equal(a.numpy(), b.numpy())
26 |
27 |
28 | def zeros_like(input_):
29 | return tvm.nd.array(numpy.zeros_like(input_.numpy()))
30 |
31 |
32 | def min(a, b):
33 | return tvm.nd.array(numpy.minimum(a.numpy(), b.numpy()))
34 |
35 |
36 | def max(a, b):
37 | return tvm.nd.array(numpy.maximum(a.numpy(), b.numpy()))
38 |
39 |
40 | def allclose(a, b):
41 | return numpy.allclose(a.numpy(), b.numpy(), rtol=5e-3, atol=5e-03)
42 |
43 |
44 | def concatenate(tensors, dim=0):
45 | return tvm.nd.array(numpy.concatenate([t.numpy() for t in tensors], axis=dim))
46 |
47 |
48 | def chunk(input, chunks, dim=0):
49 | return [tvm.nd.array(i) for i in numpy.array_split(input.numpy(), chunks, axis=dim)]
50 |
51 |
52 | def narrow(input, dim, start, length):
53 | indices = numpy.asarray(range(start, start + length))
54 | return tvm.nd.array(numpy.take(input.numpy(), indices, axis=dim))
55 |
56 |
57 | Tensor = tvm.nd.NDArray
58 |
59 | tree_flatten = pytree.tree_flatten
60 | tree_unflatten = pytree.tree_unflatten
61 |
62 |
63 | def clone(input_):
64 | return tvm.nd.array(numpy.copy(input_.numpy()))
65 |
66 |
67 | from_numpy = tvm.nd.array
68 |
--------------------------------------------------------------------------------
/easydist/torch/compile.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from easydist.torch.decomp_utils import EASYDIST_DECOMP_TABLE
3 | from easydist.torch.experimental.pp.split_utils import SplitPatcher
4 | import torch.utils._pytree as pytree
5 | from torch._subclasses.fake_tensor import FakeTensor
6 | from torch.fx.experimental.proxy_tensor import make_fx
7 | from functools import partial
8 | from easydist.torch.experimental.pp.split_utils import clear_pp_compile_states, get_updated_params_states
9 | from easydist.torch.experimental.pp.utils import save_graphviz_dot
10 | from easydist.torch.init_helper import SetParaInitHelper
11 | from easydist.torch.utils import _enable_compile, _rematerialize_optimizer
12 | from easydist.torch.scope_auto.build_scope_modules import build_scope_modules
13 |
14 |
15 | import torch
16 | from torch.nn.utils import stateless
17 |
18 |
19 | from contextlib import nullcontext
20 | from typing import cast
21 |
22 | from easydist.utils import rgetattr, rsetattr
23 |
24 |
25 | def stateless_func(func, module, opt, params, buffers, named_states, args, kwargs):
26 | clear_pp_compile_states()
27 | with stateless._reparametrize_module(
28 | cast(torch.nn.Module, module), {
29 | **params,
30 | **buffers
31 | }, tie_weights=True) if module else nullcontext(), _rematerialize_optimizer(
32 | opt, named_states, params) if opt else nullcontext():
33 | ret = func(*args, **kwargs)
34 | if (tup := get_updated_params_states()) != (None, None):
35 | params, named_states = tup
36 | grads = {k: v.grad for k, v in params.items()}
37 | return params, buffers, named_states, grads, ret
38 |
39 |
40 | def ed_compile_func(func, tracing_mode, init_helper, args, kwargs, schedule_cls, module, opt):
41 | params, buffers = {}, {}
42 | if module is not None:
43 | params = dict(module.named_parameters())
44 | buffers = dict(module.named_buffers())
45 |
46 | if isinstance(init_helper, SetParaInitHelper):
47 | init_helper.module = module
48 |
49 | named_states = {}
50 |
51 | if opt is not None:
52 | # assign grad and warm up optimizer
53 | for name in dict(module.named_parameters()):
54 | with torch.no_grad():
55 | rsetattr(module, name + ".grad", torch.zeros_like(rgetattr(module, name).data))
56 | if isinstance(rgetattr(module, name).data, FakeTensor):
57 | mode = rgetattr(module, name).data.fake_mode
58 |
59 | opt.step()
60 | opt.zero_grad(True)
61 |
62 | for n, p in params.items():
63 | if p in opt.state:
64 | named_states[n] = opt.state[p] # type: ignore[index]
65 | # if step in state, reduce one for warmup step.
66 | if 'step' in named_states[n]:
67 | named_states[n]['step'] -= 1
68 |
69 | flat_named_states, _ = pytree.tree_flatten(named_states)
70 |
71 | # fix for sgd withtout momentum
72 | if all(state is None for state in flat_named_states):
73 | named_states = {}
74 | flat_named_states, _ = pytree.tree_flatten(named_states)
75 |
76 | state_tensor_num = len(params) + len(buffers) + len(flat_named_states)
77 |
78 | with _enable_compile(), SplitPatcher(module, opt) if schedule_cls else nullcontext():
79 | traced_graph = make_fx(partial(stateless_func, func, module, opt),
80 | tracing_mode=tracing_mode,
81 | decomposition_table=EASYDIST_DECOMP_TABLE,
82 | _allow_non_fake_inputs=False)(params, buffers, named_states, args,
83 | kwargs)
84 |
85 | if len(list(traced_graph.named_buffers())) != 0:
86 | warnings.warn(f"No buffer should be found in the traced graph, please check if the model is correctly traced, found {dict(traced_graph.named_buffers())}")
87 |
88 | traced_graph = build_scope_modules(traced_graph)
89 | traced_graph.graph.eliminate_dead_code()
90 | traced_graph.recompile()
91 |
92 | save_graphviz_dot(traced_graph, 'traced_graph')
93 |
94 | return params,buffers,named_states,state_tensor_num,traced_graph
95 |
--------------------------------------------------------------------------------
/easydist/torch/cuda/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/easydist/4ced4cd7ca7a39722670beae16332426d5430238/easydist/torch/cuda/__init__.py
--------------------------------------------------------------------------------
/easydist/torch/cuda/mem_allocator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import ctypes
16 |
17 | import easydist
18 | from easydist.torch.meta_allocator import profiling_allocator
19 |
20 |
21 | class _CUDAAllocator:
22 | def __init__(self, allocator):
23 | self._allocator = allocator
24 |
25 | def allocator(self):
26 | return self._allocator
27 |
28 |
29 | class EffectiveCUDAAllocator(_CUDAAllocator):
30 | def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str):
31 | allocator = ctypes.CDLL(path_to_so_file)
32 | alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value
33 | free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value
34 | assert alloc_fn is not None
35 | assert free_fn is not None
36 | self._allocator = profiling_allocator._cuda_customEffectiveAllocator(alloc_fn, free_fn)
37 |
38 |
39 | def change_current_allocator(allocator: _CUDAAllocator) -> None:
40 | profiling_allocator._cuda_changeCurrentAllocator(allocator.allocator())
41 |
42 | def init_meta_allocator():
43 | if not easydist.config.enable_memory_opt:
44 | return
45 | swap_to_profiling_allocator()
46 |
47 | def swap_to_profiling_allocator():
48 | # swap from caching allocator to profiling allocator
49 |
50 | profiling_allocator._compile_if_needed()
51 |
52 | path_to_profiling_allocator = profiling_allocator.module.__file__
53 | raw_allocator = ctypes.CDLL(path_to_profiling_allocator)
54 | init_fn = ctypes.cast(getattr(raw_allocator, 'init_fn'), ctypes.c_void_p).value
55 | new_alloc = EffectiveCUDAAllocator(
56 | path_to_profiling_allocator, 'meta_malloc', 'meta_free')
57 | profiling_allocator._save_back_allocator()
58 | change_current_allocator(new_alloc)
59 | new_alloc.allocator().set_init_fn(init_fn)
60 |
61 |
--------------------------------------------------------------------------------
/easydist/torch/experimental/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/easydist/4ced4cd7ca7a39722670beae16332426d5430238/easydist/torch/experimental/__init__.py
--------------------------------------------------------------------------------
/easydist/torch/experimental/pp/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | # TODO @botbw: dependency and circular import issues
--------------------------------------------------------------------------------
/easydist/torch/experimental/pp/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import logging
16 | import operator
17 | from copy import deepcopy
18 | from typing import Any, Dict
19 |
20 | import torch.fx as fx
21 | from torch.fx.passes.graph_drawer import FxGraphDrawer
22 |
23 | import easydist.config as mdconfig
24 |
25 |
26 | def ordered_gi_users(node: fx.node):
27 | assert all(user.op == 'call_function' and user.target == operator.getitem
28 | for user in node.users), "All users of the node must be getitem"
29 | ret = [None for _ in range(len(node.users))]
30 | for user in node.users:
31 | ret[user.args[1]] = user
32 | return ret
33 |
34 |
35 | def save_graphviz_dot(gm, name):
36 | with open(f"./log/{name}.dot", "w") as f:
37 | f.write(str(FxGraphDrawer(gm, name).get_dot_graph()))
38 |
39 |
40 | def _to_tuple(x):
41 | if isinstance(x, tuple):
42 | return x
43 | return (x, )
44 |
45 |
46 | class OneToOneMap:
47 | def __init__(self):
48 | self._map: Dict[Any, Any]= {}
49 | self._inv: Dict[Any, Any] = {}
50 |
51 | def get(self, key: Any) -> Any:
52 | return self._map[key]
53 |
54 | def inv_get(self, key: Any) -> Any:
55 | return self._inv[key]
56 |
57 | def add(self, key: Any, value: Any) -> Any:
58 | if key in self._map or value in self._map:
59 | raise RuntimeError(f"{key}: {value} is not one to one mapping, found {key in self._map} {value in self._inv}")
60 | self._map[key] = value
61 | self._inv[value] = key
62 |
63 | def items(self):
64 | return self._map.items()
65 |
66 | def keys(self):
67 | return self._map.keys()
68 |
69 | def inv_items(self):
70 | return self._inv.items()
71 |
72 | def inv_keys(self):
73 | return self._inv.keys()
74 |
75 | def apply(self, other: "OneToOneMap") -> "OneToOneMap":
76 | mapping = OneToOneMap()
77 | for k, v in self.items():
78 | mapping.add(k, other.get(v))
79 | return mapping
80 |
81 | def map_dict_key(self, dictt: Dict) -> Dict:
82 | ret = {}
83 | for k, v in dictt.items():
84 | ret[self._map[k]] = v
85 | return ret
86 |
87 | def inverse(self) -> "OneToOneMap":
88 | inversed = deepcopy(self)
89 | inversed._map, inversed._inv = inversed._inv, inversed._map
90 | return inversed
91 |
92 | def __repr__(self):
93 | return f"{self._map=}\n{self._inv=}"
94 |
95 | @staticmethod
96 | def from_dict(dict: Dict) -> "OneToOneMap":
97 | mapping = OneToOneMap()
98 | for k, v in dict.items():
99 | mapping.add(k, v)
100 | return mapping
101 |
--------------------------------------------------------------------------------
/easydist/torch/graph_profile_db.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import os
16 | import pickle
17 | import logging
18 |
19 | import easydist.config as mdconfig
20 |
21 | _logger = logging.getLogger(__name__)
22 |
23 |
24 | class PerfDB:
25 |
26 | def __init__(self) -> None:
27 | self.db = dict()
28 | if os.path.exists(mdconfig.prof_db_path):
29 | self.db = pickle.load(open(mdconfig.prof_db_path, 'rb'))
30 | _logger.info(f"Load Perf DB from {mdconfig.prof_db_path}")
31 |
32 | def get_op_perf(self, key_l1, key_l2):
33 | if key_l1 in self.db:
34 | return self.db[key_l1].get(key_l2, None)
35 | return None
36 |
37 | def record_op_perf(self, key_l1, key_l2, value):
38 | if key_l1 not in self.db:
39 | self.db[key_l1] = dict()
40 | self.db[key_l1][key_l2] = value
41 |
42 | def persistent(self):
43 | _logger.info(f"Persistent Perf DB to {mdconfig.prof_db_path}")
44 |
45 | if not os.path.exists(mdconfig.easydist_dir):
46 | os.makedirs(mdconfig.easydist_dir, exist_ok=True)
47 |
48 | pickle.dump(self.db, open(mdconfig.prof_db_path, 'wb'))
49 |
--------------------------------------------------------------------------------
/easydist/torch/meta_allocator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import os
16 | import logging
17 |
18 | import pynvml
19 |
20 | from torch.utils.cpp_extension import load, _join_cuda_home
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 | def _compile():
25 |
26 | print("torch load")
27 |
28 | device_index = os.environ.get('CUDA_VISIBLE_DEVICES', "0").split(',')[0]
29 |
30 | # (NOTE) workaround of torch.cuda.get_cuda_arch_list() which will init allocator
31 | pynvml.nvmlInit()
32 | handle = pynvml.nvmlDeviceGetHandleByIndex(int(device_index))
33 | capability = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
34 |
35 | os.environ['TORCH_CUDA_ARCH_LIST'] = f'{capability[0]}.{capability[1]}+PTX'
36 |
37 | sources_files = [
38 | 'csrc/profiling_allocator.cpp',
39 | 'csrc/stream_tracer.cpp',
40 | 'csrc/cupti_callback_api.cpp',
41 | 'csrc/python_tracer_init.cpp',
42 | 'csrc/effective_cuda_allocator.cpp'
43 | ]
44 |
45 | profiling_allocator_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "profiler")
46 | _comiled_module = load(name="profiling_allocator",
47 | extra_include_paths=[_join_cuda_home('extras', 'CUPTI', 'include')],
48 | sources=[os.path.join(profiling_allocator_dir, f) for f in sources_files],
49 | extra_cflags=['-DUSE_CUDA=1', '-D_GLIBCXX_USE_CXX11_ABI=0'],
50 | with_cuda=True, verbose=True)
51 |
52 | logger.info(f"[profiling_allocator] compiled in {_comiled_module.__file__}")
53 |
54 | return _comiled_module
55 |
56 |
57 | class LazyModule():
58 | def __init__(self):
59 | self.module = None
60 |
61 | def _compile_if_needed(self):
62 | if self.module is None:
63 | self.module = _compile()
64 |
65 | def __getattr__(self, name: str):
66 | self._compile_if_needed()
67 | return self.module.__getattribute__(name)
68 |
69 | profiling_allocator = LazyModule()
70 |
--------------------------------------------------------------------------------
/easydist/torch/passes/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | from .fix_embedding import fix_embedding
16 | from .fix_bias import fix_addmm_bias, fix_convoluation_bias
17 | from .fix_meta_device import fix_meta_device
18 | from .eliminate_detach import eliminate_detach
19 | from .sharding import sharding_transform, sharding_transform_dtensor
20 | from .tile_comm import tile_comm
21 | from .comm_optimize import comm_optimize
22 | from .rule_override import rule_override_by_graph
23 | from .runtime_prof import runtime_prof
24 | from .edinfo_utils import create_edinfo, annotation_edinfo
25 | from .process_tag import process_tag
26 | from .allocator_profiler import AllocatorProfiler, ModuleProfilingInfo
27 | from .fix_view import decouple_view
28 | from .pp_passes import get_partition
29 | __all__ = [
30 | "fix_embedding", "fix_addmm_bias", "fix_convoluation_bias", "eliminate_detach",
31 | "sharding_transform", "sharding_transform_dtensor", "fix_meta_device", "tile_comm",
32 | "comm_optimize", "rule_override_by_graph", "runtime_prof", "create_edinfo",
33 | "AllocatorProfiler", "ModuleProfilingInfo", "annotation_edinfo", "process_tag",
34 | "decouple_view", "get_partition"
35 | ]
36 |
--------------------------------------------------------------------------------
/easydist/torch/passes/edinfo_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import operator
16 |
17 | import torch
18 | from torch.fx.node import Node, _get_qualified_name
19 | import torch.utils._pytree as pytree
20 |
21 | from easydist.torch.utils import EDInfo, EDNodeType, create_meta_from_node
22 | from easydist.torch.passes.sharding import CREATE_ATEN_OP, COMM_FUNCS
23 |
24 |
25 | def create_edinfo(fx_module: torch.fx.GraphModule, sharding_info, shape_info) -> torch.fx.GraphModule:
26 |
27 | for node in fx_module.graph.nodes:
28 |
29 | if node.op == 'call_function' and 'val' not in node.meta:
30 | node.meta = create_meta_from_node(node)
31 |
32 |
33 | if not hasattr(node, "ed_info"):
34 | node.ed_info = EDInfo(ori_meta=node.meta)
35 |
36 | if node.op == "call_function":
37 | op_name = _get_qualified_name(node.target)
38 |
39 | node_sharding_info = None
40 | if op_name in sharding_info:
41 |
42 | def _gen_meta(arg: Node):
43 | return torch.empty(shape_info[arg.name]["shape"],
44 | dtype=shape_info[arg.name]["dtype"],
45 | device="meta")
46 |
47 | args_meta = pytree.tree_map_only(Node, _gen_meta, node.args)
48 | args_meta = str(tuple(args_meta)) + ' | ' + str(node.kwargs)
49 | if args_meta in sharding_info[op_name]:
50 | node_sharding_info = sharding_info[op_name][args_meta]
51 |
52 | node.ed_info.spmd_annotation = node_sharding_info
53 |
54 | elif node.op in ["placeholder", "get_attr"]:
55 | if hasattr(node, "meta") and 'val' in node.meta:
56 | node_sharding_info = None
57 | if node.op in sharding_info:
58 | arg_meta_tensor = torch.empty(shape_info[node.name]["shape"],
59 | dtype=shape_info[node.name]["dtype"],
60 | device="meta")
61 | args_meta = str(arg_meta_tensor)
62 | if args_meta in sharding_info[node.op]:
63 | node_sharding_info = sharding_info[node.op][args_meta]
64 |
65 | node.ed_info.spmd_annotation = node_sharding_info
66 |
67 | return fx_module
68 |
69 |
70 | def annotation_edinfo(traced_graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
71 | for node in traced_graph.graph.nodes:
72 | if not hasattr(node, "ed_info"):
73 | node.ed_info = EDInfo(ori_meta=node.meta)
74 |
75 | if node.op == 'placeholder':
76 | node.ed_info.node_type = EDNodeType.AUXILIARY
77 | elif node.op == 'call_function':
78 | # create meta for custom function
79 | if node.target not in CREATE_ATEN_OP:
80 | node.meta = create_meta_from_node(node)
81 | # annotate node type
82 | if node.target in COMM_FUNCS:
83 | node.ed_info.node_type = EDNodeType.COMMUNICATION
84 | # (TODO) hard code here to avoid to runtime profile torch.ops.aten._fused_adam.default
85 | elif node.target in [operator.getitem, torch.ops.aten._fused_adam.default]:
86 | node.ed_info.node_type = EDNodeType.AUXILIARY
87 | else:
88 | node.ed_info.node_type = EDNodeType.COMPUTATION
89 | elif node.op == 'output':
90 | node.ed_info.node_type = EDNodeType.AUXILIARY
91 |
92 | return traced_graph
--------------------------------------------------------------------------------
/easydist/torch/passes/eliminate_detach.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import torch.fx as fx
16 | from torch.fx.node import _get_qualified_name
17 |
18 |
19 | def eliminate_detach(fx_graph: fx.GraphModule):
20 | for node in fx_graph.graph.nodes:
21 | if node.op == 'call_function':
22 | if _get_qualified_name(node.target) == 'torch.ops.aten.detach.default':
23 | node.replace_all_uses_with(node.args[0])
24 |
25 | fx_graph.graph.eliminate_dead_code()
26 |
27 | return fx_graph
28 |
--------------------------------------------------------------------------------
/easydist/torch/passes/fix_bias.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import copy
16 |
17 | import torch
18 | from torch.fx.node import _get_qualified_name
19 |
20 | from easydist.torch.utils import create_meta_from_node
21 |
22 | def fix_addmm_bias(fx_module: torch.fx.GraphModule):
23 |
24 | for node in fx_module.graph.nodes:
25 | if node.op == 'call_function':
26 | if "torch.ops.aten.addmm.default" in _get_qualified_name(node.target):
27 | node.target = torch.ops.aten.mm.default
28 | bias = node.args[0]
29 | node.args = (node.args[1], node.args[2])
30 |
31 | with fx_module.graph.inserting_after(node):
32 | add_bias_node = fx_module.graph.call_function(torch.ops.aten.add.Tensor,
33 | args=(node, bias))
34 |
35 | node.replace_all_uses_with(add_bias_node)
36 |
37 | add_bias_node.update_arg(0, node)
38 |
39 | add_bias_node.meta = create_meta_from_node(add_bias_node)
40 |
41 | fx_module.recompile()
42 |
43 | return fx_module
44 |
45 |
46 | def fix_convoluation_bias(fx_module: torch.fx.GraphModule):
47 |
48 | for node in fx_module.graph.nodes:
49 | if node.op == 'call_function':
50 | if "torch.ops.aten.convolution.default" in _get_qualified_name(node.target):
51 | if node.args[2] is not None:
52 | node.target = torch.ops.aten.convolution.default
53 | bias = node.args[2]
54 | node.args = (node.args[0], node.args[1], None, *node.args[3:])
55 |
56 | with fx_module.graph.inserting_after(node):
57 | bias_new = fx_module.graph.call_function(torch.ops.aten.view.default,
58 | args=(bias, [1, -1, 1, 1]))
59 |
60 | with fx_module.graph.inserting_after(bias_new):
61 | add_bias_node = fx_module.graph.call_function(torch.ops.aten.add.Tensor,
62 | args=(node, bias_new))
63 |
64 | node.replace_all_uses_with(add_bias_node)
65 |
66 | add_bias_node.update_arg(0, node)
67 |
68 | bias_new.meta = create_meta_from_node(bias_new)
69 | add_bias_node.meta = create_meta_from_node(add_bias_node)
70 |
71 | fx_module.recompile()
72 |
73 | return fx_module
74 |
--------------------------------------------------------------------------------
/easydist/torch/passes/fix_embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import torch
16 | from torch.fx.node import _get_qualified_name
17 |
18 |
19 | def md_embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
20 | if int(torch.max(indices).item()) >= weight.shape[0]:
21 | raise RuntimeError("embedding indice overflow")
22 | return torch.ops.aten.embedding.default(weight, indices, padding_idx, scale_grad_by_freq,
23 | sparse)
24 |
25 |
26 | def fix_embedding(fx_module: torch.fx.GraphModule, recover=False):
27 |
28 | for node in fx_module.graph.nodes:
29 | if node.op == 'call_function':
30 | if "torch.ops.aten.embedding.default" in _get_qualified_name(node.target):
31 | node.target = md_embedding
32 |
33 | if recover and "md_embedding" in _get_qualified_name(node.target):
34 | node.target = torch.ops.aten.embedding.default
35 |
36 | fx_module.recompile()
37 |
38 | return fx_module
39 |
--------------------------------------------------------------------------------
/easydist/torch/passes/fix_meta_device.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import os
16 | import copy
17 |
18 | import torch
19 |
20 | import easydist.config as mdconfig
21 |
22 |
23 | def fix_meta_device(fx_module: torch.fx.GraphModule):
24 |
25 | for node in fx_module.graph.nodes:
26 | if node.op == 'call_function':
27 | if "device" in node.kwargs:
28 | new_kwargs = dict(copy.deepcopy(node.kwargs))
29 | device = mdconfig.easydist_device
30 | new_kwargs["device"] = torch.device(device=device)
31 | assert isinstance(new_kwargs, dict)
32 | node.kwargs = new_kwargs
33 |
34 | fx_module.recompile()
35 |
36 | return fx_module
37 |
--------------------------------------------------------------------------------
/easydist/torch/passes/fix_view.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import functools
16 | import operator
17 | import torch
18 | import torch.fx as fx
19 |
20 | from easydist.torch.preset_propagation import VIEW_OPS
21 |
22 |
23 | def _fix_view_node(input_shape, output_shape):
24 | if -1 in output_shape:
25 | numel = functools.reduce(operator.mul, input_shape)
26 | dim_size = -1 * numel // functools.reduce(operator.mul, output_shape)
27 | output_shape[output_shape.index(-1)] = dim_size
28 |
29 | intermediate_shape = []
30 | i = j = 0
31 | while i < len(input_shape) and j < len(output_shape):
32 | accu_i = input_shape[i]
33 | accu_j = output_shape[j]
34 | while accu_i != accu_j:
35 | if accu_i < accu_j:
36 | i += 1
37 | accu_i *= input_shape[i]
38 | else:
39 | j += 1
40 | accu_j *= output_shape[j]
41 | intermediate_shape.append(accu_i)
42 | i += 1
43 | j += 1
44 | while i < len(input_shape):
45 | intermediate_shape.append(input_shape[i])
46 | i += 1
47 | while j < len(output_shape):
48 | intermediate_shape.append(output_shape[j])
49 | j += 1
50 | assert i == len(input_shape) and j == len(output_shape)
51 | return intermediate_shape
52 |
53 |
54 | def decouple_view(fx_module: fx.GraphModule):
55 | for node in fx_module.graph.nodes:
56 | if node.op == 'call_function' and node.target in VIEW_OPS:
57 | target_op = node.target
58 | input_shape = list(node.args[0].meta['val'].shape)
59 | output_shape = list(node.args[1])
60 | intermediate_shape = _fix_view_node(input_shape, output_shape)
61 | if input_shape != intermediate_shape and output_shape != intermediate_shape:
62 | node.args = (node.args[0], intermediate_shape)
63 | fake_mode = node.meta['val'].fake_mode
64 | node.meta['val'] = fake_mode.from_tensor(
65 | torch.zeros(intermediate_shape,
66 | dtype=node.meta['val'].dtype))
67 | with fx_module.graph.inserting_after(node):
68 | intermediate_view = fx_module.graph.call_function(
69 | target_op, args=(node, output_shape))
70 | intermediate_view.meta['val'] = fake_mode.from_tensor(
71 | torch.zeros(output_shape,
72 | dtype=node.meta['val'].dtype))
73 | node.replace_all_uses_with(intermediate_view, delete_user_cb=lambda x: x != intermediate_view)
74 |
75 | fx_module.recompile()
76 | return fx_module
77 |
78 | # def fix_sharded_view(fx_module: fx.GraphModule):
79 | # for node in fx_module.graph.nodes:
80 | # if (node.op == 'call_function' and node.target in VIEW_OPS and
81 | # node.args[0].op == 'call_function' and node.args[0].target == scatter_wrapper):
82 | # input_val = node.args[0].meta['val']
83 | # output_shape = list(node.args[1])
84 | # try:
85 | # _ = input_val.view(output_shape)
86 | # except Exception:
87 |
88 | # # global [768 (shard), 768] => view as [3, 256, 768]
89 | # # 0 [384, 768] => view as [1, 256, 768]
90 | # # 1 [384, 768] => view as [2, 256, 768]
--------------------------------------------------------------------------------
/easydist/torch/passes/pp_passes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | from typing import List, Set
16 |
17 | import torch
18 | import torch.fx as fx
19 |
20 |
21 | def get_partition(fx_module: fx.GraphModule) -> List[Set[str]]:
22 | partitions = []
23 | cur_par = set()
24 | for node in fx_module.graph.nodes:
25 | cur_par.add(node.name)
26 | if node.op == 'call_function' and node.target in [
27 | torch.ops.easydist.fw_bw_split.default,
28 | torch.ops.easydist.step_split.default
29 | ]:
30 | partitions.append(cur_par)
31 | cur_par = set()
32 | partitions.append(cur_par)
33 | return partitions
34 |
--------------------------------------------------------------------------------
/easydist/torch/passes/process_tag.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | import torch
16 | import torch._custom_ops
17 |
18 | from easydist.torch.device_mesh import get_device_mesh
19 | from easydist.torch.passes.sharding import all_reduce_start, all_reduce_end
20 |
21 |
22 | @torch._custom_ops.custom_op("easydist::tag")
23 | def tag(input: torch.Tensor, tag: str) -> torch.Tensor:
24 | ...
25 |
26 | @torch._custom_ops.impl_abstract("easydist::tag")
27 | def tag_impl_abstract(input: torch.Tensor, tag: str) -> torch.Tensor:
28 | return torch.empty_like(input)
29 |
30 |
31 | @torch._custom_ops.impl("easydist::tag")
32 | def tag_impl(input: torch.Tensor, tag: str) -> torch.Tensor:
33 | return input
34 |
35 |
36 | def process_tag(traced_graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
37 |
38 | tp_mesh = get_device_mesh('tp')
39 |
40 | for node in traced_graph.graph.nodes:
41 | if node.target == torch.ops.easydist.tag.default:
42 | if node.args[1] == "allreduce[sum]":
43 | reduceOp = "sum"
44 | ranks = tp_mesh.mesh.flatten().tolist()
45 | with traced_graph.graph.inserting_before(node):
46 | all_reduce_start_node = traced_graph.graph.call_function(all_reduce_start,
47 | args=(node.args[0],
48 | reduceOp,
49 | ranks))
50 | all_reduce_end_node = traced_graph.graph.call_function(
51 | all_reduce_end, args=(all_reduce_start_node, reduceOp, ranks))
52 |
53 | node.replace_all_uses_with(all_reduce_end_node)
54 |
55 | traced_graph.graph.eliminate_dead_code()
56 | traced_graph.recompile()
57 |
58 | return traced_graph
--------------------------------------------------------------------------------
/easydist/torch/profiler/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Alibaba Group;
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 |
15 | from .stream_tracer import (
16 | StreamTracer,
17 | )
18 |
19 | __all__ = [
20 | "StreamTracer",
21 | ]
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/easydist/torch/profiler/csrc/cupti_callback_api.h:
--------------------------------------------------------------------------------
1 | /* Copyright (c) 2024, Alibaba Group;
2 | Licensed under the Apache License, Version 2.0 (the "License");
3 | you may not use this file except in compliance with the License.
4 | You may obtain a copy of the License at
5 |
6 | http://www.apache.org/licenses/LICENSE-2.0
7 |
8 | Unless required by applicable law or agreed to in writing, software
9 | distributed under the License is distributed on an "AS IS" BASIS,
10 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | See the License for the specific language governing permissions and
12 | limitations under the License.
13 | ==============================================================================*/
14 |
15 | #include