├── .flake8 ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── pyproject.toml ├── symm_mem_all_reduce.py ├── torch_compile_all_gather_matmul.py ├── triton_all_gather_matmul.py ├── triton_barrier.py ├── triton_multimem_all_reduce.py ├── triton_one_shot_all_reduce.py ├── triton_utils.py └── utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | ignore = E203, E501, E731, W503 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.8.0 4 | hooks: 5 | - id: black 6 | 7 | - repo: https://github.com/pycqa/isort 8 | rev: 5.13.2 9 | hooks: 10 | - id: isort 11 | 12 | - repo: https://github.com/pycqa/flake8 13 | rev: 7.0.0 14 | hooks: 15 | - id: flake8 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # symm-mem-recipes 2 | 3 | This repository includes: 4 | - Usage and benchmarks of `SymmetricMemory`-based multi-GPU algorithms in PyTorch. 5 | - Examples and benchmarks of multi-GPU algorithms built with `SymmetricMemory` + Triton. 6 | 7 | --- 8 | ## symm_mem_all_reduce.py 9 | 10 | This script demonstrates the usage of `SymmetricMemory`-based NVLink all-reduce implementations and benchmarks their performance. The available variants are: 11 | - `multimem_all_reduce` (PyTorch op available in nightly) 12 | - `one_shot_all_reduce` (PyTorch op available in nightly) 13 | - `two_shot_all_reduce` (PyTorch op available in nightly) 14 | - `triton_multimem_all_reduce` (Triton kernel defined in this repo) 15 | - `triton_one_shot_all_reduce` (Triton kernel defined in this repo) 16 | 17 | Usage: 18 | ```bash 19 | torchrun \ 20 | --nnodes 1 --nproc-per-node 8 \ 21 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \ 22 | --no_python python3 symm_mem_all_reduce.py --impl multimem_all_reduce 23 | ``` 24 | 25 | Some benchmarks on 8xH100 with NVSwitch: 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | --- 37 | ## triton_all_gather_matmul.py 38 | 39 | This is a fused all-gather matmul example using Triton + `SymmetricMemory`, based on the `tma_persistent` Triton tutorial with slight modifications. 40 | 41 | This example requires PyTorch Nightly and Triton 3.0.0+ to run. 42 | 43 | Usage: 44 | ```bash 45 | torchrun \ 46 | --nnodes 1 --nproc-per-node 8 \ 47 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \ 48 | --no_python python3 triton_all_gather_matmul.py \ 49 | --M 16384 --N 6656 --K 16384 --BLOCK_SIZE_M 128 --BLOCK_SIZE_N 256 --BLOCK_SIZE_K 64 50 | ``` 51 | 52 | Some benchmarks on 8xH100 (special version with HBM2e, at 650W) with NVSwitch: 53 | 54 | #### Llama 3 8B (N=1792, K=4096) 55 | | Problem Size
(M) | Config1 | cuBLAS MM
Only (µs) | Triton MM
Only (µs) | cuBLAS +
NCCL (µs) | Triton
Fused (µs) | Speedup | 56 | |------------|------------|------------|------------|------------|------------|------------| 57 | | 4096 | 64,128,128,4 | 100 | 142 | 223 | 211 | 1.05x2 | 58 | | 8192 | 128,128,64,6 | 186 | 198 | 393 | 293 | 1.34x | 59 | | 16384 | 128,256,64,3 | 363 | 363 | 748 | 485 | 1.54x | 60 | 61 | #### Llama 3 70B (N=3584, K=8192) 62 | | Problem Size
(M) | Config1 | cuBLAS MM
Only (µs) | Triton MM
Only (µs) | cuBLAS +
NCCL (µs) | Triton
Fused (µs) | Speedup | 63 | |------------|------------|------------|------------|------------|------------|------------| 64 | | 4096 | 128,128,64,6 | 376 | 392 | 587 | 453 | 1.29x | 65 | | 8192 | 128,256,64,3 | 746 | 706 | 1168 | 821 | 1.42x | 66 | | 16384 | 128,256,64,3 | 1502 | 1403 | 2306 | 1566 | 1.47x | 67 | 68 | #### Llama 3 105B (N=6656, K=16384) 69 | | Problem Size
(M) | Config1 | cuBLAS MM
Only (µs) | Triton MM
Only (µs) | cuBLAS +
NCCL (µs) | Triton
Fused (µs) | Speedup | 70 | |------------|------------|------------|------------|------------|------------|------------| 71 | | 4096 | 128,256,64,3 | 1358 | 1425 | 1858 | 1615 | 1.15x | 72 | | 8192 | 128,256,64,3 | 2567 | 2656 | 3533 | 2907 | 1.22x | 73 | | 16384 | 128,256,64,3 | 5249 | 5375 | 6982 | 5814 | 1.20x | 74 | 75 | 1 Config refers to `BLOCK_SIZE_M`, `BLOCK_SIZE_N`, `BLOCK_SIZE_K`, and `num_stages`. 76 | 77 | 2 For this problem size, using multicast all-gather would be a more suitable optimization. 78 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | target-version = ['py38'] 4 | 5 | [tool.isort] 6 | profile = "black" 7 | atomic = true 8 | combine_as_imports = true 9 | include_trailing_comma = true 10 | indent = 4 11 | line_length = 88 12 | lines_after_imports = 2 13 | multi_line_output = 3 14 | -------------------------------------------------------------------------------- /symm_mem_all_reduce.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | 4 | import click 5 | import torch 6 | import torch.distributed as dist 7 | import torch.distributed._symmetric_memory as symm_mem 8 | 9 | from utils import benchmark_with_profiler 10 | 11 | 12 | def multimem_all_reduce(msg): 13 | torch.ops.symm_mem.multimem_all_reduce_( 14 | msg, 15 | "sum", 16 | dist.group.WORLD.group_name, 17 | ) 18 | 19 | 20 | def one_shot_all_reduce(msg): 21 | torch.ops.symm_mem.one_shot_all_reduce( 22 | msg, 23 | "sum", 24 | dist.group.WORLD.group_name, 25 | ) 26 | 27 | 28 | def two_shot_all_reduce(msg): 29 | torch.ops.symm_mem.two_shot_all_reduce_( 30 | msg, 31 | "sum", 32 | dist.group.WORLD.group_name, 33 | ) 34 | 35 | 36 | def triton_multimem_all_reduce(msg): 37 | from triton_multimem_all_reduce import multimem_all_reduce 38 | 39 | multimem_all_reduce(msg) 40 | 41 | 42 | def triton_one_shot_all_reduce(msg): 43 | from triton_one_shot_all_reduce import one_shot_all_reduce 44 | 45 | one_shot_all_reduce(msg) 46 | 47 | 48 | def get_impl(impl: str): 49 | if impl == "multimem_all_reduce": 50 | return multimem_all_reduce 51 | elif impl == "one_shot_all_reduce": 52 | return one_shot_all_reduce 53 | elif impl == "two_shot_all_reduce": 54 | return two_shot_all_reduce 55 | elif impl == "triton_multimem_all_reduce": 56 | return triton_multimem_all_reduce 57 | elif impl == "triton_one_shot_all_reduce": 58 | return triton_one_shot_all_reduce 59 | else: 60 | raise NotImplementedError(impl) 61 | 62 | 63 | def benchmark(device: torch.device, impl: str, msg_sz_bytes: int): 64 | msg_numel = msg_sz_bytes // torch.bfloat16.itemsize 65 | msg = symm_mem.empty( 66 | msg_numel, 67 | dtype=torch.bfloat16, 68 | device=device, 69 | ) 70 | symm_mem.rendezvous(msg, dist.group.WORLD.group_name) 71 | 72 | target_fn = functools.partial(get_impl(impl), msg) 73 | baseline_fn = functools.partial(dist.all_reduce, msg) 74 | 75 | target_us = benchmark_with_profiler( 76 | target_fn, ".*all_reduce.*", benchmark_iters=200 77 | ) 78 | baseline_us = benchmark_with_profiler( 79 | baseline_fn, ".*AllReduce.*", benchmark_iters=200 80 | ) 81 | if dist.get_rank() == 0: 82 | print( 83 | f"msg_sz_bytes: {msg_sz_bytes}\t" 84 | f"nccl_ring: {baseline_us:.2f} us\t" 85 | f"{impl}: {target_us:.2f} us\t" 86 | ) 87 | 88 | 89 | @click.command() 90 | @click.option( 91 | "--impl", 92 | help="Valid options: multimem_all_reduce, one_shot_all_reduce, two_shot_all_reduce, triton_multimem_all_reduce, triton_one_shot_all_reduce", 93 | default="multimem_all_reduce", 94 | ) 95 | def main(impl: str): 96 | """ 97 | Benchmark for the symmetric memory-based all-reduce variants. 98 | NVSwitch is required. 99 | 100 | torchrun \ 101 | --nnodes 1 --nproc-per-node 8 \ 102 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \ 103 | --no_python python3 symm_mem_all_reduce.py 104 | """ 105 | local_rank = int(os.environ["LOCAL_RANK"]) 106 | 107 | device = torch.device(f"cuda:{local_rank}") 108 | torch.cuda.set_device(device) 109 | dist.init_process_group("nccl") 110 | torch.manual_seed(42 + local_rank) 111 | 112 | if dist.get_rank() == 0: 113 | print(f"Benchmarking {impl}...") 114 | 115 | msg_sizes = [2**exp for exp in range(12, 21)] 116 | for msg_sz_bytes in msg_sizes: 117 | benchmark(device, impl, msg_sz_bytes) 118 | 119 | dist.destroy_process_group() 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /torch_compile_all_gather_matmul.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | from typing import List 4 | 5 | import click 6 | import torch 7 | import torch.distributed as dist 8 | import torch.distributed._symmetric_memory as symm_mem 9 | from torch.distributed._functional_collectives import all_gather_tensor 10 | 11 | from utils import benchmark_with_event 12 | 13 | 14 | def parse_csv(ctx, param, value): 15 | return [int(num) for num in value.split(",")] 16 | 17 | 18 | def all_gather_matmul(a_shard, bs, gather_dim, group_name): 19 | a = all_gather_tensor(a_shard.contiguous(), gather_dim=gather_dim, group=group_name) 20 | return [torch.matmul(a, b) for b in bs] 21 | 22 | 23 | compiled_all_gather_matmul = torch.compile( 24 | options={ 25 | "_micro_pipeline_tp": True, 26 | "keep_output_stride": False, 27 | }, 28 | fullgraph=True, 29 | )(all_gather_matmul) 30 | 31 | 32 | def scaled_matmul(a, b, a_scale, b_scale, **kwargs): 33 | leading_dims = a.shape[:-1] 34 | c = torch._scaled_mm(a.flatten(0, -2), b, a_scale, b_scale, **kwargs) 35 | return c.view(*leading_dims, -1) 36 | 37 | 38 | def all_gather_scaled_matmul(a_shard, bs, a_scale, b_scales, gather_dim, group_name): 39 | a = all_gather_tensor(a_shard.contiguous(), gather_dim=gather_dim, group=group_name) 40 | return [ 41 | scaled_matmul( 42 | a, b, a_scale, b_scale, out_dtype=torch.bfloat16, use_fast_accum=True 43 | ) 44 | for b, b_scale in zip(bs, b_scales) 45 | ] 46 | 47 | 48 | compiled_all_gather_scaled_matmul = torch.compile( 49 | options={ 50 | "_micro_pipeline_tp": True, 51 | "keep_output_stride": False, 52 | }, 53 | fullgraph=True, 54 | )(all_gather_scaled_matmul) 55 | 56 | 57 | @click.command() 58 | @click.option("--batch", default=1) 59 | @click.option("--M", default=8192) 60 | @click.option("--N", callback=parse_csv, default="3584") 61 | @click.option("--K", default=8192) 62 | @click.option("--dtype", default="bfloat16") 63 | @click.option("--gather-dim", default=0) 64 | @click.option("--scale-mode", default="tensor-wise") 65 | @click.option("--cuda-graph", is_flag=True, default=False) 66 | def main( 67 | batch: int, 68 | m: int, 69 | n: int, 70 | k: List[int], 71 | dtype: str, 72 | gather_dim: int, 73 | scale_mode: str, 74 | cuda_graph: bool, 75 | ): 76 | """ 77 | torchrun \ 78 | --nnodes 1 --nproc-per-node 8 \ 79 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \ 80 | --no_python python3 torch_compile_all_gather_matmul.py --cuda-graph 81 | """ 82 | os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHE"] = "1" 83 | os.environ["TORCH_SYMM_MEM_ENABLE_NATIVE_ASYNC_TP"] = "1" 84 | 85 | rank = int(os.environ["RANK"]) 86 | local_rank = int(os.environ["LOCAL_RANK"]) 87 | world_size = int(os.environ["WORLD_SIZE"]) 88 | 89 | if rank == 0: 90 | print(f"M={m}, N={n}, K={k}") 91 | 92 | device = torch.device(f"cuda:{local_rank}") 93 | torch.cuda.set_device(device) 94 | torch.manual_seed(42 + rank) 95 | 96 | dist.init_process_group("nccl") 97 | group_name = dist.group.WORLD.group_name 98 | symm_mem.enable_symm_mem_for_group(group_name) 99 | 100 | a_shard = torch.rand(batch, m // world_size, k, dtype=torch.bfloat16, device="cuda") 101 | bs = [torch.rand(N, k, dtype=torch.bfloat16, device="cuda").T for N in n] 102 | 103 | if dtype == "bfloat16": 104 | baseline = functools.partial( 105 | all_gather_matmul, a_shard, bs, gather_dim=gather_dim, group_name=group_name 106 | ) 107 | compiled = functools.partial( 108 | compiled_all_gather_matmul, 109 | symm_mem.restride_A_shard_for_fused_all_gather_matmul( 110 | a_shard, gather_dim=gather_dim 111 | ), 112 | bs, 113 | gather_dim=gather_dim, 114 | group_name=group_name, 115 | ) 116 | 117 | elif dtype == "float8": 118 | a_shard = a_shard.to(torch.float8_e4m3fn) 119 | bs = [B.to(torch.float8_e4m3fn) for B in bs] 120 | 121 | if scale_mode == "tensor-wise": 122 | A_scale = torch.tensor(0.1, device="cuda") 123 | B_scales = [torch.tensor(0.1, device="cuda") for _ in n] 124 | elif scale_mode == "row-wise": 125 | A_scale = torch.full((batch, m // world_size, 1), 0.1, device="cuda") 126 | B_scales = [torch.full((1, N), 0.1, device="cuda") for N in n] 127 | else: 128 | raise AssertionError(f"Invalid scale_mode: {scale_mode}") 129 | 130 | baseline = functools.partial( 131 | all_gather_scaled_matmul, 132 | a_shard, 133 | bs, 134 | A_scale, 135 | B_scales, 136 | gather_dim=gather_dim, 137 | group_name=group_name, 138 | ) 139 | compiled = functools.partial( 140 | compiled_all_gather_scaled_matmul, 141 | symm_mem.restride_A_shard_for_fused_all_gather_matmul( 142 | a_shard, gather_dim=gather_dim 143 | ), 144 | bs, 145 | A_scale, 146 | B_scales, 147 | gather_dim=gather_dim, 148 | group_name=group_name, 149 | ) 150 | 151 | else: 152 | raise AssertionError(f"Invalid dtype: {dtype}") 153 | 154 | torch.testing.assert_close(baseline(), compiled()) 155 | baseline_us = benchmark_with_event(baseline, flush_l2=True, cuda_graph=cuda_graph) 156 | compiled_us = benchmark_with_event(compiled, flush_l2=True, cuda_graph=cuda_graph) 157 | print(f"baseline: {baseline_us:.2f} us; compiled: {compiled_us:.2f} us") 158 | 159 | dist.destroy_process_group() 160 | 161 | 162 | if __name__ == "__main__": 163 | main() 164 | -------------------------------------------------------------------------------- /triton_all_gather_matmul.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import click 4 | import torch 5 | import torch.distributed as dist 6 | import torch.distributed._symmetric_memory as symm_mem 7 | import triton 8 | import triton.language as tl 9 | import triton.tools.experimental_descriptor 10 | 11 | from triton_barrier import get_flat_tid 12 | from utils import benchmark_with_event, log_triton_kernel 13 | 14 | 15 | def all_gather_with_progress( 16 | output: torch.Tensor, 17 | inp: torch.Tensor, 18 | progress: torch.Tensor, 19 | splits_per_rank: int, 20 | ): 21 | assert inp.is_contiguous() 22 | 23 | symm_mem_hdl = symm_mem.rendezvous(inp, group=dist.group.WORLD) 24 | assert symm_mem_hdl is not None 25 | 26 | rank = symm_mem_hdl.rank 27 | world_size = symm_mem_hdl.world_size 28 | 29 | assert inp.numel() % splits_per_rank == 0 30 | assert progress.numel() == world_size * splits_per_rank 31 | 32 | output_shape = list(inp.shape) 33 | output_shape[0] *= world_size 34 | assert list(output.shape) == output_shape, (list(output.shape), output_shape) 35 | 36 | chunks = output.chunk(world_size * splits_per_rank) 37 | 38 | for step in range(0, world_size): 39 | src_rank = (rank + step + 1) % world_size 40 | for split_id in range(splits_per_rank): 41 | src_buf = symm_mem_hdl.get_buffer( 42 | src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id 43 | ) 44 | chunks[src_rank * splits_per_rank + split_id].copy_(src_buf) 45 | # cuStreamWriteValue32 issues a system level fence before the write 46 | symm_mem_hdl.stream_write_value32( 47 | progress, 48 | offset=src_rank * splits_per_rank + split_id, 49 | val=1, 50 | ) 51 | symm_mem_hdl.barrier() 52 | 53 | 54 | def _matmul_launch_metadata(grid, kernel, args): 55 | ret = {} 56 | M, N, K = args["M"], args["N"], args["K"] 57 | ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" 58 | ret["flops8"] = 2.0 * M * N * K 59 | if "c_ptr" in args: 60 | bytes_per_elem = args["c_ptr"].element_size() 61 | else: 62 | bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 63 | ret["bytes"] = bytes_per_elem * (M * K + N * K) 64 | return ret 65 | 66 | 67 | @triton.jit 68 | def wait_signal(addr, flat_tid): 69 | if flat_tid == 0: 70 | tl.inline_asm_elementwise( 71 | """ 72 | { 73 | .reg .pred %p<1>; 74 | 75 | wait_block: 76 | ld.global.relaxed.gpu.u32 $0, [$1]; 77 | setp.eq.u32 %p0, $0, 1; 78 | @!%p0 bra wait_block; 79 | } 80 | """, 81 | "=r, l", 82 | [addr], 83 | dtype=tl.int32, 84 | is_pure=False, 85 | pack=1, 86 | ) 87 | 88 | tl.inline_asm_elementwise( 89 | "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 90 | ) 91 | 92 | 93 | @triton.jit(launch_metadata=_matmul_launch_metadata) 94 | def matmul_kernel_tma_persistent( 95 | a_shard_desc_ptr, 96 | a_desc_ptr, 97 | b_desc_ptr, 98 | c_desc_ptr, 99 | progress_ptr, 100 | M, 101 | N, 102 | K, 103 | BLOCK_SIZE_M: tl.constexpr, 104 | BLOCK_SIZE_N: tl.constexpr, 105 | BLOCK_SIZE_K: tl.constexpr, 106 | GROUP_SIZE_M: tl.constexpr, 107 | COMM_BLOCK_SIZE_M: tl.constexpr, 108 | RANK: tl.constexpr, 109 | WORLD_SIZE: tl.constexpr, 110 | FP8_OUTPUT: tl.constexpr, 111 | NUM_SMS: tl.constexpr, 112 | ): 113 | """ 114 | Slightly modified from the sm90 tma persistent Triton tutorial. 115 | """ 116 | flat_tid = get_flat_tid() 117 | 118 | dtype = tl.float8e4nv if FP8_OUTPUT else tl.bfloat16 119 | start_pid = tl.program_id(axis=0) 120 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 121 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 122 | k_tiles = tl.cdiv(K, BLOCK_SIZE_K) 123 | num_tiles = num_pid_m * num_pid_n 124 | 125 | tiles_per_SM = num_tiles // NUM_SMS 126 | if start_pid < num_tiles % NUM_SMS: 127 | tiles_per_SM += 1 128 | 129 | tile_id = start_pid - NUM_SMS 130 | ki = -1 131 | 132 | pid_m = 0 133 | pid_n = 0 134 | offs_am_src = 0 135 | offs_bn = 0 136 | a_ptr = a_desc_ptr 137 | 138 | num_pid_in_group = GROUP_SIZE_M * num_pid_n 139 | 140 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 141 | 142 | for _ in range(0, k_tiles * tiles_per_SM): 143 | ki = tl.where(ki == k_tiles - 1, 0, ki + 1) 144 | if ki == 0: 145 | tile_id += NUM_SMS 146 | group_id = tile_id // num_pid_in_group 147 | first_pid_m = group_id * GROUP_SIZE_M 148 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 149 | pid_m = first_pid_m + (tile_id % group_size_m) 150 | pid_n = (tile_id % num_pid_in_group) // group_size_m 151 | 152 | NUM_COMM_BLOCKS = M // COMM_BLOCK_SIZE_M 153 | NUM_COMM_BLOCKS_PER_RANK = NUM_COMM_BLOCKS // WORLD_SIZE 154 | NUM_PID_M_PER_COMM_BLOCK = COMM_BLOCK_SIZE_M // BLOCK_SIZE_M 155 | 156 | # Pivot tile_id so that M tiles are processed in their ready order. 157 | # This pivot preserves the prior swizzling. 158 | pid_m = (pid_m + NUM_PID_M_PER_COMM_BLOCK * RANK) % num_pid_m 159 | 160 | comm_block_id = pid_m // NUM_PID_M_PER_COMM_BLOCK 161 | if comm_block_id // NUM_COMM_BLOCKS_PER_RANK == RANK: 162 | # Read from the local a_shard 163 | offs_am_src = (pid_m * BLOCK_SIZE_M) % COMM_BLOCK_SIZE_M 164 | a_ptr = a_shard_desc_ptr 165 | else: 166 | # Wait for and read from a_shard copied from remote ranks 167 | wait_signal((progress_ptr + comm_block_id).to(tl.uint64), flat_tid) 168 | offs_am_src = pid_m * BLOCK_SIZE_M 169 | a_ptr = a_desc_ptr 170 | 171 | offs_bn = pid_n * BLOCK_SIZE_N 172 | offs_k = ki * BLOCK_SIZE_K 173 | 174 | a = tl._experimental_descriptor_load( 175 | a_ptr, [offs_am_src, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype 176 | ) 177 | b = tl._experimental_descriptor_load( 178 | b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype 179 | ) 180 | accumulator = tl.dot(a, b.T, accumulator) 181 | 182 | if ki == k_tiles - 1: 183 | c = accumulator.to(dtype) 184 | 185 | tl._experimental_descriptor_store( 186 | c_desc_ptr, c, [pid_m * BLOCK_SIZE_M, offs_bn] 187 | ) 188 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 189 | 190 | 191 | _tma_desc_cache = {} 192 | 193 | 194 | def create_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size): 195 | global _tma_desc_cache 196 | key = (ptr, dim1, dim0, block_dim1, block_dim0, element_size) 197 | if key in _tma_desc_cache: 198 | return _tma_desc_cache[key] 199 | desc = triton.tools.experimental_descriptor.create_2d_tma_descriptor( 200 | ptr, 201 | dim1, 202 | dim0, 203 | block_dim1, 204 | block_dim0, 205 | element_size, 206 | ) 207 | _tma_desc_cache[key] = desc 208 | return desc 209 | 210 | 211 | def all_gather_matmul_tma_persistent( 212 | a_shard, b, a_out, c_out, configs, mm_only: bool = False 213 | ): 214 | if mm_only: 215 | rank = 0 216 | world_size = int(os.environ.get("WORLD_SIZE", "8")) 217 | else: 218 | symm_mem_hdl = symm_mem.rendezvous(a_shard, group=dist.group.WORLD) 219 | assert symm_mem_hdl is not None, "a_shard must be allocated via SymmetricMemory" 220 | rank = symm_mem_hdl.rank 221 | world_size = symm_mem_hdl.world_size 222 | 223 | dtype = a_shard.dtype 224 | M = a_shard.shape[0] * world_size 225 | N = b.shape[0] 226 | K = a_shard.shape[1] 227 | 228 | assert b.shape[1] == K 229 | assert a_out.shape[0] == M 230 | assert a_out.shape[1] == K 231 | assert c_out.shape[0] == M 232 | assert c_out.shape[1] == N 233 | 234 | SPLITS_PER_RANK = 1 235 | COMM_BLOCK_SIZE_M = M // world_size // SPLITS_PER_RANK 236 | assert COMM_BLOCK_SIZE_M % (configs["BLOCK_SIZE_M"] * configs["GROUP_SIZE_M"]) == 0 237 | 238 | backend_stream = symm_mem._get_backend_stream(priority=-1) 239 | if mm_only: 240 | progress = torch.ones(world_size, dtype=torch.uint32, device="cuda") 241 | else: 242 | progress = torch.zeros(world_size, dtype=torch.uint32, device="cuda") 243 | symm_mem_hdl.barrier(0) 244 | backend_stream.wait_stream(torch.cuda.current_stream()) 245 | with torch.cuda.stream(backend_stream): 246 | all_gather_with_progress(a_out, a_shard, progress, SPLITS_PER_RANK) 247 | 248 | desc_a_shard = create_2d_tma_descriptor( 249 | a_shard.data_ptr(), 250 | a_shard.shape[0], 251 | K, 252 | configs["BLOCK_SIZE_M"], 253 | configs["BLOCK_SIZE_K"], 254 | a_shard.element_size(), 255 | ) 256 | desc_a = create_2d_tma_descriptor( 257 | a_out.data_ptr(), 258 | M, 259 | K, 260 | configs["BLOCK_SIZE_M"], 261 | configs["BLOCK_SIZE_K"], 262 | a_out.element_size(), 263 | ) 264 | desc_b = create_2d_tma_descriptor( 265 | b.data_ptr(), 266 | N, 267 | K, 268 | configs["BLOCK_SIZE_N"], 269 | configs["BLOCK_SIZE_K"], 270 | b.element_size(), 271 | ) 272 | desc_c = create_2d_tma_descriptor( 273 | c_out.data_ptr(), 274 | M, 275 | N, 276 | configs["BLOCK_SIZE_M"], 277 | configs["BLOCK_SIZE_N"], 278 | c_out.element_size(), 279 | ) 280 | NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count 281 | 282 | grid = lambda META: ( 283 | min( 284 | NUM_SMS, 285 | triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 286 | ), 287 | ) 288 | kernel = matmul_kernel_tma_persistent[grid]( 289 | desc_a_shard, 290 | desc_a, 291 | desc_b, 292 | desc_c, 293 | progress, 294 | M, 295 | N, 296 | K, 297 | BLOCK_SIZE_M=configs["BLOCK_SIZE_M"], 298 | BLOCK_SIZE_N=configs["BLOCK_SIZE_N"], 299 | BLOCK_SIZE_K=configs["BLOCK_SIZE_K"], 300 | GROUP_SIZE_M=configs["GROUP_SIZE_M"], 301 | COMM_BLOCK_SIZE_M=COMM_BLOCK_SIZE_M, 302 | RANK=rank, 303 | WORLD_SIZE=world_size, 304 | FP8_OUTPUT=dtype == torch.float8_e4m3fn, 305 | NUM_SMS=NUM_SMS, 306 | num_stages=configs["num_stages"], 307 | num_warps=configs["num_warps"], 308 | ) 309 | log_triton_kernel(kernel) 310 | torch.cuda.current_stream().wait_stream(backend_stream) 311 | return c_out 312 | 313 | 314 | def all_gather_matmul(a_shard, b): 315 | from torch.distributed._functional_collectives import all_gather_tensor 316 | 317 | a = all_gather_tensor(a_shard, 0, "0") 318 | return torch.matmul(a, b) 319 | 320 | 321 | @click.command() 322 | @click.option("--M", default=4096) 323 | @click.option("--N", default=6656) 324 | @click.option("--K", default=16384) 325 | @click.option("--BLOCK_SIZE_M", default=128) 326 | @click.option("--BLOCK_SIZE_N", default=256) 327 | @click.option("--BLOCK_SIZE_K", default=64) 328 | @click.option("--GROUP_SIZE_M", default=4) 329 | @click.option("--num_stages", default=3) 330 | @click.option("--num_warps", default=8) 331 | def main( 332 | m: int, 333 | n: int, 334 | k: int, 335 | block_size_m: int, 336 | block_size_n: int, 337 | block_size_k: int, 338 | group_size_m: int, 339 | num_stages: int, 340 | num_warps: int, 341 | ): 342 | """ 343 | torchrun \ 344 | --nnodes 1 --nproc-per-node 8 \ 345 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \ 346 | --no_python python3 triton_all_gather_matmul.py 347 | """ 348 | rank = int(os.environ["RANK"]) 349 | local_rank = int(os.environ["LOCAL_RANK"]) 350 | world_size = int(os.environ["WORLD_SIZE"]) 351 | 352 | device = torch.device(f"cuda:{local_rank}") 353 | torch.cuda.set_device(device) 354 | torch.manual_seed(42 + rank) 355 | dist.init_process_group("nccl") 356 | 357 | a_shard = symm_mem.empty( 358 | m // world_size, k, dtype=torch.bfloat16, device=device 359 | ).normal_() 360 | a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) 361 | b = torch.randn((k, n), device="cuda", dtype=torch.bfloat16).T.contiguous() 362 | c = torch.randn((m, n), device="cuda", dtype=torch.bfloat16) 363 | 364 | # Autotuner does not work with TMA. Use manual config. 365 | configs = { 366 | "BLOCK_SIZE_M": block_size_m, 367 | "BLOCK_SIZE_N": block_size_n, 368 | "BLOCK_SIZE_K": block_size_k, 369 | "GROUP_SIZE_M": group_size_m, 370 | "num_stages": num_stages, 371 | "num_warps": num_warps, 372 | } 373 | 374 | c0 = all_gather_matmul(a_shard, b.T) 375 | c1 = all_gather_matmul_tma_persistent(a_shard, b, a, c, configs) 376 | assert torch.allclose(c0, c1) 377 | 378 | def rank_0_print(msg): 379 | if rank == 0: 380 | print(msg) 381 | 382 | lat_cublas_mm = benchmark_with_event( 383 | lambda: torch.matmul(a, b.T, out=c), flush_l2=True 384 | ) 385 | rank_0_print(f"cublas mm only:\t{round(lat_cublas_mm)} us") 386 | 387 | lat_triton_mm = benchmark_with_event( 388 | lambda: all_gather_matmul_tma_persistent( 389 | a_shard, b, a, c, configs, mm_only=True 390 | ), 391 | flush_l2=True, 392 | ) 393 | rank_0_print(f"triton mm only:\t{round(lat_triton_mm)} us") 394 | 395 | lat_cublas_nccl = benchmark_with_event( 396 | lambda: all_gather_matmul(a_shard, b.T), flush_l2=True 397 | ) 398 | rank_0_print(f"cublas + nccl:\t{round(lat_cublas_nccl)} us") 399 | 400 | lat_triton_fused = benchmark_with_event( 401 | lambda: all_gather_matmul_tma_persistent(a_shard, b, a, c, configs), 402 | flush_l2=True, 403 | ) 404 | rank_0_print(f"triton fused:\t{round(lat_triton_fused)} us") 405 | rank_0_print(f"speedup:\t{lat_cublas_nccl / lat_triton_fused:.02f}x") 406 | 407 | dist.destroy_process_group() 408 | 409 | 410 | if __name__ == "__main__": 411 | main() 412 | -------------------------------------------------------------------------------- /triton_barrier.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import torch.distributed._symmetric_memory as symm_mem 6 | import triton 7 | import triton.language as tl 8 | 9 | from triton_utils import get_flat_bid, get_flat_tid, sync_threads 10 | from utils import log_triton_kernel 11 | 12 | 13 | @triton.jit 14 | def send_signal(addrs, sem: tl.constexpr): 15 | if sem == "relaxed": 16 | tl.inline_asm_elementwise( 17 | """ 18 | { 19 | .reg .u32 %tmp32_<1>; 20 | .reg .pred %p<1>; 21 | 22 | send_signal: 23 | atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1; 24 | setp.eq.u32 %p0, %tmp32_0, 0; 25 | @!%p0 bra send_signal; 26 | } 27 | """, 28 | "=r, l", 29 | [addrs], 30 | dtype=tl.int32, 31 | is_pure=False, 32 | pack=1, 33 | ) 34 | elif sem == "acq_rel": 35 | tl.inline_asm_elementwise( 36 | """ 37 | { 38 | .reg .u32 %tmp32_<1>; 39 | .reg .pred %p<1>; 40 | 41 | send_signal: 42 | atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1; 43 | setp.eq.u32 %p0, %tmp32_0, 0; 44 | @!%p0 bra send_signal; 45 | } 46 | """, 47 | "=r, l", 48 | [addrs], 49 | dtype=tl.int32, 50 | is_pure=False, 51 | pack=1, 52 | ) 53 | else: 54 | raise RuntimeError(f"Unrecognized sem: {sem}") 55 | 56 | 57 | @triton.jit 58 | def wait_signal(addrs, sem: tl.constexpr): 59 | if sem == "relaxed": 60 | tl.inline_asm_elementwise( 61 | """ 62 | { 63 | .reg .u32 %tmp32_<1>; 64 | .reg .pred %p<1>; 65 | 66 | wait_signal: 67 | atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0; 68 | setp.eq.u32 %p0, %tmp32_0, 1; 69 | @!%p0 bra wait_signal; 70 | } 71 | """, 72 | "=r, l", 73 | [addrs], 74 | dtype=tl.int32, 75 | is_pure=False, 76 | pack=1, 77 | ) 78 | elif sem == "acq_rel": 79 | tl.inline_asm_elementwise( 80 | """ 81 | { 82 | .reg .u32 %tmp32_<1>; 83 | .reg .pred %p<1>; 84 | 85 | wait_signal: 86 | atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0; 87 | setp.eq.u32 %p0, %tmp32_0, 1; 88 | @!%p0 bra wait_signal; 89 | } 90 | """, 91 | "=r, l", 92 | [addrs], 93 | dtype=tl.int32, 94 | is_pure=False, 95 | pack=1, 96 | ) 97 | else: 98 | raise RuntimeError(f"Unrecognized sem: {sem}") 99 | 100 | 101 | @triton.jit 102 | def blockwise_barrier( 103 | signal_pad_ptrs, 104 | block_id, 105 | rank: tl.constexpr, 106 | world_size: tl.constexpr, 107 | sem: tl.constexpr, 108 | ): 109 | """ 110 | Synchronizes blocks with matching block_id across participating devices. 111 | 112 | Note: the function itself is not a system level barrier/fence. It is a 113 | building block for expressing different synchronization patterns. 114 | 115 | Pattern 0: Ensures that all writes to symm_mem buffers from previous 116 | kernels across all devices are visible to the current kernel: 117 | 118 | blockwise_barrier(..., sem="relaxed") 119 | sync_threads() 120 | 121 | Pattern 1: Ensures that all writes to symm_mem buffers from the current 122 | block are visible to all remote blocks with matching blockIdx: 123 | 124 | sync_threads() 125 | blockwise_barrier(..., sem="acq_rel") 126 | sync_threads() 127 | 128 | Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe 129 | for writing by subsequent kernels across all devices. 130 | 131 | sync_threads() 132 | blockwise_barrier(..., sem="relaxed") 133 | 134 | CUDA graph friendliness: 135 | 136 | This barrier operates through atomic operations on a zero-filled signal 137 | pad, which resets to a zero-filled state after each successful 138 | synchronization. This design eliminates the need for incrementing a 139 | flag from host. 140 | """ 141 | if block_id is None: 142 | block_id = get_flat_bid() 143 | flat_tid = get_flat_tid() 144 | 145 | remote_ranks = tl.arange(0, world_size) 146 | signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64)) 147 | remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to( 148 | tl.pointer_type(tl.uint32) 149 | ) 150 | send_addrs = remote_signal_pad_addrs + block_id * world_size + rank 151 | 152 | local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to( 153 | tl.pointer_type(tl.uint32) 154 | ) 155 | wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks 156 | 157 | if flat_tid < world_size: 158 | send_signal(send_addrs, sem) 159 | wait_signal(wait_addrs, sem) 160 | 161 | 162 | @triton.jit 163 | def barrier_test_kernel( 164 | signal_pad_ptrs, 165 | rank: tl.constexpr, 166 | world_size: tl.constexpr, 167 | ): 168 | blockwise_barrier(signal_pad_ptrs, None, rank, world_size, "relaxed") 169 | sync_threads() 170 | 171 | 172 | def barrier_test(t: torch.Tensor) -> None: 173 | symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD) 174 | 175 | kernel = barrier_test_kernel[(32, 1, 1)]( 176 | symm_mem_hdl.signal_pad_ptrs_dev, 177 | rank=symm_mem_hdl.rank, 178 | world_size=symm_mem_hdl.world_size, 179 | ) 180 | log_triton_kernel(kernel) 181 | 182 | signal_pad = symm_mem_hdl.get_signal_pad(symm_mem_hdl.rank) 183 | assert signal_pad.eq(0).all().item() 184 | 185 | 186 | if __name__ == "__main__": 187 | """ 188 | torchrun \ 189 | --nnodes 1 --nproc-per-node 8 \ 190 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \ 191 | --no_python python3 triton_barrier.py 192 | """ 193 | rank = int(os.environ["RANK"]) 194 | local_rank = int(os.environ["LOCAL_RANK"]) 195 | world_size = int(os.environ["WORLD_SIZE"]) 196 | 197 | device = torch.device(f"cuda:{local_rank}") 198 | torch.cuda.set_device(device) 199 | dist.init_process_group("nccl") 200 | 201 | t = symm_mem.empty(4096, device=device) 202 | barrier_test(t) 203 | -------------------------------------------------------------------------------- /triton_multimem_all_reduce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import torch.distributed._symmetric_memory as symm_mem 4 | import triton 5 | import triton.language as tl 6 | 7 | from triton_barrier import blockwise_barrier 8 | from triton_utils import get_flat_tid, sync_threads 9 | from utils import log_triton_kernel 10 | 11 | 12 | @triton.jit 13 | def multimem_ld_reduce_128(multicast_ptrs, mask): 14 | return tl.inline_asm_elementwise( 15 | """ 16 | { 17 | .reg .pred %p0; 18 | setp.eq.s32 %p0, $5, 1; 19 | @!%p0 bra end; 20 | multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {$0, $1, $2, $3}, [$4]; 21 | end: 22 | } 23 | """, 24 | "=r,=r,=r,=r,l,r", 25 | args=[multicast_ptrs, mask.to(tl.int32)], 26 | dtype=(tl.uint32, tl.uint32, tl.uint32, tl.uint32), 27 | is_pure=True, 28 | pack=1, 29 | ) 30 | 31 | 32 | @triton.jit 33 | def multimem_st_128(multicast_ptrs, x, y, z, w, mask): 34 | return tl.inline_asm_elementwise( 35 | """ 36 | { 37 | .reg .pred %p0; 38 | setp.eq.s32 %p0, $6, 1; 39 | @!%p0 bra end; 40 | multimem.st.relaxed.sys.global.v4.f32 [$1], {$2, $3, $4, $5}; 41 | end: 42 | } 43 | """, 44 | "=r,l,r,r,r,r,r", 45 | args=[multicast_ptrs, x, y, z, w, mask.to(tl.int32)], 46 | dtype=(tl.uint32), 47 | is_pure=False, 48 | pack=1, 49 | ) 50 | 51 | 52 | @triton.jit 53 | def multimem_all_reduce_kernel( 54 | multicast_ptr, 55 | signal_pad_ptrs, 56 | numel, 57 | BLOCK_SIZE: tl.constexpr, 58 | NUMEL_PER_THREAD: tl.constexpr, 59 | RANK: tl.constexpr, 60 | WORLD_SIZE: tl.constexpr, 61 | ): 62 | blockwise_barrier(signal_pad_ptrs, None, RANK, WORLD_SIZE, sem="relaxed") 63 | sync_threads() 64 | 65 | pid = tl.program_id(axis=0) 66 | tid = get_flat_tid() 67 | 68 | # From this point on, we pretend each element is 128-bit 69 | numel = numel // NUMEL_PER_THREAD 70 | numel_per_rank = tl.cdiv(numel, WORLD_SIZE) 71 | block_start = pid * BLOCK_SIZE 72 | 73 | while block_start < numel_per_rank: 74 | offsets = block_start + tid 75 | mask = offsets < numel_per_rank 76 | 77 | # Each pointer points to a 128-bit bit pack 78 | ptrs = ( 79 | multicast_ptr.to(tl.pointer_type(tl.uint64)) 80 | + (RANK * numel_per_rank + offsets) * 2 81 | ) 82 | (x, y, z, w) = multimem_ld_reduce_128(ptrs, mask=mask) 83 | multimem_st_128(ptrs, x, y, z, w, mask=mask) 84 | 85 | block_start += tl.num_programs(axis=0) * BLOCK_SIZE 86 | 87 | sync_threads() 88 | blockwise_barrier(signal_pad_ptrs, None, RANK, WORLD_SIZE, sem="acq_rel") 89 | 90 | 91 | def multimem_all_reduce(tensor: torch.Tensor): 92 | WARP_SIZE = 32 93 | MAX_NUM_BLOCKS = 4 94 | MAX_BLOCK_SIZE = 1024 95 | BYTES_PER_THREAD = 16 96 | 97 | symm_mem_hdl = symm_mem.rendezvous(tensor, group=dist.group.WORLD) 98 | 99 | assert tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." 100 | numel_per_thread = BYTES_PER_THREAD // tensor.element_size() 101 | 102 | assert ( 103 | tensor.numel() % numel_per_thread == 0 104 | ), "The number of elements must be 128-bit aligned." 105 | 106 | num_threads = triton.cdiv( 107 | tensor.numel() // numel_per_thread, symm_mem_hdl.world_size 108 | ) 109 | if num_threads < MAX_BLOCK_SIZE: 110 | block_size = 1 111 | while block_size < num_threads: 112 | block_size *= 2 113 | num_warps = block_size // WARP_SIZE 114 | num_blocks = 1 115 | else: 116 | block_size = MAX_BLOCK_SIZE 117 | num_warps = MAX_BLOCK_SIZE // WARP_SIZE 118 | num_blocks = min( 119 | triton.cdiv(num_threads, MAX_BLOCK_SIZE), 120 | MAX_NUM_BLOCKS, 121 | ) 122 | 123 | kernel = multimem_all_reduce_kernel[(num_blocks, 1, 1)]( 124 | symm_mem_hdl.multicast_ptr, 125 | symm_mem_hdl.signal_pad_ptrs_dev, 126 | numel=tensor.numel(), 127 | BLOCK_SIZE=block_size, 128 | NUMEL_PER_THREAD=numel_per_thread, 129 | RANK=symm_mem_hdl.rank, 130 | WORLD_SIZE=symm_mem_hdl.world_size, 131 | num_warps=num_warps, 132 | ) 133 | log_triton_kernel(kernel) 134 | return tensor 135 | 136 | 137 | if __name__ == "__main__": 138 | """ 139 | torchrun \ 140 | --nnodes 1 --nproc-per-node 8 \ 141 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \ 142 | --no_python python3 triton_multimem_all_reduce.py 143 | """ 144 | from symm_mem_all_reduce import main 145 | 146 | main(["--impl", "triton_multimem_all_reduce"]) 147 | -------------------------------------------------------------------------------- /triton_one_shot_all_reduce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import torch.distributed._symmetric_memory as symm_mem 4 | import triton 5 | import triton.language as tl 6 | 7 | from triton_barrier import blockwise_barrier 8 | from triton_utils import sync_threads 9 | from utils import log_triton_kernel 10 | 11 | 12 | @triton.jit 13 | def load_128(addrs, mask): 14 | return tl.inline_asm_elementwise( 15 | """ 16 | { 17 | .reg .pred %p0; 18 | setp.eq.s32 %p0, $3, 1; 19 | @%p0 ld.global.v2.u64 {$0, $1}, [$2]; 20 | } 21 | """, 22 | "=l,=l,l,r", 23 | args=[addrs, mask.to(tl.int32)], 24 | dtype=(tl.uint64, tl.uint64), 25 | is_pure=True, 26 | pack=1, 27 | ) 28 | 29 | 30 | @triton.jit 31 | def add_v8_bf16(a_hi, a_lo, b_hi, b_lo): 32 | return tl.inline_asm_elementwise( 33 | """ 34 | { 35 | .reg .v4 .b32 %acc, %tmp; 36 | mov.v4.b32 %acc, 0; 37 | mov.b64 {%acc.x, %acc.y}, $2; 38 | mov.b64 {%acc.z, %acc.w}, $3; 39 | mov.b64 {%tmp.x, %tmp.y}, $4; 40 | mov.b64 {%tmp.z, %tmp.w}, $5; 41 | add.bf16x2 %acc.x, %acc.x, %tmp.x; 42 | add.bf16x2 %acc.y, %acc.y, %tmp.y; 43 | add.bf16x2 %acc.z, %acc.z, %tmp.z; 44 | add.bf16x2 %acc.w, %acc.w, %tmp.w; 45 | mov.b64 $0, {%acc.x, %acc.y}; 46 | mov.b64 $1, {%acc.z, %acc.w}; 47 | } 48 | """, 49 | "=l,=l,l,l,l,l", 50 | args=[a_hi, a_lo, b_hi, b_lo], 51 | dtype=(tl.uint64, tl.uint64), 52 | is_pure=True, 53 | pack=1, 54 | ) 55 | 56 | 57 | @triton.jit 58 | def one_shot_all_reduce_kernel( 59 | buffer_ptrs, 60 | signal_pad_ptrs, 61 | output_ptr, 62 | numel: tl.constexpr, 63 | rank: tl.constexpr, 64 | world_size: tl.constexpr, 65 | BLOCK_SIZE: tl.constexpr, 66 | NUMEL_PER_THREAD: tl.constexpr, 67 | ): 68 | blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed") 69 | sync_threads() 70 | 71 | pid = tl.program_id(axis=0) 72 | 73 | buffer_ptrs = buffer_ptrs.to(tl.pointer_type(tl.uint64)) 74 | output_ptr = output_ptr.to(tl.pointer_type(tl.uint64)) 75 | block_start = pid * BLOCK_SIZE 76 | 77 | while block_start < (numel // NUMEL_PER_THREAD): 78 | # Each thread processes 128 bits. Since Triton doesn't yet natively 79 | # support 128-bit dtypes, we achieve this by having each thread process 80 | # two 64-bit elements. 81 | offsets = (block_start + tl.arange(0, BLOCK_SIZE)) * 2 82 | mask = block_start + tl.arange(0, BLOCK_SIZE) < numel // NUMEL_PER_THREAD 83 | 84 | acc_hi = tl.zeros((BLOCK_SIZE,), tl.uint64) 85 | acc_lo = tl.zeros((BLOCK_SIZE,), tl.uint64) 86 | for i in range(world_size): 87 | buffer_ptr = tl.load(buffer_ptrs + i).to(tl.pointer_type(tl.uint64)) 88 | (hi, lo) = load_128(buffer_ptr + offsets, mask=mask) 89 | (acc_hi, acc_lo) = add_v8_bf16(acc_hi, acc_lo, hi, lo) 90 | 91 | tl.store(output_ptr + offsets + 0, acc_hi, mask=mask) 92 | tl.store(output_ptr + offsets + 1, acc_lo, mask=mask) 93 | block_start += tl.num_programs(axis=0) * BLOCK_SIZE 94 | 95 | sync_threads() 96 | blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed") 97 | 98 | 99 | def one_shot_all_reduce(tensor: torch.Tensor): 100 | MAX_NUM_BLOCKS = 24 101 | NUM_WARPS = 16 102 | BLOCK_SIZE = NUM_WARPS * 32 103 | NUMEL_PER_THREAD = 8 104 | 105 | assert tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." 106 | assert ( 107 | tensor.numel() % NUMEL_PER_THREAD == 0 108 | ), "The number of elements must be 128-bit aligned." 109 | num_blocks = min( 110 | triton.cdiv(triton.cdiv(tensor.numel(), NUMEL_PER_THREAD), BLOCK_SIZE), 111 | MAX_NUM_BLOCKS, 112 | ) 113 | 114 | symm_mem_hdl = symm_mem.rendezvous(tensor, group=dist.group.WORLD) 115 | output = torch.empty_like(tensor) 116 | 117 | kernel = one_shot_all_reduce_kernel[(num_blocks, 1, 1)]( 118 | symm_mem_hdl.buffer_ptrs_dev, 119 | symm_mem_hdl.signal_pad_ptrs_dev, 120 | output, 121 | numel=tensor.numel(), 122 | rank=symm_mem_hdl.rank, 123 | world_size=symm_mem_hdl.world_size, 124 | BLOCK_SIZE=BLOCK_SIZE, 125 | NUMEL_PER_THREAD=NUMEL_PER_THREAD, 126 | num_warps=NUM_WARPS, 127 | ) 128 | log_triton_kernel(kernel) 129 | return output 130 | 131 | 132 | if __name__ == "__main__": 133 | """ 134 | torchrun \ 135 | --nnodes 1 --nproc-per-node 8 \ 136 | --rdzv-backend c10d --rdzv-endpoint localhost:0 \ 137 | --no_python python3 triton_one_shot_all_reduce.py 138 | """ 139 | from symm_mem_all_reduce import main 140 | 141 | main(["--impl", "triton_one_shot_all_reduce"]) 142 | -------------------------------------------------------------------------------- /triton_utils.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | 5 | @triton.jit 6 | def get_tid(): 7 | return tl.inline_asm_elementwise( 8 | """ 9 | mov.u32 $0, %tid.x; 10 | mov.u32 $1, %tid.y; 11 | mov.u32 $2, %tid.z; 12 | """, 13 | "=r,=r,=r", 14 | [], 15 | dtype=(tl.uint32, tl.uint32, tl.uint32), 16 | is_pure=True, 17 | pack=1, 18 | ) 19 | 20 | 21 | @triton.jit 22 | def get_ntid(): 23 | return tl.inline_asm_elementwise( 24 | """ 25 | mov.u32 $0, %ntid.x; 26 | mov.u32 $1, %ntid.y; 27 | mov.u32 $2, %ntid.z; 28 | """, 29 | "=r,=r,=r", 30 | [], 31 | dtype=(tl.uint32, tl.uint32, tl.uint32), 32 | is_pure=True, 33 | pack=1, 34 | ) 35 | 36 | 37 | @triton.jit 38 | def get_flat_tid(): 39 | tid_x, tid_y, tid_z = get_tid() 40 | ntid_x, ntid_y, _ = get_ntid() 41 | return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x 42 | 43 | 44 | @triton.jit 45 | def get_flat_bid(): 46 | return ( 47 | tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0) 48 | + tl.program_id(1) * tl.num_programs(0) 49 | + tl.program_id(0) 50 | ) 51 | 52 | 53 | @triton.jit 54 | def sync_threads(): 55 | tl.inline_asm_elementwise( 56 | "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 57 | ) 58 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from contextlib import nullcontext 4 | from typing import Callable, List, Optional 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | 10 | def benchmark_with_profiler( 11 | target_fn: Callable[[None], None], 12 | event_key_regex: str, 13 | warmup_iters: int = 200, 14 | benchmark_iters: int = 25, 15 | profile_ranks: Optional[List[int]] = None, 16 | flush_l2: bool = False, 17 | ) -> float: 18 | """ 19 | Benchmark the target function with PyTorch profiler. 20 | 21 | Args: 22 | target_fn: The target function to benchmark. 23 | event_key_regex: The regex pattern to identify the profiler event 24 | associated with the target function. 25 | profile_ranks: The ranks to profile. 26 | warmup_iters: The number of warmup iterations. 27 | benchmark_iters: The number of benchmark iterations. 28 | flush_l2: Whether to flush the L2 cache before each invocation of the 29 | target function. 30 | 31 | Returns: 32 | The measured median latency in microseconds. 33 | """ 34 | if "BENCHMARK_ITERS" in os.environ: 35 | benchmark_iters = int(os.environ["BENCHMARK_ITERS"]) 36 | 37 | rank = dist.get_rank() if dist.is_initialized() else 0 38 | profile_ranks = profile_ranks or [0] 39 | 40 | if flush_l2: 41 | cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") 42 | 43 | if rank in profile_ranks: 44 | try: 45 | from trace_handler import trace_handler 46 | except ImportError: 47 | trace_handler = None 48 | 49 | if "NO_TRACE" in os.environ: 50 | trace_handler = None 51 | 52 | prof = torch.profiler.profile( 53 | activities=[ 54 | torch.profiler.ProfilerActivity.CPU, 55 | torch.profiler.ProfilerActivity.CUDA, 56 | ], 57 | on_trace_ready=trace_handler, 58 | ) 59 | else: 60 | prof = nullcontext() 61 | 62 | for _ in range(warmup_iters): 63 | target_fn() 64 | 65 | if dist.is_initialized(): 66 | dist.barrier(device_ids=[torch.cuda.current_device()]) 67 | torch.cuda.synchronize() 68 | 69 | with prof: 70 | torch.cuda._sleep(int(2e7)) 71 | for i in range(benchmark_iters): 72 | if flush_l2: 73 | cache.zero_() 74 | target_fn() 75 | torch.cuda.synchronize() 76 | 77 | if rank not in profile_ranks: 78 | return 0 79 | 80 | latencies_us = [] 81 | for event in prof.events(): 82 | if re.match(event_key_regex, event.key): 83 | latencies_us.append(event.device_time) 84 | 85 | if len(latencies_us) == 0: 86 | return 0 87 | 88 | return torch.tensor(latencies_us).median().item() 89 | 90 | 91 | def benchmark_with_event( 92 | target_fn: Callable[[None], None], 93 | warmup_iters: int = 200, 94 | benchmark_iters: int = 25, 95 | profile_ranks: Optional[List[int]] = None, 96 | flush_l2: bool = False, 97 | cuda_graph: bool = False, 98 | ) -> float: 99 | if cuda_graph: 100 | target_fn() 101 | g = torch.cuda.CUDAGraph() 102 | with torch.cuda.graph(g): 103 | target_fn() 104 | target_fn = lambda: g.replay() 105 | 106 | if "BENCHMARK_ITERS" in os.environ: 107 | benchmark_iters = int(os.environ["BENCHMARK_ITERS"]) 108 | 109 | rank = dist.get_rank() if dist.is_initialized() else 0 110 | profile_ranks = profile_ranks or [0] 111 | 112 | if flush_l2: 113 | cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") 114 | 115 | for _ in range(warmup_iters): 116 | target_fn() 117 | 118 | if dist.is_initialized(): 119 | dist.barrier(device_ids=[torch.cuda.current_device()]) 120 | torch.cuda.synchronize() 121 | 122 | begin_events = [ 123 | torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters) 124 | ] 125 | end_events = [torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters)] 126 | 127 | if rank in profile_ranks: 128 | try: 129 | from trace_handler import trace_handler 130 | except ImportError: 131 | trace_handler = None 132 | 133 | if "NO_TRACE" in os.environ: 134 | trace_handler = None 135 | 136 | prof = torch.profiler.profile( 137 | activities=[ 138 | torch.profiler.ProfilerActivity.CPU, 139 | torch.profiler.ProfilerActivity.CUDA, 140 | ], 141 | on_trace_ready=trace_handler, 142 | ) 143 | else: 144 | prof = nullcontext() 145 | 146 | with prof: 147 | torch.cuda._sleep(int(2e7)) 148 | for i in range(benchmark_iters): 149 | if flush_l2: 150 | cache.zero_() 151 | begin_events[i].record() 152 | target_fn() 153 | end_events[i].record() 154 | torch.cuda.synchronize() 155 | 156 | latencies = [b.elapsed_time(e) for b, e in zip(begin_events, end_events)] 157 | return torch.tensor(latencies).median().item() * 1000 158 | 159 | 160 | triton_kernels = {} 161 | 162 | 163 | def log_triton_kernel(kernel): 164 | import atexit 165 | import tempfile 166 | 167 | if dist.is_initialized() and dist.get_rank() != 0: 168 | return 169 | 170 | def on_exit(): 171 | print("PTX files:") 172 | for kernel in triton_kernels: 173 | f = tempfile.NamedTemporaryFile(dir="/tmp", delete=False) 174 | f.write(kernel.asm["ptx"].encode("utf-8")) 175 | print(f"+- {kernel.name}: {f.name}") 176 | 177 | if len(triton_kernels) == 0: 178 | atexit.register(on_exit) 179 | 180 | if kernel not in triton_kernels: 181 | triton_kernels[kernel] = None 182 | --------------------------------------------------------------------------------