├── .flake8
├── .gitignore
├── .pre-commit-config.yaml
├── README.md
├── pyproject.toml
├── symm_mem_all_reduce.py
├── torch_compile_all_gather_matmul.py
├── triton_all_gather_matmul.py
├── triton_barrier.py
├── triton_multimem_all_reduce.py
├── triton_one_shot_all_reduce.py
├── triton_utils.py
└── utils.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 88
3 | ignore = E203, E501, E731, W503
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/psf/black
3 | rev: 24.8.0
4 | hooks:
5 | - id: black
6 |
7 | - repo: https://github.com/pycqa/isort
8 | rev: 5.13.2
9 | hooks:
10 | - id: isort
11 |
12 | - repo: https://github.com/pycqa/flake8
13 | rev: 7.0.0
14 | hooks:
15 | - id: flake8
16 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # symm-mem-recipes
2 |
3 | This repository includes:
4 | - Usage and benchmarks of `SymmetricMemory`-based multi-GPU algorithms in PyTorch.
5 | - Examples and benchmarks of multi-GPU algorithms built with `SymmetricMemory` + Triton.
6 |
7 | ---
8 | ## symm_mem_all_reduce.py
9 |
10 | This script demonstrates the usage of `SymmetricMemory`-based NVLink all-reduce implementations and benchmarks their performance. The available variants are:
11 | - `multimem_all_reduce` (PyTorch op available in nightly)
12 | - `one_shot_all_reduce` (PyTorch op available in nightly)
13 | - `two_shot_all_reduce` (PyTorch op available in nightly)
14 | - `triton_multimem_all_reduce` (Triton kernel defined in this repo)
15 | - `triton_one_shot_all_reduce` (Triton kernel defined in this repo)
16 |
17 | Usage:
18 | ```bash
19 | torchrun \
20 | --nnodes 1 --nproc-per-node 8 \
21 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \
22 | --no_python python3 symm_mem_all_reduce.py --impl multimem_all_reduce
23 | ```
24 |
25 | Some benchmarks on 8xH100 with NVSwitch:
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 | ---
37 | ## triton_all_gather_matmul.py
38 |
39 | This is a fused all-gather matmul example using Triton + `SymmetricMemory`, based on the `tma_persistent` Triton tutorial with slight modifications.
40 |
41 | This example requires PyTorch Nightly and Triton 3.0.0+ to run.
42 |
43 | Usage:
44 | ```bash
45 | torchrun \
46 | --nnodes 1 --nproc-per-node 8 \
47 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \
48 | --no_python python3 triton_all_gather_matmul.py \
49 | --M 16384 --N 6656 --K 16384 --BLOCK_SIZE_M 128 --BLOCK_SIZE_N 256 --BLOCK_SIZE_K 64
50 | ```
51 |
52 | Some benchmarks on 8xH100 (special version with HBM2e, at 650W) with NVSwitch:
53 |
54 | #### Llama 3 8B (N=1792, K=4096)
55 | | Problem Size
(M) | Config1 | cuBLAS MM
Only (µs) | Triton MM
Only (µs) | cuBLAS +
NCCL (µs) | Triton
Fused (µs) | Speedup |
56 | |------------|------------|------------|------------|------------|------------|------------|
57 | | 4096 | 64,128,128,4 | 100 | 142 | 223 | 211 | 1.05x2 |
58 | | 8192 | 128,128,64,6 | 186 | 198 | 393 | 293 | 1.34x |
59 | | 16384 | 128,256,64,3 | 363 | 363 | 748 | 485 | 1.54x |
60 |
61 | #### Llama 3 70B (N=3584, K=8192)
62 | | Problem Size
(M) | Config1 | cuBLAS MM
Only (µs) | Triton MM
Only (µs) | cuBLAS +
NCCL (µs) | Triton
Fused (µs) | Speedup |
63 | |------------|------------|------------|------------|------------|------------|------------|
64 | | 4096 | 128,128,64,6 | 376 | 392 | 587 | 453 | 1.29x |
65 | | 8192 | 128,256,64,3 | 746 | 706 | 1168 | 821 | 1.42x |
66 | | 16384 | 128,256,64,3 | 1502 | 1403 | 2306 | 1566 | 1.47x |
67 |
68 | #### Llama 3 105B (N=6656, K=16384)
69 | | Problem Size
(M) | Config1 | cuBLAS MM
Only (µs) | Triton MM
Only (µs) | cuBLAS +
NCCL (µs) | Triton
Fused (µs) | Speedup |
70 | |------------|------------|------------|------------|------------|------------|------------|
71 | | 4096 | 128,256,64,3 | 1358 | 1425 | 1858 | 1615 | 1.15x |
72 | | 8192 | 128,256,64,3 | 2567 | 2656 | 3533 | 2907 | 1.22x |
73 | | 16384 | 128,256,64,3 | 5249 | 5375 | 6982 | 5814 | 1.20x |
74 |
75 | 1 Config refers to `BLOCK_SIZE_M`, `BLOCK_SIZE_N`, `BLOCK_SIZE_K`, and `num_stages`.
76 |
77 | 2 For this problem size, using multicast all-gather would be a more suitable optimization.
78 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 88
3 | target-version = ['py38']
4 |
5 | [tool.isort]
6 | profile = "black"
7 | atomic = true
8 | combine_as_imports = true
9 | include_trailing_comma = true
10 | indent = 4
11 | line_length = 88
12 | lines_after_imports = 2
13 | multi_line_output = 3
14 |
--------------------------------------------------------------------------------
/symm_mem_all_reduce.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 |
4 | import click
5 | import torch
6 | import torch.distributed as dist
7 | import torch.distributed._symmetric_memory as symm_mem
8 |
9 | from utils import benchmark_with_profiler
10 |
11 |
12 | def multimem_all_reduce(msg):
13 | torch.ops.symm_mem.multimem_all_reduce_(
14 | msg,
15 | "sum",
16 | dist.group.WORLD.group_name,
17 | )
18 |
19 |
20 | def one_shot_all_reduce(msg):
21 | torch.ops.symm_mem.one_shot_all_reduce(
22 | msg,
23 | "sum",
24 | dist.group.WORLD.group_name,
25 | )
26 |
27 |
28 | def two_shot_all_reduce(msg):
29 | torch.ops.symm_mem.two_shot_all_reduce_(
30 | msg,
31 | "sum",
32 | dist.group.WORLD.group_name,
33 | )
34 |
35 |
36 | def triton_multimem_all_reduce(msg):
37 | from triton_multimem_all_reduce import multimem_all_reduce
38 |
39 | multimem_all_reduce(msg)
40 |
41 |
42 | def triton_one_shot_all_reduce(msg):
43 | from triton_one_shot_all_reduce import one_shot_all_reduce
44 |
45 | one_shot_all_reduce(msg)
46 |
47 |
48 | def get_impl(impl: str):
49 | if impl == "multimem_all_reduce":
50 | return multimem_all_reduce
51 | elif impl == "one_shot_all_reduce":
52 | return one_shot_all_reduce
53 | elif impl == "two_shot_all_reduce":
54 | return two_shot_all_reduce
55 | elif impl == "triton_multimem_all_reduce":
56 | return triton_multimem_all_reduce
57 | elif impl == "triton_one_shot_all_reduce":
58 | return triton_one_shot_all_reduce
59 | else:
60 | raise NotImplementedError(impl)
61 |
62 |
63 | def benchmark(device: torch.device, impl: str, msg_sz_bytes: int):
64 | msg_numel = msg_sz_bytes // torch.bfloat16.itemsize
65 | msg = symm_mem.empty(
66 | msg_numel,
67 | dtype=torch.bfloat16,
68 | device=device,
69 | )
70 | symm_mem.rendezvous(msg, dist.group.WORLD.group_name)
71 |
72 | target_fn = functools.partial(get_impl(impl), msg)
73 | baseline_fn = functools.partial(dist.all_reduce, msg)
74 |
75 | target_us = benchmark_with_profiler(
76 | target_fn, ".*all_reduce.*", benchmark_iters=200
77 | )
78 | baseline_us = benchmark_with_profiler(
79 | baseline_fn, ".*AllReduce.*", benchmark_iters=200
80 | )
81 | if dist.get_rank() == 0:
82 | print(
83 | f"msg_sz_bytes: {msg_sz_bytes}\t"
84 | f"nccl_ring: {baseline_us:.2f} us\t"
85 | f"{impl}: {target_us:.2f} us\t"
86 | )
87 |
88 |
89 | @click.command()
90 | @click.option(
91 | "--impl",
92 | help="Valid options: multimem_all_reduce, one_shot_all_reduce, two_shot_all_reduce, triton_multimem_all_reduce, triton_one_shot_all_reduce",
93 | default="multimem_all_reduce",
94 | )
95 | def main(impl: str):
96 | """
97 | Benchmark for the symmetric memory-based all-reduce variants.
98 | NVSwitch is required.
99 |
100 | torchrun \
101 | --nnodes 1 --nproc-per-node 8 \
102 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \
103 | --no_python python3 symm_mem_all_reduce.py
104 | """
105 | local_rank = int(os.environ["LOCAL_RANK"])
106 |
107 | device = torch.device(f"cuda:{local_rank}")
108 | torch.cuda.set_device(device)
109 | dist.init_process_group("nccl")
110 | torch.manual_seed(42 + local_rank)
111 |
112 | if dist.get_rank() == 0:
113 | print(f"Benchmarking {impl}...")
114 |
115 | msg_sizes = [2**exp for exp in range(12, 21)]
116 | for msg_sz_bytes in msg_sizes:
117 | benchmark(device, impl, msg_sz_bytes)
118 |
119 | dist.destroy_process_group()
120 |
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/torch_compile_all_gather_matmul.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 | from typing import List
4 |
5 | import click
6 | import torch
7 | import torch.distributed as dist
8 | import torch.distributed._symmetric_memory as symm_mem
9 | from torch.distributed._functional_collectives import all_gather_tensor
10 |
11 | from utils import benchmark_with_event
12 |
13 |
14 | def parse_csv(ctx, param, value):
15 | return [int(num) for num in value.split(",")]
16 |
17 |
18 | def all_gather_matmul(a_shard, bs, gather_dim, group_name):
19 | a = all_gather_tensor(a_shard.contiguous(), gather_dim=gather_dim, group=group_name)
20 | return [torch.matmul(a, b) for b in bs]
21 |
22 |
23 | compiled_all_gather_matmul = torch.compile(
24 | options={
25 | "_micro_pipeline_tp": True,
26 | "keep_output_stride": False,
27 | },
28 | fullgraph=True,
29 | )(all_gather_matmul)
30 |
31 |
32 | def scaled_matmul(a, b, a_scale, b_scale, **kwargs):
33 | leading_dims = a.shape[:-1]
34 | c = torch._scaled_mm(a.flatten(0, -2), b, a_scale, b_scale, **kwargs)
35 | return c.view(*leading_dims, -1)
36 |
37 |
38 | def all_gather_scaled_matmul(a_shard, bs, a_scale, b_scales, gather_dim, group_name):
39 | a = all_gather_tensor(a_shard.contiguous(), gather_dim=gather_dim, group=group_name)
40 | return [
41 | scaled_matmul(
42 | a, b, a_scale, b_scale, out_dtype=torch.bfloat16, use_fast_accum=True
43 | )
44 | for b, b_scale in zip(bs, b_scales)
45 | ]
46 |
47 |
48 | compiled_all_gather_scaled_matmul = torch.compile(
49 | options={
50 | "_micro_pipeline_tp": True,
51 | "keep_output_stride": False,
52 | },
53 | fullgraph=True,
54 | )(all_gather_scaled_matmul)
55 |
56 |
57 | @click.command()
58 | @click.option("--batch", default=1)
59 | @click.option("--M", default=8192)
60 | @click.option("--N", callback=parse_csv, default="3584")
61 | @click.option("--K", default=8192)
62 | @click.option("--dtype", default="bfloat16")
63 | @click.option("--gather-dim", default=0)
64 | @click.option("--scale-mode", default="tensor-wise")
65 | @click.option("--cuda-graph", is_flag=True, default=False)
66 | def main(
67 | batch: int,
68 | m: int,
69 | n: int,
70 | k: List[int],
71 | dtype: str,
72 | gather_dim: int,
73 | scale_mode: str,
74 | cuda_graph: bool,
75 | ):
76 | """
77 | torchrun \
78 | --nnodes 1 --nproc-per-node 8 \
79 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \
80 | --no_python python3 torch_compile_all_gather_matmul.py --cuda-graph
81 | """
82 | os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHE"] = "1"
83 | os.environ["TORCH_SYMM_MEM_ENABLE_NATIVE_ASYNC_TP"] = "1"
84 |
85 | rank = int(os.environ["RANK"])
86 | local_rank = int(os.environ["LOCAL_RANK"])
87 | world_size = int(os.environ["WORLD_SIZE"])
88 |
89 | if rank == 0:
90 | print(f"M={m}, N={n}, K={k}")
91 |
92 | device = torch.device(f"cuda:{local_rank}")
93 | torch.cuda.set_device(device)
94 | torch.manual_seed(42 + rank)
95 |
96 | dist.init_process_group("nccl")
97 | group_name = dist.group.WORLD.group_name
98 | symm_mem.enable_symm_mem_for_group(group_name)
99 |
100 | a_shard = torch.rand(batch, m // world_size, k, dtype=torch.bfloat16, device="cuda")
101 | bs = [torch.rand(N, k, dtype=torch.bfloat16, device="cuda").T for N in n]
102 |
103 | if dtype == "bfloat16":
104 | baseline = functools.partial(
105 | all_gather_matmul, a_shard, bs, gather_dim=gather_dim, group_name=group_name
106 | )
107 | compiled = functools.partial(
108 | compiled_all_gather_matmul,
109 | symm_mem.restride_A_shard_for_fused_all_gather_matmul(
110 | a_shard, gather_dim=gather_dim
111 | ),
112 | bs,
113 | gather_dim=gather_dim,
114 | group_name=group_name,
115 | )
116 |
117 | elif dtype == "float8":
118 | a_shard = a_shard.to(torch.float8_e4m3fn)
119 | bs = [B.to(torch.float8_e4m3fn) for B in bs]
120 |
121 | if scale_mode == "tensor-wise":
122 | A_scale = torch.tensor(0.1, device="cuda")
123 | B_scales = [torch.tensor(0.1, device="cuda") for _ in n]
124 | elif scale_mode == "row-wise":
125 | A_scale = torch.full((batch, m // world_size, 1), 0.1, device="cuda")
126 | B_scales = [torch.full((1, N), 0.1, device="cuda") for N in n]
127 | else:
128 | raise AssertionError(f"Invalid scale_mode: {scale_mode}")
129 |
130 | baseline = functools.partial(
131 | all_gather_scaled_matmul,
132 | a_shard,
133 | bs,
134 | A_scale,
135 | B_scales,
136 | gather_dim=gather_dim,
137 | group_name=group_name,
138 | )
139 | compiled = functools.partial(
140 | compiled_all_gather_scaled_matmul,
141 | symm_mem.restride_A_shard_for_fused_all_gather_matmul(
142 | a_shard, gather_dim=gather_dim
143 | ),
144 | bs,
145 | A_scale,
146 | B_scales,
147 | gather_dim=gather_dim,
148 | group_name=group_name,
149 | )
150 |
151 | else:
152 | raise AssertionError(f"Invalid dtype: {dtype}")
153 |
154 | torch.testing.assert_close(baseline(), compiled())
155 | baseline_us = benchmark_with_event(baseline, flush_l2=True, cuda_graph=cuda_graph)
156 | compiled_us = benchmark_with_event(compiled, flush_l2=True, cuda_graph=cuda_graph)
157 | print(f"baseline: {baseline_us:.2f} us; compiled: {compiled_us:.2f} us")
158 |
159 | dist.destroy_process_group()
160 |
161 |
162 | if __name__ == "__main__":
163 | main()
164 |
--------------------------------------------------------------------------------
/triton_all_gather_matmul.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import click
4 | import torch
5 | import torch.distributed as dist
6 | import torch.distributed._symmetric_memory as symm_mem
7 | import triton
8 | import triton.language as tl
9 | import triton.tools.experimental_descriptor
10 |
11 | from triton_barrier import get_flat_tid
12 | from utils import benchmark_with_event, log_triton_kernel
13 |
14 |
15 | def all_gather_with_progress(
16 | output: torch.Tensor,
17 | inp: torch.Tensor,
18 | progress: torch.Tensor,
19 | splits_per_rank: int,
20 | ):
21 | assert inp.is_contiguous()
22 |
23 | symm_mem_hdl = symm_mem.rendezvous(inp, group=dist.group.WORLD)
24 | assert symm_mem_hdl is not None
25 |
26 | rank = symm_mem_hdl.rank
27 | world_size = symm_mem_hdl.world_size
28 |
29 | assert inp.numel() % splits_per_rank == 0
30 | assert progress.numel() == world_size * splits_per_rank
31 |
32 | output_shape = list(inp.shape)
33 | output_shape[0] *= world_size
34 | assert list(output.shape) == output_shape, (list(output.shape), output_shape)
35 |
36 | chunks = output.chunk(world_size * splits_per_rank)
37 |
38 | for step in range(0, world_size):
39 | src_rank = (rank + step + 1) % world_size
40 | for split_id in range(splits_per_rank):
41 | src_buf = symm_mem_hdl.get_buffer(
42 | src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id
43 | )
44 | chunks[src_rank * splits_per_rank + split_id].copy_(src_buf)
45 | # cuStreamWriteValue32 issues a system level fence before the write
46 | symm_mem_hdl.stream_write_value32(
47 | progress,
48 | offset=src_rank * splits_per_rank + split_id,
49 | val=1,
50 | )
51 | symm_mem_hdl.barrier()
52 |
53 |
54 | def _matmul_launch_metadata(grid, kernel, args):
55 | ret = {}
56 | M, N, K = args["M"], args["N"], args["K"]
57 | ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
58 | ret["flops8"] = 2.0 * M * N * K
59 | if "c_ptr" in args:
60 | bytes_per_elem = args["c_ptr"].element_size()
61 | else:
62 | bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
63 | ret["bytes"] = bytes_per_elem * (M * K + N * K)
64 | return ret
65 |
66 |
67 | @triton.jit
68 | def wait_signal(addr, flat_tid):
69 | if flat_tid == 0:
70 | tl.inline_asm_elementwise(
71 | """
72 | {
73 | .reg .pred %p<1>;
74 |
75 | wait_block:
76 | ld.global.relaxed.gpu.u32 $0, [$1];
77 | setp.eq.u32 %p0, $0, 1;
78 | @!%p0 bra wait_block;
79 | }
80 | """,
81 | "=r, l",
82 | [addr],
83 | dtype=tl.int32,
84 | is_pure=False,
85 | pack=1,
86 | )
87 |
88 | tl.inline_asm_elementwise(
89 | "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
90 | )
91 |
92 |
93 | @triton.jit(launch_metadata=_matmul_launch_metadata)
94 | def matmul_kernel_tma_persistent(
95 | a_shard_desc_ptr,
96 | a_desc_ptr,
97 | b_desc_ptr,
98 | c_desc_ptr,
99 | progress_ptr,
100 | M,
101 | N,
102 | K,
103 | BLOCK_SIZE_M: tl.constexpr,
104 | BLOCK_SIZE_N: tl.constexpr,
105 | BLOCK_SIZE_K: tl.constexpr,
106 | GROUP_SIZE_M: tl.constexpr,
107 | COMM_BLOCK_SIZE_M: tl.constexpr,
108 | RANK: tl.constexpr,
109 | WORLD_SIZE: tl.constexpr,
110 | FP8_OUTPUT: tl.constexpr,
111 | NUM_SMS: tl.constexpr,
112 | ):
113 | """
114 | Slightly modified from the sm90 tma persistent Triton tutorial.
115 | """
116 | flat_tid = get_flat_tid()
117 |
118 | dtype = tl.float8e4nv if FP8_OUTPUT else tl.bfloat16
119 | start_pid = tl.program_id(axis=0)
120 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
121 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
122 | k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
123 | num_tiles = num_pid_m * num_pid_n
124 |
125 | tiles_per_SM = num_tiles // NUM_SMS
126 | if start_pid < num_tiles % NUM_SMS:
127 | tiles_per_SM += 1
128 |
129 | tile_id = start_pid - NUM_SMS
130 | ki = -1
131 |
132 | pid_m = 0
133 | pid_n = 0
134 | offs_am_src = 0
135 | offs_bn = 0
136 | a_ptr = a_desc_ptr
137 |
138 | num_pid_in_group = GROUP_SIZE_M * num_pid_n
139 |
140 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
141 |
142 | for _ in range(0, k_tiles * tiles_per_SM):
143 | ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
144 | if ki == 0:
145 | tile_id += NUM_SMS
146 | group_id = tile_id // num_pid_in_group
147 | first_pid_m = group_id * GROUP_SIZE_M
148 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
149 | pid_m = first_pid_m + (tile_id % group_size_m)
150 | pid_n = (tile_id % num_pid_in_group) // group_size_m
151 |
152 | NUM_COMM_BLOCKS = M // COMM_BLOCK_SIZE_M
153 | NUM_COMM_BLOCKS_PER_RANK = NUM_COMM_BLOCKS // WORLD_SIZE
154 | NUM_PID_M_PER_COMM_BLOCK = COMM_BLOCK_SIZE_M // BLOCK_SIZE_M
155 |
156 | # Pivot tile_id so that M tiles are processed in their ready order.
157 | # This pivot preserves the prior swizzling.
158 | pid_m = (pid_m + NUM_PID_M_PER_COMM_BLOCK * RANK) % num_pid_m
159 |
160 | comm_block_id = pid_m // NUM_PID_M_PER_COMM_BLOCK
161 | if comm_block_id // NUM_COMM_BLOCKS_PER_RANK == RANK:
162 | # Read from the local a_shard
163 | offs_am_src = (pid_m * BLOCK_SIZE_M) % COMM_BLOCK_SIZE_M
164 | a_ptr = a_shard_desc_ptr
165 | else:
166 | # Wait for and read from a_shard copied from remote ranks
167 | wait_signal((progress_ptr + comm_block_id).to(tl.uint64), flat_tid)
168 | offs_am_src = pid_m * BLOCK_SIZE_M
169 | a_ptr = a_desc_ptr
170 |
171 | offs_bn = pid_n * BLOCK_SIZE_N
172 | offs_k = ki * BLOCK_SIZE_K
173 |
174 | a = tl._experimental_descriptor_load(
175 | a_ptr, [offs_am_src, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype
176 | )
177 | b = tl._experimental_descriptor_load(
178 | b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype
179 | )
180 | accumulator = tl.dot(a, b.T, accumulator)
181 |
182 | if ki == k_tiles - 1:
183 | c = accumulator.to(dtype)
184 |
185 | tl._experimental_descriptor_store(
186 | c_desc_ptr, c, [pid_m * BLOCK_SIZE_M, offs_bn]
187 | )
188 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
189 |
190 |
191 | _tma_desc_cache = {}
192 |
193 |
194 | def create_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size):
195 | global _tma_desc_cache
196 | key = (ptr, dim1, dim0, block_dim1, block_dim0, element_size)
197 | if key in _tma_desc_cache:
198 | return _tma_desc_cache[key]
199 | desc = triton.tools.experimental_descriptor.create_2d_tma_descriptor(
200 | ptr,
201 | dim1,
202 | dim0,
203 | block_dim1,
204 | block_dim0,
205 | element_size,
206 | )
207 | _tma_desc_cache[key] = desc
208 | return desc
209 |
210 |
211 | def all_gather_matmul_tma_persistent(
212 | a_shard, b, a_out, c_out, configs, mm_only: bool = False
213 | ):
214 | if mm_only:
215 | rank = 0
216 | world_size = int(os.environ.get("WORLD_SIZE", "8"))
217 | else:
218 | symm_mem_hdl = symm_mem.rendezvous(a_shard, group=dist.group.WORLD)
219 | assert symm_mem_hdl is not None, "a_shard must be allocated via SymmetricMemory"
220 | rank = symm_mem_hdl.rank
221 | world_size = symm_mem_hdl.world_size
222 |
223 | dtype = a_shard.dtype
224 | M = a_shard.shape[0] * world_size
225 | N = b.shape[0]
226 | K = a_shard.shape[1]
227 |
228 | assert b.shape[1] == K
229 | assert a_out.shape[0] == M
230 | assert a_out.shape[1] == K
231 | assert c_out.shape[0] == M
232 | assert c_out.shape[1] == N
233 |
234 | SPLITS_PER_RANK = 1
235 | COMM_BLOCK_SIZE_M = M // world_size // SPLITS_PER_RANK
236 | assert COMM_BLOCK_SIZE_M % (configs["BLOCK_SIZE_M"] * configs["GROUP_SIZE_M"]) == 0
237 |
238 | backend_stream = symm_mem._get_backend_stream(priority=-1)
239 | if mm_only:
240 | progress = torch.ones(world_size, dtype=torch.uint32, device="cuda")
241 | else:
242 | progress = torch.zeros(world_size, dtype=torch.uint32, device="cuda")
243 | symm_mem_hdl.barrier(0)
244 | backend_stream.wait_stream(torch.cuda.current_stream())
245 | with torch.cuda.stream(backend_stream):
246 | all_gather_with_progress(a_out, a_shard, progress, SPLITS_PER_RANK)
247 |
248 | desc_a_shard = create_2d_tma_descriptor(
249 | a_shard.data_ptr(),
250 | a_shard.shape[0],
251 | K,
252 | configs["BLOCK_SIZE_M"],
253 | configs["BLOCK_SIZE_K"],
254 | a_shard.element_size(),
255 | )
256 | desc_a = create_2d_tma_descriptor(
257 | a_out.data_ptr(),
258 | M,
259 | K,
260 | configs["BLOCK_SIZE_M"],
261 | configs["BLOCK_SIZE_K"],
262 | a_out.element_size(),
263 | )
264 | desc_b = create_2d_tma_descriptor(
265 | b.data_ptr(),
266 | N,
267 | K,
268 | configs["BLOCK_SIZE_N"],
269 | configs["BLOCK_SIZE_K"],
270 | b.element_size(),
271 | )
272 | desc_c = create_2d_tma_descriptor(
273 | c_out.data_ptr(),
274 | M,
275 | N,
276 | configs["BLOCK_SIZE_M"],
277 | configs["BLOCK_SIZE_N"],
278 | c_out.element_size(),
279 | )
280 | NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
281 |
282 | grid = lambda META: (
283 | min(
284 | NUM_SMS,
285 | triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
286 | ),
287 | )
288 | kernel = matmul_kernel_tma_persistent[grid](
289 | desc_a_shard,
290 | desc_a,
291 | desc_b,
292 | desc_c,
293 | progress,
294 | M,
295 | N,
296 | K,
297 | BLOCK_SIZE_M=configs["BLOCK_SIZE_M"],
298 | BLOCK_SIZE_N=configs["BLOCK_SIZE_N"],
299 | BLOCK_SIZE_K=configs["BLOCK_SIZE_K"],
300 | GROUP_SIZE_M=configs["GROUP_SIZE_M"],
301 | COMM_BLOCK_SIZE_M=COMM_BLOCK_SIZE_M,
302 | RANK=rank,
303 | WORLD_SIZE=world_size,
304 | FP8_OUTPUT=dtype == torch.float8_e4m3fn,
305 | NUM_SMS=NUM_SMS,
306 | num_stages=configs["num_stages"],
307 | num_warps=configs["num_warps"],
308 | )
309 | log_triton_kernel(kernel)
310 | torch.cuda.current_stream().wait_stream(backend_stream)
311 | return c_out
312 |
313 |
314 | def all_gather_matmul(a_shard, b):
315 | from torch.distributed._functional_collectives import all_gather_tensor
316 |
317 | a = all_gather_tensor(a_shard, 0, "0")
318 | return torch.matmul(a, b)
319 |
320 |
321 | @click.command()
322 | @click.option("--M", default=4096)
323 | @click.option("--N", default=6656)
324 | @click.option("--K", default=16384)
325 | @click.option("--BLOCK_SIZE_M", default=128)
326 | @click.option("--BLOCK_SIZE_N", default=256)
327 | @click.option("--BLOCK_SIZE_K", default=64)
328 | @click.option("--GROUP_SIZE_M", default=4)
329 | @click.option("--num_stages", default=3)
330 | @click.option("--num_warps", default=8)
331 | def main(
332 | m: int,
333 | n: int,
334 | k: int,
335 | block_size_m: int,
336 | block_size_n: int,
337 | block_size_k: int,
338 | group_size_m: int,
339 | num_stages: int,
340 | num_warps: int,
341 | ):
342 | """
343 | torchrun \
344 | --nnodes 1 --nproc-per-node 8 \
345 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \
346 | --no_python python3 triton_all_gather_matmul.py
347 | """
348 | rank = int(os.environ["RANK"])
349 | local_rank = int(os.environ["LOCAL_RANK"])
350 | world_size = int(os.environ["WORLD_SIZE"])
351 |
352 | device = torch.device(f"cuda:{local_rank}")
353 | torch.cuda.set_device(device)
354 | torch.manual_seed(42 + rank)
355 | dist.init_process_group("nccl")
356 |
357 | a_shard = symm_mem.empty(
358 | m // world_size, k, dtype=torch.bfloat16, device=device
359 | ).normal_()
360 | a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
361 | b = torch.randn((k, n), device="cuda", dtype=torch.bfloat16).T.contiguous()
362 | c = torch.randn((m, n), device="cuda", dtype=torch.bfloat16)
363 |
364 | # Autotuner does not work with TMA. Use manual config.
365 | configs = {
366 | "BLOCK_SIZE_M": block_size_m,
367 | "BLOCK_SIZE_N": block_size_n,
368 | "BLOCK_SIZE_K": block_size_k,
369 | "GROUP_SIZE_M": group_size_m,
370 | "num_stages": num_stages,
371 | "num_warps": num_warps,
372 | }
373 |
374 | c0 = all_gather_matmul(a_shard, b.T)
375 | c1 = all_gather_matmul_tma_persistent(a_shard, b, a, c, configs)
376 | assert torch.allclose(c0, c1)
377 |
378 | def rank_0_print(msg):
379 | if rank == 0:
380 | print(msg)
381 |
382 | lat_cublas_mm = benchmark_with_event(
383 | lambda: torch.matmul(a, b.T, out=c), flush_l2=True
384 | )
385 | rank_0_print(f"cublas mm only:\t{round(lat_cublas_mm)} us")
386 |
387 | lat_triton_mm = benchmark_with_event(
388 | lambda: all_gather_matmul_tma_persistent(
389 | a_shard, b, a, c, configs, mm_only=True
390 | ),
391 | flush_l2=True,
392 | )
393 | rank_0_print(f"triton mm only:\t{round(lat_triton_mm)} us")
394 |
395 | lat_cublas_nccl = benchmark_with_event(
396 | lambda: all_gather_matmul(a_shard, b.T), flush_l2=True
397 | )
398 | rank_0_print(f"cublas + nccl:\t{round(lat_cublas_nccl)} us")
399 |
400 | lat_triton_fused = benchmark_with_event(
401 | lambda: all_gather_matmul_tma_persistent(a_shard, b, a, c, configs),
402 | flush_l2=True,
403 | )
404 | rank_0_print(f"triton fused:\t{round(lat_triton_fused)} us")
405 | rank_0_print(f"speedup:\t{lat_cublas_nccl / lat_triton_fused:.02f}x")
406 |
407 | dist.destroy_process_group()
408 |
409 |
410 | if __name__ == "__main__":
411 | main()
412 |
--------------------------------------------------------------------------------
/triton_barrier.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.distributed as dist
5 | import torch.distributed._symmetric_memory as symm_mem
6 | import triton
7 | import triton.language as tl
8 |
9 | from triton_utils import get_flat_bid, get_flat_tid, sync_threads
10 | from utils import log_triton_kernel
11 |
12 |
13 | @triton.jit
14 | def send_signal(addrs, sem: tl.constexpr):
15 | if sem == "relaxed":
16 | tl.inline_asm_elementwise(
17 | """
18 | {
19 | .reg .u32 %tmp32_<1>;
20 | .reg .pred %p<1>;
21 |
22 | send_signal:
23 | atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1;
24 | setp.eq.u32 %p0, %tmp32_0, 0;
25 | @!%p0 bra send_signal;
26 | }
27 | """,
28 | "=r, l",
29 | [addrs],
30 | dtype=tl.int32,
31 | is_pure=False,
32 | pack=1,
33 | )
34 | elif sem == "acq_rel":
35 | tl.inline_asm_elementwise(
36 | """
37 | {
38 | .reg .u32 %tmp32_<1>;
39 | .reg .pred %p<1>;
40 |
41 | send_signal:
42 | atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1;
43 | setp.eq.u32 %p0, %tmp32_0, 0;
44 | @!%p0 bra send_signal;
45 | }
46 | """,
47 | "=r, l",
48 | [addrs],
49 | dtype=tl.int32,
50 | is_pure=False,
51 | pack=1,
52 | )
53 | else:
54 | raise RuntimeError(f"Unrecognized sem: {sem}")
55 |
56 |
57 | @triton.jit
58 | def wait_signal(addrs, sem: tl.constexpr):
59 | if sem == "relaxed":
60 | tl.inline_asm_elementwise(
61 | """
62 | {
63 | .reg .u32 %tmp32_<1>;
64 | .reg .pred %p<1>;
65 |
66 | wait_signal:
67 | atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0;
68 | setp.eq.u32 %p0, %tmp32_0, 1;
69 | @!%p0 bra wait_signal;
70 | }
71 | """,
72 | "=r, l",
73 | [addrs],
74 | dtype=tl.int32,
75 | is_pure=False,
76 | pack=1,
77 | )
78 | elif sem == "acq_rel":
79 | tl.inline_asm_elementwise(
80 | """
81 | {
82 | .reg .u32 %tmp32_<1>;
83 | .reg .pred %p<1>;
84 |
85 | wait_signal:
86 | atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0;
87 | setp.eq.u32 %p0, %tmp32_0, 1;
88 | @!%p0 bra wait_signal;
89 | }
90 | """,
91 | "=r, l",
92 | [addrs],
93 | dtype=tl.int32,
94 | is_pure=False,
95 | pack=1,
96 | )
97 | else:
98 | raise RuntimeError(f"Unrecognized sem: {sem}")
99 |
100 |
101 | @triton.jit
102 | def blockwise_barrier(
103 | signal_pad_ptrs,
104 | block_id,
105 | rank: tl.constexpr,
106 | world_size: tl.constexpr,
107 | sem: tl.constexpr,
108 | ):
109 | """
110 | Synchronizes blocks with matching block_id across participating devices.
111 |
112 | Note: the function itself is not a system level barrier/fence. It is a
113 | building block for expressing different synchronization patterns.
114 |
115 | Pattern 0: Ensures that all writes to symm_mem buffers from previous
116 | kernels across all devices are visible to the current kernel:
117 |
118 | blockwise_barrier(..., sem="relaxed")
119 | sync_threads()
120 |
121 | Pattern 1: Ensures that all writes to symm_mem buffers from the current
122 | block are visible to all remote blocks with matching blockIdx:
123 |
124 | sync_threads()
125 | blockwise_barrier(..., sem="acq_rel")
126 | sync_threads()
127 |
128 | Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
129 | for writing by subsequent kernels across all devices.
130 |
131 | sync_threads()
132 | blockwise_barrier(..., sem="relaxed")
133 |
134 | CUDA graph friendliness:
135 |
136 | This barrier operates through atomic operations on a zero-filled signal
137 | pad, which resets to a zero-filled state after each successful
138 | synchronization. This design eliminates the need for incrementing a
139 | flag from host.
140 | """
141 | if block_id is None:
142 | block_id = get_flat_bid()
143 | flat_tid = get_flat_tid()
144 |
145 | remote_ranks = tl.arange(0, world_size)
146 | signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
147 | remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
148 | tl.pointer_type(tl.uint32)
149 | )
150 | send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
151 |
152 | local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
153 | tl.pointer_type(tl.uint32)
154 | )
155 | wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
156 |
157 | if flat_tid < world_size:
158 | send_signal(send_addrs, sem)
159 | wait_signal(wait_addrs, sem)
160 |
161 |
162 | @triton.jit
163 | def barrier_test_kernel(
164 | signal_pad_ptrs,
165 | rank: tl.constexpr,
166 | world_size: tl.constexpr,
167 | ):
168 | blockwise_barrier(signal_pad_ptrs, None, rank, world_size, "relaxed")
169 | sync_threads()
170 |
171 |
172 | def barrier_test(t: torch.Tensor) -> None:
173 | symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
174 |
175 | kernel = barrier_test_kernel[(32, 1, 1)](
176 | symm_mem_hdl.signal_pad_ptrs_dev,
177 | rank=symm_mem_hdl.rank,
178 | world_size=symm_mem_hdl.world_size,
179 | )
180 | log_triton_kernel(kernel)
181 |
182 | signal_pad = symm_mem_hdl.get_signal_pad(symm_mem_hdl.rank)
183 | assert signal_pad.eq(0).all().item()
184 |
185 |
186 | if __name__ == "__main__":
187 | """
188 | torchrun \
189 | --nnodes 1 --nproc-per-node 8 \
190 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \
191 | --no_python python3 triton_barrier.py
192 | """
193 | rank = int(os.environ["RANK"])
194 | local_rank = int(os.environ["LOCAL_RANK"])
195 | world_size = int(os.environ["WORLD_SIZE"])
196 |
197 | device = torch.device(f"cuda:{local_rank}")
198 | torch.cuda.set_device(device)
199 | dist.init_process_group("nccl")
200 |
201 | t = symm_mem.empty(4096, device=device)
202 | barrier_test(t)
203 |
--------------------------------------------------------------------------------
/triton_multimem_all_reduce.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import torch.distributed._symmetric_memory as symm_mem
4 | import triton
5 | import triton.language as tl
6 |
7 | from triton_barrier import blockwise_barrier
8 | from triton_utils import get_flat_tid, sync_threads
9 | from utils import log_triton_kernel
10 |
11 |
12 | @triton.jit
13 | def multimem_ld_reduce_128(multicast_ptrs, mask):
14 | return tl.inline_asm_elementwise(
15 | """
16 | {
17 | .reg .pred %p0;
18 | setp.eq.s32 %p0, $5, 1;
19 | @!%p0 bra end;
20 | multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {$0, $1, $2, $3}, [$4];
21 | end:
22 | }
23 | """,
24 | "=r,=r,=r,=r,l,r",
25 | args=[multicast_ptrs, mask.to(tl.int32)],
26 | dtype=(tl.uint32, tl.uint32, tl.uint32, tl.uint32),
27 | is_pure=True,
28 | pack=1,
29 | )
30 |
31 |
32 | @triton.jit
33 | def multimem_st_128(multicast_ptrs, x, y, z, w, mask):
34 | return tl.inline_asm_elementwise(
35 | """
36 | {
37 | .reg .pred %p0;
38 | setp.eq.s32 %p0, $6, 1;
39 | @!%p0 bra end;
40 | multimem.st.relaxed.sys.global.v4.f32 [$1], {$2, $3, $4, $5};
41 | end:
42 | }
43 | """,
44 | "=r,l,r,r,r,r,r",
45 | args=[multicast_ptrs, x, y, z, w, mask.to(tl.int32)],
46 | dtype=(tl.uint32),
47 | is_pure=False,
48 | pack=1,
49 | )
50 |
51 |
52 | @triton.jit
53 | def multimem_all_reduce_kernel(
54 | multicast_ptr,
55 | signal_pad_ptrs,
56 | numel,
57 | BLOCK_SIZE: tl.constexpr,
58 | NUMEL_PER_THREAD: tl.constexpr,
59 | RANK: tl.constexpr,
60 | WORLD_SIZE: tl.constexpr,
61 | ):
62 | blockwise_barrier(signal_pad_ptrs, None, RANK, WORLD_SIZE, sem="relaxed")
63 | sync_threads()
64 |
65 | pid = tl.program_id(axis=0)
66 | tid = get_flat_tid()
67 |
68 | # From this point on, we pretend each element is 128-bit
69 | numel = numel // NUMEL_PER_THREAD
70 | numel_per_rank = tl.cdiv(numel, WORLD_SIZE)
71 | block_start = pid * BLOCK_SIZE
72 |
73 | while block_start < numel_per_rank:
74 | offsets = block_start + tid
75 | mask = offsets < numel_per_rank
76 |
77 | # Each pointer points to a 128-bit bit pack
78 | ptrs = (
79 | multicast_ptr.to(tl.pointer_type(tl.uint64))
80 | + (RANK * numel_per_rank + offsets) * 2
81 | )
82 | (x, y, z, w) = multimem_ld_reduce_128(ptrs, mask=mask)
83 | multimem_st_128(ptrs, x, y, z, w, mask=mask)
84 |
85 | block_start += tl.num_programs(axis=0) * BLOCK_SIZE
86 |
87 | sync_threads()
88 | blockwise_barrier(signal_pad_ptrs, None, RANK, WORLD_SIZE, sem="acq_rel")
89 |
90 |
91 | def multimem_all_reduce(tensor: torch.Tensor):
92 | WARP_SIZE = 32
93 | MAX_NUM_BLOCKS = 4
94 | MAX_BLOCK_SIZE = 1024
95 | BYTES_PER_THREAD = 16
96 |
97 | symm_mem_hdl = symm_mem.rendezvous(tensor, group=dist.group.WORLD)
98 |
99 | assert tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now."
100 | numel_per_thread = BYTES_PER_THREAD // tensor.element_size()
101 |
102 | assert (
103 | tensor.numel() % numel_per_thread == 0
104 | ), "The number of elements must be 128-bit aligned."
105 |
106 | num_threads = triton.cdiv(
107 | tensor.numel() // numel_per_thread, symm_mem_hdl.world_size
108 | )
109 | if num_threads < MAX_BLOCK_SIZE:
110 | block_size = 1
111 | while block_size < num_threads:
112 | block_size *= 2
113 | num_warps = block_size // WARP_SIZE
114 | num_blocks = 1
115 | else:
116 | block_size = MAX_BLOCK_SIZE
117 | num_warps = MAX_BLOCK_SIZE // WARP_SIZE
118 | num_blocks = min(
119 | triton.cdiv(num_threads, MAX_BLOCK_SIZE),
120 | MAX_NUM_BLOCKS,
121 | )
122 |
123 | kernel = multimem_all_reduce_kernel[(num_blocks, 1, 1)](
124 | symm_mem_hdl.multicast_ptr,
125 | symm_mem_hdl.signal_pad_ptrs_dev,
126 | numel=tensor.numel(),
127 | BLOCK_SIZE=block_size,
128 | NUMEL_PER_THREAD=numel_per_thread,
129 | RANK=symm_mem_hdl.rank,
130 | WORLD_SIZE=symm_mem_hdl.world_size,
131 | num_warps=num_warps,
132 | )
133 | log_triton_kernel(kernel)
134 | return tensor
135 |
136 |
137 | if __name__ == "__main__":
138 | """
139 | torchrun \
140 | --nnodes 1 --nproc-per-node 8 \
141 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \
142 | --no_python python3 triton_multimem_all_reduce.py
143 | """
144 | from symm_mem_all_reduce import main
145 |
146 | main(["--impl", "triton_multimem_all_reduce"])
147 |
--------------------------------------------------------------------------------
/triton_one_shot_all_reduce.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import torch.distributed._symmetric_memory as symm_mem
4 | import triton
5 | import triton.language as tl
6 |
7 | from triton_barrier import blockwise_barrier
8 | from triton_utils import sync_threads
9 | from utils import log_triton_kernel
10 |
11 |
12 | @triton.jit
13 | def load_128(addrs, mask):
14 | return tl.inline_asm_elementwise(
15 | """
16 | {
17 | .reg .pred %p0;
18 | setp.eq.s32 %p0, $3, 1;
19 | @%p0 ld.global.v2.u64 {$0, $1}, [$2];
20 | }
21 | """,
22 | "=l,=l,l,r",
23 | args=[addrs, mask.to(tl.int32)],
24 | dtype=(tl.uint64, tl.uint64),
25 | is_pure=True,
26 | pack=1,
27 | )
28 |
29 |
30 | @triton.jit
31 | def add_v8_bf16(a_hi, a_lo, b_hi, b_lo):
32 | return tl.inline_asm_elementwise(
33 | """
34 | {
35 | .reg .v4 .b32 %acc, %tmp;
36 | mov.v4.b32 %acc, 0;
37 | mov.b64 {%acc.x, %acc.y}, $2;
38 | mov.b64 {%acc.z, %acc.w}, $3;
39 | mov.b64 {%tmp.x, %tmp.y}, $4;
40 | mov.b64 {%tmp.z, %tmp.w}, $5;
41 | add.bf16x2 %acc.x, %acc.x, %tmp.x;
42 | add.bf16x2 %acc.y, %acc.y, %tmp.y;
43 | add.bf16x2 %acc.z, %acc.z, %tmp.z;
44 | add.bf16x2 %acc.w, %acc.w, %tmp.w;
45 | mov.b64 $0, {%acc.x, %acc.y};
46 | mov.b64 $1, {%acc.z, %acc.w};
47 | }
48 | """,
49 | "=l,=l,l,l,l,l",
50 | args=[a_hi, a_lo, b_hi, b_lo],
51 | dtype=(tl.uint64, tl.uint64),
52 | is_pure=True,
53 | pack=1,
54 | )
55 |
56 |
57 | @triton.jit
58 | def one_shot_all_reduce_kernel(
59 | buffer_ptrs,
60 | signal_pad_ptrs,
61 | output_ptr,
62 | numel: tl.constexpr,
63 | rank: tl.constexpr,
64 | world_size: tl.constexpr,
65 | BLOCK_SIZE: tl.constexpr,
66 | NUMEL_PER_THREAD: tl.constexpr,
67 | ):
68 | blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
69 | sync_threads()
70 |
71 | pid = tl.program_id(axis=0)
72 |
73 | buffer_ptrs = buffer_ptrs.to(tl.pointer_type(tl.uint64))
74 | output_ptr = output_ptr.to(tl.pointer_type(tl.uint64))
75 | block_start = pid * BLOCK_SIZE
76 |
77 | while block_start < (numel // NUMEL_PER_THREAD):
78 | # Each thread processes 128 bits. Since Triton doesn't yet natively
79 | # support 128-bit dtypes, we achieve this by having each thread process
80 | # two 64-bit elements.
81 | offsets = (block_start + tl.arange(0, BLOCK_SIZE)) * 2
82 | mask = block_start + tl.arange(0, BLOCK_SIZE) < numel // NUMEL_PER_THREAD
83 |
84 | acc_hi = tl.zeros((BLOCK_SIZE,), tl.uint64)
85 | acc_lo = tl.zeros((BLOCK_SIZE,), tl.uint64)
86 | for i in range(world_size):
87 | buffer_ptr = tl.load(buffer_ptrs + i).to(tl.pointer_type(tl.uint64))
88 | (hi, lo) = load_128(buffer_ptr + offsets, mask=mask)
89 | (acc_hi, acc_lo) = add_v8_bf16(acc_hi, acc_lo, hi, lo)
90 |
91 | tl.store(output_ptr + offsets + 0, acc_hi, mask=mask)
92 | tl.store(output_ptr + offsets + 1, acc_lo, mask=mask)
93 | block_start += tl.num_programs(axis=0) * BLOCK_SIZE
94 |
95 | sync_threads()
96 | blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
97 |
98 |
99 | def one_shot_all_reduce(tensor: torch.Tensor):
100 | MAX_NUM_BLOCKS = 24
101 | NUM_WARPS = 16
102 | BLOCK_SIZE = NUM_WARPS * 32
103 | NUMEL_PER_THREAD = 8
104 |
105 | assert tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now."
106 | assert (
107 | tensor.numel() % NUMEL_PER_THREAD == 0
108 | ), "The number of elements must be 128-bit aligned."
109 | num_blocks = min(
110 | triton.cdiv(triton.cdiv(tensor.numel(), NUMEL_PER_THREAD), BLOCK_SIZE),
111 | MAX_NUM_BLOCKS,
112 | )
113 |
114 | symm_mem_hdl = symm_mem.rendezvous(tensor, group=dist.group.WORLD)
115 | output = torch.empty_like(tensor)
116 |
117 | kernel = one_shot_all_reduce_kernel[(num_blocks, 1, 1)](
118 | symm_mem_hdl.buffer_ptrs_dev,
119 | symm_mem_hdl.signal_pad_ptrs_dev,
120 | output,
121 | numel=tensor.numel(),
122 | rank=symm_mem_hdl.rank,
123 | world_size=symm_mem_hdl.world_size,
124 | BLOCK_SIZE=BLOCK_SIZE,
125 | NUMEL_PER_THREAD=NUMEL_PER_THREAD,
126 | num_warps=NUM_WARPS,
127 | )
128 | log_triton_kernel(kernel)
129 | return output
130 |
131 |
132 | if __name__ == "__main__":
133 | """
134 | torchrun \
135 | --nnodes 1 --nproc-per-node 8 \
136 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \
137 | --no_python python3 triton_one_shot_all_reduce.py
138 | """
139 | from symm_mem_all_reduce import main
140 |
141 | main(["--impl", "triton_one_shot_all_reduce"])
142 |
--------------------------------------------------------------------------------
/triton_utils.py:
--------------------------------------------------------------------------------
1 | import triton
2 | import triton.language as tl
3 |
4 |
5 | @triton.jit
6 | def get_tid():
7 | return tl.inline_asm_elementwise(
8 | """
9 | mov.u32 $0, %tid.x;
10 | mov.u32 $1, %tid.y;
11 | mov.u32 $2, %tid.z;
12 | """,
13 | "=r,=r,=r",
14 | [],
15 | dtype=(tl.uint32, tl.uint32, tl.uint32),
16 | is_pure=True,
17 | pack=1,
18 | )
19 |
20 |
21 | @triton.jit
22 | def get_ntid():
23 | return tl.inline_asm_elementwise(
24 | """
25 | mov.u32 $0, %ntid.x;
26 | mov.u32 $1, %ntid.y;
27 | mov.u32 $2, %ntid.z;
28 | """,
29 | "=r,=r,=r",
30 | [],
31 | dtype=(tl.uint32, tl.uint32, tl.uint32),
32 | is_pure=True,
33 | pack=1,
34 | )
35 |
36 |
37 | @triton.jit
38 | def get_flat_tid():
39 | tid_x, tid_y, tid_z = get_tid()
40 | ntid_x, ntid_y, _ = get_ntid()
41 | return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
42 |
43 |
44 | @triton.jit
45 | def get_flat_bid():
46 | return (
47 | tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0)
48 | + tl.program_id(1) * tl.num_programs(0)
49 | + tl.program_id(0)
50 | )
51 |
52 |
53 | @triton.jit
54 | def sync_threads():
55 | tl.inline_asm_elementwise(
56 | "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
57 | )
58 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from contextlib import nullcontext
4 | from typing import Callable, List, Optional
5 |
6 | import torch
7 | import torch.distributed as dist
8 |
9 |
10 | def benchmark_with_profiler(
11 | target_fn: Callable[[None], None],
12 | event_key_regex: str,
13 | warmup_iters: int = 200,
14 | benchmark_iters: int = 25,
15 | profile_ranks: Optional[List[int]] = None,
16 | flush_l2: bool = False,
17 | ) -> float:
18 | """
19 | Benchmark the target function with PyTorch profiler.
20 |
21 | Args:
22 | target_fn: The target function to benchmark.
23 | event_key_regex: The regex pattern to identify the profiler event
24 | associated with the target function.
25 | profile_ranks: The ranks to profile.
26 | warmup_iters: The number of warmup iterations.
27 | benchmark_iters: The number of benchmark iterations.
28 | flush_l2: Whether to flush the L2 cache before each invocation of the
29 | target function.
30 |
31 | Returns:
32 | The measured median latency in microseconds.
33 | """
34 | if "BENCHMARK_ITERS" in os.environ:
35 | benchmark_iters = int(os.environ["BENCHMARK_ITERS"])
36 |
37 | rank = dist.get_rank() if dist.is_initialized() else 0
38 | profile_ranks = profile_ranks or [0]
39 |
40 | if flush_l2:
41 | cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
42 |
43 | if rank in profile_ranks:
44 | try:
45 | from trace_handler import trace_handler
46 | except ImportError:
47 | trace_handler = None
48 |
49 | if "NO_TRACE" in os.environ:
50 | trace_handler = None
51 |
52 | prof = torch.profiler.profile(
53 | activities=[
54 | torch.profiler.ProfilerActivity.CPU,
55 | torch.profiler.ProfilerActivity.CUDA,
56 | ],
57 | on_trace_ready=trace_handler,
58 | )
59 | else:
60 | prof = nullcontext()
61 |
62 | for _ in range(warmup_iters):
63 | target_fn()
64 |
65 | if dist.is_initialized():
66 | dist.barrier(device_ids=[torch.cuda.current_device()])
67 | torch.cuda.synchronize()
68 |
69 | with prof:
70 | torch.cuda._sleep(int(2e7))
71 | for i in range(benchmark_iters):
72 | if flush_l2:
73 | cache.zero_()
74 | target_fn()
75 | torch.cuda.synchronize()
76 |
77 | if rank not in profile_ranks:
78 | return 0
79 |
80 | latencies_us = []
81 | for event in prof.events():
82 | if re.match(event_key_regex, event.key):
83 | latencies_us.append(event.device_time)
84 |
85 | if len(latencies_us) == 0:
86 | return 0
87 |
88 | return torch.tensor(latencies_us).median().item()
89 |
90 |
91 | def benchmark_with_event(
92 | target_fn: Callable[[None], None],
93 | warmup_iters: int = 200,
94 | benchmark_iters: int = 25,
95 | profile_ranks: Optional[List[int]] = None,
96 | flush_l2: bool = False,
97 | cuda_graph: bool = False,
98 | ) -> float:
99 | if cuda_graph:
100 | target_fn()
101 | g = torch.cuda.CUDAGraph()
102 | with torch.cuda.graph(g):
103 | target_fn()
104 | target_fn = lambda: g.replay()
105 |
106 | if "BENCHMARK_ITERS" in os.environ:
107 | benchmark_iters = int(os.environ["BENCHMARK_ITERS"])
108 |
109 | rank = dist.get_rank() if dist.is_initialized() else 0
110 | profile_ranks = profile_ranks or [0]
111 |
112 | if flush_l2:
113 | cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
114 |
115 | for _ in range(warmup_iters):
116 | target_fn()
117 |
118 | if dist.is_initialized():
119 | dist.barrier(device_ids=[torch.cuda.current_device()])
120 | torch.cuda.synchronize()
121 |
122 | begin_events = [
123 | torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters)
124 | ]
125 | end_events = [torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters)]
126 |
127 | if rank in profile_ranks:
128 | try:
129 | from trace_handler import trace_handler
130 | except ImportError:
131 | trace_handler = None
132 |
133 | if "NO_TRACE" in os.environ:
134 | trace_handler = None
135 |
136 | prof = torch.profiler.profile(
137 | activities=[
138 | torch.profiler.ProfilerActivity.CPU,
139 | torch.profiler.ProfilerActivity.CUDA,
140 | ],
141 | on_trace_ready=trace_handler,
142 | )
143 | else:
144 | prof = nullcontext()
145 |
146 | with prof:
147 | torch.cuda._sleep(int(2e7))
148 | for i in range(benchmark_iters):
149 | if flush_l2:
150 | cache.zero_()
151 | begin_events[i].record()
152 | target_fn()
153 | end_events[i].record()
154 | torch.cuda.synchronize()
155 |
156 | latencies = [b.elapsed_time(e) for b, e in zip(begin_events, end_events)]
157 | return torch.tensor(latencies).median().item() * 1000
158 |
159 |
160 | triton_kernels = {}
161 |
162 |
163 | def log_triton_kernel(kernel):
164 | import atexit
165 | import tempfile
166 |
167 | if dist.is_initialized() and dist.get_rank() != 0:
168 | return
169 |
170 | def on_exit():
171 | print("PTX files:")
172 | for kernel in triton_kernels:
173 | f = tempfile.NamedTemporaryFile(dir="/tmp", delete=False)
174 | f.write(kernel.asm["ptx"].encode("utf-8"))
175 | print(f"+- {kernel.name}: {f.name}")
176 |
177 | if len(triton_kernels) == 0:
178 | atexit.register(on_exit)
179 |
180 | if kernel not in triton_kernels:
181 | triton_kernels[kernel] = None
182 |
--------------------------------------------------------------------------------