├── yunchang ├── comm │ ├── __init__.py │ ├── extract_local.py │ └── all_to_all.py ├── ulysses │ ├── __init__.py │ └── attn_layer.py ├── __init__.py ├── hybrid │ ├── __init__.py │ ├── utils.py │ ├── async_attn_layer.py │ └── attn_layer.py ├── ring │ ├── __init__.py │ ├── triton_utils.py │ ├── utils.py │ ├── ring_pytorch_attn.py │ ├── ring_npu_flash_attn.py │ ├── ring_flashinfer_attn.py │ ├── ring_flash_attn.py │ ├── ring_flash_attn_varlen.py │ ├── zigzag_ring_flash_attn.py │ └── stripe_flash_attn.py ├── globals.py └── kernels │ └── __init__.py ├── media ├── gqa.png ├── loss.png ├── ring.png ├── usp.png ├── ulysses.png ├── usp_fa.png ├── yun_chang.jpg ├── long_ctx_h2.png ├── long_ctx_h8.png ├── pcie_machine.jpg └── benchmark_results.png ├── scripts ├── run_npu.sh ├── run_hybrid_npu.sh ├── run_dit.sh ├── run_gqa.sh └── run_qkvpack_compare.sh ├── .pre-commit-config.yaml ├── pyproject.toml ├── .github └── workflows │ └── python-publish.yml ├── docs └── install_amd.md ├── .gitignore ├── test ├── test_ulysses_attn_npu.py ├── test_ulysses_attn.py ├── test_utils.py ├── test_hybrid_qkvpacked_attn.py └── test_hybrid_attn_npu.py ├── benchmark ├── benchmark_longctx_qkvpacked.py └── benchmark_longctx.py ├── LICENSE.txt ├── README.md └── patches └── Megatron-DeepSpeed.patch /yunchang/comm/__init__.py: -------------------------------------------------------------------------------- 1 | from .all_to_all import * 2 | from .extract_local import * 3 | -------------------------------------------------------------------------------- /media/gqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/gqa.png -------------------------------------------------------------------------------- /media/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/loss.png -------------------------------------------------------------------------------- /media/ring.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/ring.png -------------------------------------------------------------------------------- /media/usp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/usp.png -------------------------------------------------------------------------------- /media/ulysses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/ulysses.png -------------------------------------------------------------------------------- /media/usp_fa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/usp_fa.png -------------------------------------------------------------------------------- /media/yun_chang.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/yun_chang.jpg -------------------------------------------------------------------------------- /media/long_ctx_h2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/long_ctx_h2.png -------------------------------------------------------------------------------- /media/long_ctx_h8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/long_ctx_h8.png -------------------------------------------------------------------------------- /media/pcie_machine.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/pcie_machine.jpg -------------------------------------------------------------------------------- /yunchang/ulysses/__init__.py: -------------------------------------------------------------------------------- 1 | from .attn_layer import UlyssesAttention 2 | 3 | __all__ = ['UlyssesAttention'] 4 | -------------------------------------------------------------------------------- /media/benchmark_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/benchmark_results.png -------------------------------------------------------------------------------- /scripts/run_npu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source /usr/local/Ascend/ascend-toolkit/latest/bin/setenv.bash 3 | export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 4 | 5 | export MASTER_ADDR=127.0.0.1 6 | export MASTER_PORT=21289 7 | export HCCL_IF_BASE_PORT=64199 8 | 9 | torchrun --master_port=29600 --nproc_per_node 4 test/test_ulysses_attn_npu.py 10 | -------------------------------------------------------------------------------- /yunchang/__init__.py: -------------------------------------------------------------------------------- 1 | from .hybrid import * 2 | from .ring import * 3 | from .ulysses import * 4 | from .globals import set_seq_parallel_pg 5 | from .comm.extract_local import ( 6 | stripe_extract_local, 7 | basic_extract_local, 8 | zigzag_extract_local, 9 | EXTRACT_FUNC_DICT, 10 | ) 11 | 12 | __version__ = "0.6.3.post1" 13 | -------------------------------------------------------------------------------- /scripts/run_hybrid_npu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source /usr/local/Ascend/ascend-toolkit/latest/bin/setenv.bash 3 | export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 4 | 5 | export MASTER_ADDR=127.0.0.1 6 | export MASTER_PORT=21289 7 | export HCCL_IF_BASE_PORT=64199 8 | 9 | torchrun --master_port=29600 --nproc_per_node 4 test/test_hybrid_attn_npu.py --use_bwd 10 | -------------------------------------------------------------------------------- /yunchang/hybrid/__init__.py: -------------------------------------------------------------------------------- 1 | from .attn_layer import LongContextAttention, LongContextAttentionQKVPacked 2 | from .async_attn_layer import AsyncLongContextAttention 3 | 4 | from .utils import RING_IMPL_QKVPACKED_DICT 5 | 6 | __all__ = [ 7 | "LongContextAttention", 8 | "LongContextAttentionQKVPacked", 9 | "RING_IMPL_QKVPACKED_DICT", 10 | "AsyncLongContextAttention", 11 | ] 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # Using this mirror lets us use mypyc-compiled black, which is about 2x faster 3 | - repo: https://github.com/psf/black-pre-commit-mirror 4 | rev: 24.2.0 5 | hooks: 6 | - id: black 7 | # It is recommended to specify the latest version of Python 8 | # supported by your project here, or alternatively use 9 | # pre-commit's default_language_version, see 10 | # https://pre-commit.com/#top_level-default_language_version 11 | language_version: python3.10 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "yunchang" 7 | version = "0.6.3.post1" 8 | authors = [ 9 | { name="Jiarui Fang", email="fangjiarui123@gmail.com" }, 10 | ] 11 | description = "a package for long context attention" 12 | readme = "README.md" 13 | requires-python = ">=3.7" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: Apache Software License", 17 | "Operating System :: OS Independent", 18 | ] 19 | dependencies = [] 20 | 21 | [project.optional-dependencies] 22 | flash = ["flash-attn>=2.6.0"] 23 | 24 | [project.urls] 25 | "Homepage" = "https://github.com/feifeibear/long-context-attention" 26 | "Bug Tracker" = "https://github.com/feifeibear/long-context-attention/issues" 27 | 28 | [tool.setuptools] 29 | packages = {find = {where = ["."], exclude = ["tests*", "benchmark*", "media*", "docs*", "patches*"]}} 30 | -------------------------------------------------------------------------------- /yunchang/hybrid/utils.py: -------------------------------------------------------------------------------- 1 | from yunchang.ring import ( 2 | ring_flash_attn_func, 3 | ring_flash_attn_qkvpacked_func, 4 | zigzag_ring_flash_attn_func, 5 | zigzag_ring_flash_attn_qkvpacked_func, 6 | stripe_flash_attn_func, 7 | stripe_flash_attn_qkvpacked_func, 8 | ring_pytorch_attn_func, 9 | ring_flashinfer_attn_func, 10 | ring_flashinfer_attn_qkvpacked_func, 11 | ring_npu_flash_attn_func 12 | ) 13 | 14 | RING_IMPL_DICT = { 15 | "basic": ring_flash_attn_func, 16 | "zigzag": zigzag_ring_flash_attn_func, 17 | "strip": stripe_flash_attn_func, 18 | "basic_pytorch": ring_pytorch_attn_func, 19 | "basic_flashinfer": ring_flashinfer_attn_func, 20 | "basic_npu": ring_npu_flash_attn_func 21 | } 22 | 23 | RING_IMPL_QKVPACKED_DICT = { 24 | "basic": ring_flash_attn_qkvpacked_func, 25 | "zigzag": zigzag_ring_flash_attn_qkvpacked_func, 26 | "strip": stripe_flash_attn_qkvpacked_func, 27 | "basic_flashinfer": ring_flashinfer_attn_qkvpacked_func, 28 | } 29 | -------------------------------------------------------------------------------- /yunchang/ring/__init__.py: -------------------------------------------------------------------------------- 1 | from .ring_flash_attn import ( 2 | ring_flash_attn_func, 3 | ring_flash_attn_kvpacked_func, 4 | ring_flash_attn_qkvpacked_func, 5 | ) 6 | from .ring_flash_attn_varlen import ( 7 | ring_flash_attn_varlen_func, 8 | ring_flash_attn_varlen_kvpacked_func, 9 | ring_flash_attn_varlen_qkvpacked_func, 10 | ) 11 | from .zigzag_ring_flash_attn import ( 12 | zigzag_ring_flash_attn_func, 13 | zigzag_ring_flash_attn_kvpacked_func, 14 | zigzag_ring_flash_attn_qkvpacked_func, 15 | ) 16 | from .zigzag_ring_flash_attn_varlen import ( 17 | zigzag_ring_flash_attn_varlen_func, 18 | zigzag_ring_flash_attn_varlen_qkvpacked_func, 19 | zigzag_ring_flash_attn_varlen_qkvpacked_func, 20 | ) 21 | from .stripe_flash_attn import ( 22 | stripe_flash_attn_func, 23 | stripe_flash_attn_kvpacked_func, 24 | stripe_flash_attn_qkvpacked_func, 25 | ) 26 | 27 | from .ring_pytorch_attn import ( 28 | ring_pytorch_attn_func, 29 | ) 30 | 31 | from .ring_flashinfer_attn import ( 32 | ring_flashinfer_attn_func, 33 | ring_flashinfer_attn_kvpacked_func, 34 | ring_flashinfer_attn_qkvpacked_func, 35 | ) 36 | 37 | from .ring_npu_flash_attn import ( 38 | ring_npu_flash_attn_func, 39 | ) -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build setuptools wheel 33 | - name: Build package 34 | run: | 35 | python -m build 36 | - name: Publish package 37 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 38 | with: 39 | user: __token__ 40 | password: ${{ secrets.PYPI_API_TOKEN }} 41 | -------------------------------------------------------------------------------- /scripts/run_dit.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PWD:$PYTHONPATH# export NCCL_PXN_DISABLE=1 2 | # export NCCL_DEBUG=INFO 3 | # export NCCL_SOCKET_IFNAME=eth0 4 | # export NCCL_IB_GID_INDEX=3 5 | # export NCCL_IB_DISABLE=0 6 | # export NCCL_NET_GDR_LEVEL=2 7 | # export NCCL_IB_QPS_PER_CONNECTION=4 8 | # export NCCL_IB_TC=160 9 | # export NCCL_IB_TIMEOUT=22 10 | 11 | export CUDA_DEVICE_MAX_CONNECTIONS=1 12 | 13 | # nccl settings 14 | #export NCCL_DEBUG=INFO 15 | export NCCL_SOCKET_IFNAME=eth0 16 | export NCCL_IB_GID_INDEX=3 17 | export NCCL_IB_DISABLE=0 18 | export NCCL_NET_GDR_LEVEL=2 19 | export NCCL_IB_QPS_PER_CONNECTION=4 20 | export NCCL_IB_TC=160 21 | export NCCL_IB_TIMEOUT=22 22 | 23 | export GLOO_SOCKET_IFNAME=eth0 24 | 25 | 26 | # comment this line fwd+bwd 27 | # FWD_FLAG="--fwd_only" 28 | 29 | NHEADS=24 30 | SEQLEN=1024 31 | GROUP_NUM=1 32 | GPU_NUM=2 33 | HEAD_SIZE=128 34 | ULYSSES_DEGREE=8 35 | 36 | NRANK=${NRANK:-0} 37 | # RING_IMPL_TYPE="zigzag" 38 | 39 | # make sure NHEADS // GROUP_NUM % ULYSSES_DEGREE == 0 40 | for attn_type in "torch" "fa" "fa3"; do 41 | for ULYSSES_DEGREE in 2; do 42 | for RING_IMPL_TYPE in "basic"; do 43 | torchrun --nproc_per_node $GPU_NUM --node_rank $NRANK benchmark/benchmark_longctx.py \ 44 | --nheads $NHEADS --group_num $GROUP_NUM --batch_size 1 $FWD_FLAG --seq_len $SEQLEN --head_size $HEAD_SIZE \ 45 | --ulysses_degree $ULYSSES_DEGREE --ring_impl_type $RING_IMPL_TYPE --no_causal --attn_type $attn_type --use_ulysses 46 | done 47 | done 48 | done 49 | -------------------------------------------------------------------------------- /scripts/run_gqa.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PWD:$PYTHONPATH 2 | # export NCCL_PXN_DISABLE=1 3 | # export NCCL_DEBUG=INFO 4 | # export NCCL_SOCKET_IFNAME=eth0 5 | # export NCCL_IB_GID_INDEX=3 6 | # export NCCL_IB_DISABLE=0 7 | # export NCCL_NET_GDR_LEVEL=2 8 | # export NCCL_IB_QPS_PER_CONNECTION=4 9 | # export NCCL_IB_TC=160 10 | # export NCCL_IB_TIMEOUT=22 11 | 12 | export CUDA_DEVICE_MAX_CONNECTIONS=1 13 | 14 | # nccl settings 15 | #export NCCL_DEBUG=INFO 16 | export NCCL_SOCKET_IFNAME=eth0 17 | export NCCL_IB_GID_INDEX=3 18 | export NCCL_IB_DISABLE=0 19 | export NCCL_NET_GDR_LEVEL=2 20 | export NCCL_IB_QPS_PER_CONNECTION=4 21 | export NCCL_IB_TC=160 22 | export NCCL_IB_TIMEOUT=22 23 | 24 | export GLOO_SOCKET_IFNAME=eth0 25 | 26 | 27 | # comment this line fwd+bwd 28 | # FWD_FLAG="--fwd_only" 29 | 30 | NHEADS=64 31 | SEQLEN=131072 32 | GROUP_NUM=8 33 | GPU_NUM=8 34 | ULYSSES_DEGREE=1 35 | 36 | NRANK=${NRANK:-0} 37 | # RING_IMPL_TYPE="zigzag" 38 | 39 | # make sure NHEADS // GROUP_NUM % ULYSSES_DEGREE == 0 40 | for ULYSSES_DEGREE in 8 4 2 1; do 41 | for RING_IMPL_TYPE in "zigzag"; do 42 | torchrun --nproc_per_node $GPU_NUM --node_rank $NRANK benchmark/benchmark_longctx.py --nheads $NHEADS --group_num $GROUP_NUM --batch_size 1 $FWD_FLAG --seq_len $SEQLEN --ulysses_degree $ULYSSES_DEGREE --ring_impl_type $RING_IMPL_TYPE 43 | done 44 | done 45 | 46 | torchrun --nproc_per_node $GPU_NUM --node_rank $NRANK benchmark/benchmark_ring_func.py --nheads $NHEADS --group_num $GROUP_NUM --batch_size 1 $FWD_FLAG --seq_len $SEQLEN 47 | 48 | -------------------------------------------------------------------------------- /scripts/run_qkvpack_compare.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PWD:$PYTHONPATH 2 | 3 | export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" 4 | # export NCCL_PXN_DISABLE=1 5 | # export NCCL_DEBUG=INFO 6 | # export NCCL_SOCKET_IFNAME=eth0 7 | # export NCCL_IB_GID_INDEX=3 8 | # export NCCL_IB_DISABLE=0 9 | # export NCCL_NET_GDR_LEVEL=2 10 | # export NCCL_IB_QPS_PER_CONNECTION=4 11 | # export NCCL_IB_TC=160 12 | # export NCCL_IB_TIMEOUT=22 13 | # export NCCL_P2P=0 14 | 15 | # torchrun --nproc_per_node 8 test/test_hybrid_attn.py 16 | 17 | FWD_FLAG="--fwd_only" 18 | 19 | 20 | 21 | # SEQLEN=512 22 | # SEQLEN=1024 23 | # SEQLEN=4096 24 | # SEQLEN=512 25 | # SEQLEN=16384 26 | SEQLEN=32768 #128K 27 | 28 | NHEADS=32 29 | 30 | # HEAD_SIZE=128 31 | HEAD_SIZE=32 32 | GROUP_NUM=4 33 | BS=1 34 | 35 | GPU_NUM=8 36 | 37 | # USE_PROFILE="--use_profiler" 38 | 39 | # NHEADS // GROUP_NUM > ulysses_degree 40 | 41 | for RING_IMPL_TYPE in "basic" "zigzag" "strip"; do 42 | for ULYSSES_DEGREE in 8 4 2 1; do 43 | 44 | torchrun --nproc_per_node $GPU_NUM benchmark/benchmark_longctx_qkvpacked.py \ 45 | --nheads $NHEADS \ 46 | --batch_size $BS \ 47 | --seq_len $SEQLEN \ 48 | --head_size $HEAD_SIZE \ 49 | --ulysses_degree $ULYSSES_DEGREE \ 50 | --ring_impl_type $RING_IMPL_TYPE \ 51 | $FWD_FLAG 52 | 53 | torchrun --nproc_per_node $GPU_NUM benchmark/benchmark_longctx.py \ 54 | --nheads $NHEADS \ 55 | --group_num $GROUP_NUM \ 56 | --batch_size $BS \ 57 | --seq_len $SEQLEN \ 58 | --head_size $HEAD_SIZE \ 59 | --ulysses_degree $ULYSSES_DEGREE \ 60 | --ring_impl_type $RING_IMPL_TYPE \ 61 | $FWD_FLAG 62 | 63 | done 64 | done 65 | 66 | 67 | -------------------------------------------------------------------------------- /yunchang/comm/extract_local.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | from yunchang.globals import PROCESS_GROUP 5 | 6 | 7 | def stripe_extract_local(value, rank, world_size, rd, ud, *args, **kwargs): 8 | # ud at the highest dim 9 | input_dim = value.dim() 10 | assert input_dim >= 2 11 | 12 | batch_size, seqlen, *rest = value.shape 13 | 14 | assert dist.get_world_size(group=PROCESS_GROUP.RING_PG) == rd 15 | assert dist.get_world_size(group=PROCESS_GROUP.ULYSSES_PG) == ud 16 | 17 | value = value.reshape(batch_size, seqlen // rd, rd, -1).contiguous() 18 | value = value.transpose(1, 2).reshape(batch_size, seqlen, -1).contiguous() 19 | value = value.chunk(world_size, dim=1)[rank] 20 | 21 | new_shape = [batch_size, seqlen // world_size] + rest 22 | return value.reshape(new_shape) 23 | 24 | 25 | def basic_extract_local(value, rank, world_size, *args, **kwargs): 26 | return value.chunk(world_size, dim=1)[rank].detach().clone() 27 | 28 | 29 | def zigzag_extract_local(value, rank, world_size, rd, ud, dim=1, *args, **kwargs): 30 | """ 31 | value is a tensor of shape (bs, seqlen, ...) 32 | """ 33 | input_dim = value.dim() 34 | assert input_dim >= 2 35 | batch_size, seqlen, *rest = value.shape 36 | 37 | value_chunks = value.chunk(2 * rd, dim=dim) 38 | r_rank = dist.get_rank(group=PROCESS_GROUP.RING_PG) 39 | u_rank = dist.get_rank(group=PROCESS_GROUP.ULYSSES_PG) 40 | 41 | assert dist.get_world_size(group=PROCESS_GROUP.RING_PG) == rd 42 | assert dist.get_world_size(group=PROCESS_GROUP.ULYSSES_PG) == ud 43 | 44 | local_value = torch.cat( 45 | [value_chunks[r_rank], value_chunks[2 * rd - r_rank - 1]], dim=dim 46 | ).chunk(ud, dim=dim)[u_rank] 47 | 48 | new_shape = [batch_size, seqlen // world_size] + rest 49 | return local_value.reshape(new_shape).contiguous() 50 | 51 | 52 | 53 | EXTRACT_FUNC_DICT = { 54 | "basic": basic_extract_local, 55 | "strip": stripe_extract_local, 56 | "zigzag": zigzag_extract_local, 57 | "basic_pytorch": basic_extract_local, 58 | "basic_flashinfer": basic_extract_local, 59 | "basic_npu": basic_extract_local, 60 | } 61 | -------------------------------------------------------------------------------- /docs/install_amd.md: -------------------------------------------------------------------------------- 1 | ## Install for AMD GPU 2 | 3 | Supported GPU : MI300X, MI308X 4 | 5 | GPU arch : gfx942 6 | 7 | Step 1: prepare docker envrionment 8 | 9 | Tow recommended docker container to start with 10 | 11 | - rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 : hosted in dockerhub, no conda 12 | - [dockerhub repo](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/Dockerfile.rocm62.ubuntu-22.04) : Customerized Dockerfile with conda virtual env and develop kit support 13 | 14 | An example to create an docker container : 15 | 16 | ```bash 17 | # create docker container 18 | IMG=rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 19 | tag=py310-rocm6.2-distattn-dev 20 | 21 | docker_args=$(echo -it --privileged \ 22 | --name $tag \ 23 | --ulimit memlock=-1:-1 --net=host --cap-add=IPC_LOCK \ 24 | --device=/dev/kfd --device=/dev/dri \ 25 | --ipc=host \ 26 | --security-opt seccomp=unconfined \ 27 | --shm-size 16G \ 28 | --group-add video \ 29 | -v $(readlink -f `pwd`):/workspace \ 30 | --workdir /workspace \ 31 | --cpus=$((`nproc` / 2 - 1)) \ 32 | $IMG 33 | ) 34 | 35 | docker_args=($docker_args) 36 | 37 | docker container create "${docker_args[@]}" 38 | 39 | # start it 40 | docker start -a -i $tag 41 | ``` 42 | 43 | Update ROCM SDK using this [script](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/update_sdk.sh): 44 | 45 | ```bash 46 | # e.g.: 47 | ROCM_VERSION=6.3 bash rocm/update_sdk.sh 48 | ``` 49 | 50 | Step 2 : build from local. 51 | 52 | install flash_attn from source 53 | 54 | ```bash 55 | pip install flash_attn@git+https://git@github.com/Dao-AILab/flash-attention.git 56 | ``` 57 | 58 | then install yunchang 59 | 60 | > MAX_JOBS=$(nproc) pip install . -verbose 61 | 62 | **Features:** 63 | 64 | 1. No Limitation on the Number of Heads: Our approach does not impose a restriction on the number of heads, providing greater flexibility for various attention mechanisms. 65 | 66 | 2. Cover the Capability of either Ulysses and Ring: By setting the ulysses_degree to the sequence parallel degree, the system operates identically to Ulysses. Conversely, setting the ulysses_degree to 1 mirrors the functionality of Ring. 67 | 68 | 3. Enhanced Performance: We achieve superior performance benchmarks over both Ulysses and Ring, offering a more efficient solution for attention mechanism computations. 69 | 70 | 4. Compatibility with Advanced Parallel Strategies: LongContextAttention is fully compatible with other sophisticated parallelization techniques, including Tensor Parallelism, ZeRO, and Pipeline Parallelism, ensuring seamless integration with the latest advancements in parallel computing. 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /yunchang/ring/triton_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | 6 | @triton.jit 7 | def flatten_kernel( 8 | # pointers to matrices 9 | OUT, 10 | LSE, 11 | CU_SEQLENS, 12 | # strides 13 | stride_out_nheads, 14 | stride_out_seqlen, 15 | stride_lse_batch, 16 | stride_lse_nheads, 17 | stride_lse_seqlen, 18 | # meta-parameters 19 | BLOCK_M: tl.constexpr, 20 | ): 21 | pid_m = tl.program_id(axis=0) 22 | pid_batch = tl.program_id(axis=1) 23 | pid_head = tl.program_id(axis=2) 24 | 25 | start_idx = tl.load(CU_SEQLENS + pid_batch) 26 | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx 27 | LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads 28 | OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen 29 | 30 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 31 | 32 | LSE = LSE + rm[:, None] * stride_lse_seqlen 33 | x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) 34 | 35 | OUT = OUT + rm[:, None] * stride_out_seqlen 36 | tl.store(OUT, x, mask=rm[:, None] < seqlen) 37 | 38 | 39 | def flatten_varlen_lse(lse, cu_seqlens): 40 | """ 41 | Arguments: 42 | lse: (batch_size, nheads, max_seqlen) 43 | cu_seqlens: (batch_size + 1,) 44 | Return: 45 | flatten_lse: (nheads, total_seqlen) 46 | """ 47 | total_seqlen = cu_seqlens[-1] 48 | batch_size, nheads, max_seqlen = lse.shape 49 | output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device) 50 | 51 | grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) 52 | BLOCK_M = 4 53 | 54 | with torch.cuda.device(lse.device.index): 55 | flatten_kernel[grid]( 56 | output, 57 | lse, 58 | cu_seqlens, 59 | # strides 60 | output.stride(0), 61 | output.stride(1), 62 | lse.stride(0), 63 | lse.stride(1), 64 | lse.stride(2), 65 | BLOCK_M, 66 | ) 67 | return output 68 | 69 | 70 | @triton.jit 71 | def unflatten_kernel( 72 | # pointers to matrices 73 | OUT, 74 | LSE, 75 | CU_SEQLENS, 76 | # strides 77 | stride_out_batch, 78 | stride_out_nheads, 79 | stride_out_seqlen, 80 | stride_lse_seqlen, 81 | stride_lse_nheads, 82 | # meta-parameters 83 | BLOCK_M: tl.constexpr, 84 | ): 85 | pid_m = tl.program_id(axis=0) 86 | pid_batch = tl.program_id(axis=1) 87 | pid_head = tl.program_id(axis=2) 88 | 89 | start_idx = tl.load(CU_SEQLENS + pid_batch) 90 | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx 91 | LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen 92 | OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads 93 | 94 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 95 | 96 | LSE = LSE + rm[:, None] * stride_lse_seqlen 97 | x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) 98 | 99 | OUT = OUT + rm[:, None] * stride_out_seqlen 100 | tl.store(OUT, x, mask=rm[:, None] < seqlen) 101 | 102 | 103 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): 104 | """ 105 | Arguments: 106 | lse: (total_seqlen, nheads, 1) 107 | cu_seqlens: (batch_size + 1,) 108 | max_seqlen: int 109 | Return: 110 | unflatten_lse: (batch_size, nheads, max_seqlen) 111 | """ 112 | lse = lse.unsqueeze(dim=-1) 113 | batch_size = len(cu_seqlens) - 1 114 | nheads = lse.shape[1] 115 | output = torch.empty( 116 | (batch_size, nheads, max_seqlen), 117 | dtype=lse.dtype, 118 | device=lse.device, 119 | ) 120 | 121 | grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) 122 | BLOCK_M = 4 123 | 124 | with torch.cuda.device(lse.device.index): 125 | unflatten_kernel[grid]( 126 | output, 127 | lse, 128 | cu_seqlens, 129 | # strides 130 | output.stride(0), 131 | output.stride(1), 132 | output.stride(2), 133 | lse.stride(0), 134 | lse.stride(1), 135 | BLOCK_M, 136 | ) 137 | return output 138 | -------------------------------------------------------------------------------- /yunchang/ring/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import torch.nn.functional as F 6 | 7 | __all__ = ["update_out_and_lse", "RingComm"] 8 | 9 | @torch.jit.script 10 | def _update_out_and_lse( 11 | out: torch.Tensor, 12 | lse: torch.Tensor, 13 | block_out: torch.Tensor, 14 | block_lse: torch.Tensor, 15 | ) -> Tuple[torch.Tensor, torch.Tensor]: 16 | 17 | block_out = block_out.to(torch.float32) 18 | block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) 19 | 20 | # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) 21 | # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out 22 | # For additional context and discussion, please refer to: 23 | # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 24 | out = out - F.sigmoid(block_lse - lse) * (out - block_out) 25 | lse = lse - F.logsigmoid(lse - block_lse) 26 | 27 | return out, lse 28 | 29 | 30 | def update_out_and_lse( 31 | out: Optional[torch.Tensor], 32 | lse: Optional[torch.Tensor], 33 | block_out: torch.Tensor, 34 | block_lse: torch.Tensor, 35 | slice_=None, 36 | ) -> Tuple[torch.Tensor, torch.Tensor]: 37 | if out is None: 38 | if slice_ is not None: 39 | raise RuntimeError("first update_out_and_lse should not pass slice_ args") 40 | out = block_out.to(torch.float32) 41 | lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) 42 | elif slice_ is not None: 43 | slice_out, slice_lse = out[slice_], lse[slice_] 44 | slice_out, slice_lse = _update_out_and_lse( 45 | slice_out, slice_lse, block_out, block_lse 46 | ) 47 | out[slice_], lse[slice_] = slice_out, slice_lse 48 | else: 49 | out, lse = _update_out_and_lse(out, lse, block_out, block_lse) 50 | return out, lse 51 | 52 | 53 | @torch.jit.script 54 | def flatten_varlen_lse(lse, cu_seqlens): 55 | new_lse = [] 56 | for i in range(len(cu_seqlens) - 1): 57 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 58 | new_lse.append(lse[i, :, : end - start]) 59 | return torch.cat(new_lse, dim=1) 60 | 61 | 62 | @torch.jit.script 63 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): 64 | num_seq = len(cu_seqlens) - 1 65 | num_head = lse.shape[-2] 66 | new_lse = torch.empty( 67 | (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device 68 | ) 69 | for i in range(num_seq): 70 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 71 | new_lse[i, : end - start] = lse[start:end] 72 | return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() 73 | 74 | 75 | class RingComm: 76 | def __init__(self, process_group: dist.ProcessGroup): 77 | self._process_group = process_group 78 | self._ops = [] 79 | self.rank = dist.get_rank(self._process_group) 80 | self.world_size = dist.get_world_size(self._process_group) 81 | self._reqs = None 82 | 83 | self.send_rank = (self.rank + 1) % self.world_size 84 | self.recv_rank = (self.rank - 1) % self.world_size 85 | 86 | if process_group is not None: 87 | self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) 88 | self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) 89 | 90 | def send_recv( 91 | self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None 92 | ) -> torch.Tensor: 93 | if recv_tensor is None: 94 | res = torch.empty_like(to_send) 95 | # print(f"send_recv: empty_like {to_send.shape}") 96 | else: 97 | res = recv_tensor 98 | 99 | send_op = dist.P2POp( 100 | dist.isend, to_send, self.send_rank, group=self._process_group 101 | ) 102 | recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) 103 | self._ops.append(send_op) 104 | self._ops.append(recv_op) 105 | return res 106 | 107 | def commit(self): 108 | if self._reqs is not None: 109 | raise RuntimeError("commit called twice") 110 | self._reqs = dist.batch_isend_irecv(self._ops) 111 | 112 | def wait(self): 113 | if self._reqs is None: 114 | raise RuntimeError("wait called before commit") 115 | for req in self._reqs: 116 | req.wait() 117 | self._reqs = None 118 | self._ops = [] -------------------------------------------------------------------------------- /test/test_ulysses_attn_npu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch 3 | import torch.distributed as dist 4 | from yunchang import UlyssesAttention 5 | 6 | try: 7 | import torch_npu 8 | except ImportError: 9 | from flash_attn import flash_attn_func 10 | from yunchang.kernels import AttnType 11 | 12 | def log(msg, a, rank0_only=False): 13 | world_size = dist.get_world_size() 14 | rank = dist.get_rank() 15 | if rank0_only: 16 | if rank == 0: 17 | print( 18 | f"{msg}: " 19 | f"max {a.abs().max().item()}, " 20 | f"mean {a.abs().mean().item()}", 21 | flush=True, 22 | ) 23 | return 24 | 25 | for i in range(world_size): 26 | if i == rank: 27 | if rank == 0: 28 | print(f"{msg}:") 29 | print( 30 | f"[{rank}] " 31 | f"max {a.abs().max().item()}, " 32 | f"mean {a.abs().mean().item()}", 33 | flush=True, 34 | ) 35 | dist.barrier() 36 | 37 | 38 | if __name__ == "__main__": 39 | dist.init_process_group("hccl") 40 | 41 | rank = dist.get_rank() 42 | world_size = dist.get_world_size() 43 | dtype = torch.bfloat16 44 | device = torch.device(f"npu:{rank}") 45 | 46 | batch_size = 2 47 | seqlen = 3816 48 | nheads = 8 49 | d = 128 50 | dropout_p = 0 51 | causal = True 52 | deterministic = False 53 | 54 | assert seqlen % world_size == 0 55 | assert d % 8 == 0 56 | # assert batch_size == 1 57 | 58 | q = torch.randn( 59 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 60 | ) 61 | k = torch.randn( 62 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 63 | ) 64 | v = torch.randn( 65 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 66 | ) 67 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 68 | 69 | dist.broadcast(q, src=0) 70 | dist.broadcast(k, src=0) 71 | dist.broadcast(v, src=0) 72 | dist.broadcast(dout, src=0) 73 | 74 | local_q = q.chunk(world_size, dim=1)[rank].detach().clone() 75 | local_q.requires_grad = True 76 | local_k = k.chunk(world_size, dim=1)[rank].detach().clone() 77 | local_k.requires_grad = True 78 | local_v = v.chunk(world_size, dim=1)[rank].detach().clone() 79 | local_v.requires_grad = True 80 | 81 | local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone() 82 | 83 | # prcess_group == sequence_process_group 84 | sp_pg = None #dist.new_group(ranks=[i for i in range(world_size)]) 85 | 86 | dist_attn = UlyssesAttention(sp_pg, attn_type=AttnType.NPU) 87 | 88 | if rank == 0: 89 | print("#" * 30) 90 | print("# ds-ulysses forward:") 91 | print("#" * 30) 92 | 93 | local_out = dist_attn( 94 | local_q, 95 | local_k, 96 | local_v, 97 | dropout_p=dropout_p, 98 | causal=causal, 99 | window_size=(-1, -1), 100 | softcap=0.0, 101 | alibi_slopes=None, 102 | deterministic=deterministic, 103 | return_attn_probs=True, 104 | ) 105 | 106 | if rank == 0: 107 | print("#" * 30) 108 | print("# ds-ulysses backward:") 109 | print("#" * 30) 110 | 111 | local_out.backward(local_dout) 112 | 113 | dist.barrier() 114 | 115 | if rank == 0: 116 | print("#" * 30) 117 | print("# local forward:") 118 | print("#" * 30) 119 | # reference, a local flash attn 120 | 121 | softmax_scale = q.shape[-1] ** (-0.5) 122 | out_ref = torch_npu.npu_fusion_attention_v2(q, k, v, 123 | head_num = q.shape[-2], 124 | input_layout = "BSND", 125 | scale = softmax_scale, 126 | pre_tokens=65535, 127 | next_tokens=65535)[0] 128 | 129 | if rank == 0: 130 | print("#" * 30) 131 | print("# local forward:") 132 | print("#" * 30) 133 | 134 | out_ref.backward(dout) 135 | 136 | dist.barrier() 137 | 138 | # check correctness 139 | 140 | local_out_ref = out_ref.chunk(world_size, dim=1)[rank] 141 | 142 | log("out", local_out, rank0_only=True) 143 | log("out diff", local_out_ref - local_out) 144 | -------------------------------------------------------------------------------- /yunchang/globals.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | 5 | class Singleton: 6 | _instance = None 7 | 8 | def __new__(cls, *args, **kwargs): 9 | if not cls._instance: 10 | cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs) 11 | return cls._instance 12 | 13 | 14 | class ProcessGroupSingleton(Singleton): 15 | def __init__(self): 16 | self.ULYSSES_PG = None 17 | self.RING_PG = None 18 | 19 | 20 | PROCESS_GROUP = ProcessGroupSingleton() 21 | 22 | def set_seq_parallel_pg( 23 | sp_ulysses_degree, sp_ring_degree, rank, world_size, use_ulysses_low=True 24 | ): 25 | """ 26 | sp_ulysses_degree x sp_ring_degree = seq_parallel_degree 27 | (ulysses_degree, dp_degree) 28 | """ 29 | sp_degree = sp_ring_degree * sp_ulysses_degree 30 | dp_degree = world_size // sp_degree 31 | 32 | assert ( 33 | world_size % sp_degree == 0 34 | ), f"world_size {world_size} % sp_degree {sp_ulysses_degree} == 0" 35 | 36 | num_ulysses_pgs = sp_ring_degree # world_size // sp_ulysses_degree 37 | num_ring_pgs = sp_ulysses_degree # world_size // sp_ring_degree 38 | 39 | if use_ulysses_low: 40 | for dp_rank in range(dp_degree): 41 | offset = dp_rank * sp_degree 42 | for i in range(num_ulysses_pgs): 43 | ulysses_ranks = list( 44 | range( 45 | i * sp_ulysses_degree + offset, 46 | (i + 1) * sp_ulysses_degree + offset, 47 | ) 48 | ) 49 | group = torch.distributed.new_group(ulysses_ranks) 50 | if rank in ulysses_ranks: 51 | ulyssess_pg = group 52 | 53 | for i in range(num_ring_pgs): 54 | ring_ranks = list(range(i + offset, sp_degree + offset, num_ring_pgs)) 55 | group = torch.distributed.new_group(ring_ranks) 56 | if rank in ring_ranks: 57 | ring_pg = group 58 | 59 | else: 60 | for dp_rank in range(dp_degree): 61 | offset = dp_rank * sp_degree 62 | for i in range(num_ring_pgs): 63 | ring_ranks = list( 64 | range( 65 | i * sp_ring_degree + offset, (i + 1) * sp_ring_degree + offset 66 | ) 67 | ) 68 | group = torch.distributed.new_group(ring_ranks) 69 | if rank in ring_ranks: 70 | ring_pg = group 71 | 72 | for i in range(num_ulysses_pgs): 73 | ulysses_ranks = list( 74 | range(i + offset, sp_degree + offset, num_ulysses_pgs) 75 | ) 76 | group = torch.distributed.new_group(ulysses_ranks) 77 | if rank in ulysses_ranks: 78 | ulyssess_pg = group 79 | 80 | PROCESS_GROUP.ULYSSES_PG = ulyssess_pg 81 | PROCESS_GROUP.RING_PG = ring_pg 82 | 83 | # test if flash_attn is available 84 | try: 85 | import flash_attn 86 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward 87 | HAS_FLASH_ATTN = True 88 | except ImportError: 89 | HAS_FLASH_ATTN = False 90 | 91 | try: 92 | from flash_attn_interface import _flash_attn_forward as flash_attn_forward_hopper 93 | from flash_attn_interface import _flash_attn_backward as flash_attn_func_hopper_backward 94 | from flash_attn_interface import flash_attn_func as flash3_attn_func 95 | HAS_FLASH_ATTN_HOPPER = True 96 | except ImportError: 97 | HAS_FLASH_ATTN_HOPPER = False 98 | 99 | try: 100 | from flashinfer.prefill import single_prefill_with_kv_cache 101 | HAS_FLASHINFER = True 102 | def get_cuda_arch(): 103 | major, minor = torch.cuda.get_device_capability() 104 | return f"{major}.{minor}" 105 | 106 | cuda_arch = get_cuda_arch() 107 | os.environ['TORCH_CUDA_ARCH_LIST'] = cuda_arch 108 | print(f"Set TORCH_CUDA_ARCH_LIST to {cuda_arch}") 109 | except ImportError: 110 | HAS_FLASHINFER = False 111 | 112 | try: 113 | import aiter 114 | from aiter import flash_attn_func as flash_attn_func_aiter 115 | HAS_AITER = True 116 | except ImportError: 117 | HAS_AITER = False 118 | 119 | try: 120 | import sageattention 121 | HAS_SAGE_ATTENTION = True 122 | except ImportError: 123 | HAS_SAGE_ATTENTION = False 124 | 125 | try: 126 | import spas_sage_attn 127 | HAS_SPARSE_SAGE_ATTENTION = True 128 | except ImportError: 129 | HAS_SPARSE_SAGE_ATTENTION = False 130 | 131 | try: 132 | import torch_npu 133 | HAS_NPU = True 134 | except ImportError: 135 | HAS_NPU = False -------------------------------------------------------------------------------- /test/test_ulysses_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from yunchang import UlyssesAttention 4 | 5 | from flash_attn import flash_attn_func 6 | from yunchang.kernels import AttnType 7 | 8 | def log(msg, a, rank0_only=False): 9 | world_size = dist.get_world_size() 10 | rank = dist.get_rank() 11 | if rank0_only: 12 | if rank == 0: 13 | print( 14 | f"{msg}: " 15 | f"max {a.abs().max().item()}, " 16 | f"mean {a.abs().mean().item()}", 17 | flush=True, 18 | ) 19 | return 20 | 21 | for i in range(world_size): 22 | if i == rank: 23 | if rank == 0: 24 | print(f"{msg}:") 25 | print( 26 | f"[{rank}] " 27 | f"max {a.abs().max().item()}, " 28 | f"mean {a.abs().mean().item()}", 29 | flush=True, 30 | ) 31 | dist.barrier() 32 | 33 | 34 | if __name__ == "__main__": 35 | dist.init_process_group("nccl") 36 | 37 | rank = dist.get_rank() 38 | world_size = dist.get_world_size() 39 | dtype = torch.bfloat16 40 | device = torch.device(f"cuda:{rank}") 41 | 42 | batch_size = 2 43 | seqlen = 3816 44 | nheads = 8 45 | d = 128 46 | dropout_p = 0 47 | causal = True 48 | deterministic = False 49 | 50 | assert seqlen % world_size == 0 51 | assert d % 8 == 0 52 | # assert batch_size == 1 53 | 54 | q = torch.randn( 55 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 56 | ) 57 | k = torch.randn( 58 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 59 | ) 60 | v = torch.randn( 61 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 62 | ) 63 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 64 | 65 | dist.broadcast(q, src=0) 66 | dist.broadcast(k, src=0) 67 | dist.broadcast(v, src=0) 68 | dist.broadcast(dout, src=0) 69 | 70 | local_q = q.chunk(world_size, dim=1)[rank].detach().clone() 71 | local_q.requires_grad = True 72 | local_k = k.chunk(world_size, dim=1)[rank].detach().clone() 73 | local_k.requires_grad = True 74 | local_v = v.chunk(world_size, dim=1)[rank].detach().clone() 75 | local_v.requires_grad = True 76 | 77 | local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone() 78 | 79 | # prcess_group == sequence_process_group 80 | sp_pg = None #dist.new_group(ranks=[i for i in range(world_size)]) 81 | 82 | dist_attn = UlyssesAttention(sp_pg, attn_type=AttnType.FA) 83 | 84 | if rank == 0: 85 | print("#" * 30) 86 | print("# ds-ulysses forward:") 87 | print("#" * 30) 88 | 89 | local_out = dist_attn( 90 | local_q, 91 | local_k, 92 | local_v, 93 | dropout_p=dropout_p, 94 | causal=causal, 95 | window_size=(-1, -1), 96 | softcap=0.0, 97 | alibi_slopes=None, 98 | deterministic=deterministic, 99 | return_attn_probs=True, 100 | ) 101 | 102 | if rank == 0: 103 | print("#" * 30) 104 | print("# ds-ulysses backward:") 105 | print("#" * 30) 106 | 107 | local_out.backward(local_dout) 108 | 109 | dist.barrier() 110 | 111 | if rank == 0: 112 | print("#" * 30) 113 | print("# local forward:") 114 | print("#" * 30) 115 | # reference, a local flash attn 116 | out_ref, _, _ = flash_attn_func( 117 | q, 118 | k, 119 | v, 120 | dropout_p=dropout_p, 121 | causal=causal, 122 | window_size=(-1, -1), 123 | softcap=0.0, 124 | alibi_slopes=None, 125 | deterministic=deterministic, 126 | return_attn_probs=True, 127 | ) 128 | if rank == 0: 129 | print("#" * 30) 130 | print("# local forward:") 131 | print("#" * 30) 132 | 133 | out_ref.backward(dout) 134 | 135 | dist.barrier() 136 | 137 | # check correctness 138 | 139 | local_out_ref = out_ref.chunk(world_size, dim=1)[rank] 140 | 141 | log("out", local_out, rank0_only=True) 142 | log("out diff", local_out_ref - local_out) 143 | 144 | local_dq_ref = q.grad.chunk(world_size, dim=1)[rank] 145 | log("load_dq", local_q.grad) 146 | log("dq diff", local_dq_ref - local_q.grad) 147 | 148 | local_dk_ref = k.grad.chunk(world_size, dim=1)[rank] 149 | log("load_dk", local_k.grad) 150 | log("dk diff", local_dk_ref - local_k.grad) 151 | 152 | local_dv_ref = v.grad.chunk(world_size, dim=1)[rank] 153 | log("load_dk", local_v.grad) 154 | log("dv diff", local_dv_ref - local_v.grad) 155 | -------------------------------------------------------------------------------- /yunchang/ring/ring_pytorch_attn.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/huggingface/picotron/blob/main/picotron/context_parallel/context_parallel.py 2 | # Copyright 2024 The HuggingFace Inc. team and Jiarui Fang. 3 | 4 | import math 5 | import torch 6 | import torch.nn.functional as F 7 | from typing import Any, Optional, Tuple 8 | from yunchang.kernels import select_flash_attn_impl, AttnType 9 | from .utils import RingComm, update_out_and_lse 10 | from yunchang.kernels.attention import pytorch_attn_forward, pytorch_attn_backward 11 | 12 | def ring_pytorch_attn_func( 13 | q, 14 | k, 15 | v, 16 | dropout_p=0.0, 17 | softmax_scale=None, 18 | causal=False, 19 | window_size=(-1, -1), 20 | softcap=0.0, 21 | alibi_slopes=None, 22 | deterministic=False, 23 | return_attn_probs=False, 24 | group=None, 25 | attn_type: AttnType = AttnType.FA, 26 | attn_processor=None, 27 | ): 28 | return RingAttentionFunc.apply(group, q, k, v, softmax_scale, causal) 29 | 30 | class RingAttentionFunc(torch.autograd.Function): 31 | 32 | @staticmethod 33 | def forward(ctx, group, q, k, v, sm_scale, is_causal): 34 | 35 | comm = RingComm(group) 36 | #TODO(fmom): add flex attention 37 | #TODO(fmom): add flash attention 38 | #TODO(fmom): Find a better to save these tensors without cloning 39 | k_og = k.clone() 40 | v_og = v.clone() 41 | out, lse = None, None 42 | next_k, next_v = None, None 43 | 44 | if sm_scale is None: 45 | sm_scale = q.shape[-1] ** -0.5 46 | 47 | for step in range(comm.world_size): 48 | if step + 1 != comm.world_size: 49 | next_k = comm.send_recv(k) 50 | next_v = comm.send_recv(v) 51 | comm.commit() 52 | 53 | if not is_causal or step <= comm.rank: 54 | block_out, block_lse = pytorch_attn_forward( 55 | q, k, v, softmax_scale = sm_scale, causal = is_causal and step == 0 56 | ) 57 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 58 | 59 | if step + 1 != comm.world_size: 60 | comm.wait() 61 | k = next_k 62 | v = next_v 63 | 64 | out = out.to(q.dtype) 65 | 66 | ctx.save_for_backward(q, k_og, v_og, out, lse.squeeze(-1)) 67 | ctx.sm_scale = sm_scale 68 | ctx.is_causal = is_causal 69 | ctx.group = group 70 | 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, dout, *args): 75 | 76 | 77 | q, k, v, out, softmax_lse = ctx.saved_tensors 78 | sm_scale = ctx.sm_scale 79 | is_causal = ctx.is_causal 80 | 81 | kv_comm = RingComm(ctx.group) 82 | d_kv_comm = RingComm(ctx.group) 83 | 84 | dq, dk, dv = None, None, None 85 | next_dk, next_dv = None, None 86 | 87 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 88 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 89 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 90 | 91 | next_dk, next_dv = None, None 92 | next_k, next_v = None, None 93 | 94 | for step in range(kv_comm.world_size): 95 | if step + 1 != kv_comm.world_size: 96 | next_k = kv_comm.send_recv(k) 97 | next_v = kv_comm.send_recv(v) 98 | kv_comm.commit() 99 | 100 | if step <= kv_comm.rank or not is_causal: 101 | bwd_causal = is_causal and step == 0 102 | 103 | block_dq_buffer, block_dk_buffer, block_dv_buffer = pytorch_attn_backward( 104 | dout, q, k, v, out, softmax_lse = softmax_lse, softmax_scale = sm_scale, causal = bwd_causal 105 | ) 106 | 107 | if dq is None: 108 | dq = block_dq_buffer.to(torch.float32) 109 | dk = block_dk_buffer.to(torch.float32) 110 | dv = block_dv_buffer.to(torch.float32) 111 | else: 112 | dq += block_dq_buffer 113 | d_kv_comm.wait() 114 | dk = block_dk_buffer + next_dk 115 | dv = block_dv_buffer + next_dv 116 | elif step != 0: 117 | d_kv_comm.wait() 118 | dk = next_dk 119 | dv = next_dv 120 | 121 | if step + 1 != kv_comm.world_size: 122 | kv_comm.wait() 123 | k = next_k 124 | v = next_v 125 | 126 | next_dk = d_kv_comm.send_recv(dk) 127 | next_dv = d_kv_comm.send_recv(dv) 128 | d_kv_comm.commit() 129 | 130 | d_kv_comm.wait() 131 | 132 | return dq, next_dk, next_dv, None, None 133 | -------------------------------------------------------------------------------- /yunchang/ulysses/attn_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation and Jiarui Fang 2 | # SPDX-License-Identifier: Apache-2.0 3 | # DeepSpeed Team & Jiarui Fang 4 | 5 | 6 | import torch 7 | 8 | from typing import Any 9 | from torch import Tensor 10 | from yunchang.kernels import AttnType, select_flash_attn_impl 11 | import torch.distributed as dist 12 | from yunchang.comm.all_to_all import SeqAllToAll4D 13 | 14 | 15 | class UlyssesAttention(torch.nn.Module): 16 | """Initialization. 17 | 18 | Arguments: 19 | local_attention (Module): local attention with q,k,v 20 | sequence_process_group (ProcessGroup): sequence parallel process group 21 | scatter_idx (int): scatter_idx for all2all comm 22 | gather_idx (int): gather_idx for all2all comm 23 | use_sync (bool): whether to synchronize after all-to-all. This flag can save cuda memory but will slow down the speed. 24 | attn_type (AttnType): attention type enum 25 | """ 26 | 27 | def __init__( 28 | self, 29 | sequence_process_group: dist.ProcessGroup = None, 30 | scatter_idx: int = 2, 31 | gather_idx: int = 1, 32 | use_sync: bool = False, 33 | attn_type : AttnType = AttnType.FA, 34 | ) -> None: 35 | 36 | super(UlyssesAttention, self).__init__() 37 | self.spg = sequence_process_group 38 | self.scatter_idx = scatter_idx 39 | self.gather_idx = gather_idx 40 | self.use_sync = use_sync 41 | self.attn_type = attn_type 42 | 43 | try: 44 | import torch_npu 45 | device = torch.device("npu") 46 | except: 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | gpu_name = torch.cuda.get_device_name(device) 49 | if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name: 50 | self.attn_type = AttnType.TORCH 51 | self.attn_fn = select_flash_attn_impl(self.attn_type, stage="fwd-bwd") 52 | 53 | def forward( 54 | self, 55 | query: Tensor, 56 | key: Tensor, 57 | value: Tensor, 58 | dropout_p=0.0, 59 | softmax_scale=None, 60 | causal=False, 61 | window_size=(-1, -1), 62 | softcap=0.0, 63 | alibi_slopes=None, 64 | deterministic=False, 65 | return_attn_probs=False, 66 | *args: Any 67 | ) -> Tensor: 68 | """forward 69 | 70 | Arguments: 71 | query (Tensor): query input to the layer 72 | key (Tensor): key input to the layer 73 | value (Tensor): value input to the layer 74 | args: other args 75 | 76 | Returns: 77 | * output (Tensor): context output 78 | """ 79 | # TODO Merge three alltoall calls into one 80 | # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! 81 | # in shape : e.g., [s/p:h:] 82 | # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) 83 | 84 | # scatter 2, gather 1 85 | q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx, self.use_sync) 86 | k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx, self.use_sync) 87 | v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx, self.use_sync) 88 | 89 | if softmax_scale is None: 90 | softmax_scale = q.shape[-1] ** -0.5 91 | 92 | if self.attn_type is AttnType.NPU: 93 | context_layer = self.attn_fn( 94 | q, 95 | k, 96 | v, 97 | head_num = q.shape[-2], 98 | input_layout = "BSND", 99 | scale = softmax_scale, 100 | pre_tokens=65535, 101 | next_tokens=65535, 102 | ) 103 | else: 104 | context_layer = self.attn_fn( 105 | q, 106 | k, 107 | v, 108 | dropout_p=dropout_p, 109 | softmax_scale = softmax_scale, 110 | causal=causal, 111 | window_size=window_size, 112 | softcap=softcap, 113 | alibi_slopes=alibi_slopes, 114 | deterministic=deterministic, 115 | return_attn_probs=return_attn_probs, 116 | ) 117 | 118 | if isinstance(context_layer, tuple): 119 | context_layer = context_layer[0] 120 | 121 | # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) 122 | # scatter 1, gather 2 123 | output = SeqAllToAll4D.apply( 124 | self.spg, context_layer, self.gather_idx, self.scatter_idx, self.use_sync 125 | ) 126 | 127 | # out e.g., [s/p::h] 128 | return output 129 | 130 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange, repeat 2 | import math 3 | 4 | import torch 5 | 6 | 7 | # adpated from flash-attention 8 | def construct_local_mask( 9 | seqlen_q, 10 | seqlen_k, 11 | window_size=(-1, -1), # -1 means infinite window size 12 | query_padding_mask=None, 13 | key_padding_mask=None, 14 | device=None, 15 | key_leftpad=None, 16 | ): 17 | row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") 18 | col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) 19 | if key_leftpad is not None: 20 | key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") 21 | col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) 22 | col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) 23 | sk = ( 24 | seqlen_k 25 | if key_padding_mask is None 26 | else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") 27 | ) 28 | sq = ( 29 | seqlen_q 30 | if query_padding_mask is None 31 | else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") 32 | ) 33 | if window_size[0] < 0: 34 | return col_idx > row_idx + sk - sq + window_size[1] 35 | else: 36 | sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk 37 | return torch.logical_or( 38 | col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), 39 | col_idx < row_idx + sk - sq - window_size[0], 40 | ) 41 | 42 | # adpated from flash-attention 43 | def attention_ref( 44 | q, 45 | k, 46 | v, 47 | query_padding_mask=None, 48 | key_padding_mask=None, 49 | attn_bias=None, 50 | dropout_p=0.0, 51 | dropout_mask=None, 52 | causal=False, 53 | window_size=(-1, -1), # -1 means infinite window size 54 | softcap=0.0, 55 | upcast=True, 56 | reorder_ops=False, 57 | key_leftpad=None, 58 | ): 59 | """ 60 | Arguments: 61 | q: (batch_size, seqlen_q, nheads, head_dim) 62 | k: (batch_size, seqlen_k, nheads_k, head_dim) 63 | v: (batch_size, seqlen_k, nheads_k, head_dim) 64 | query_padding_mask: (batch_size, seqlen_q) 65 | key_padding_mask: (batch_size, seqlen_k) 66 | attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) 67 | dropout_p: float 68 | dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) 69 | causal: whether to apply causal masking 70 | window_size: (int, int), left and right window size 71 | upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast 72 | output back to fp16/bf16. 73 | reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) 74 | without changing the math. This is to estimate the numerical error from operation 75 | reordering. 76 | Output: 77 | output: (batch_size, seqlen_q, nheads, head_dim) 78 | attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout 79 | """ 80 | if causal: 81 | window_size = (window_size[0], 0) 82 | dtype_og = q.dtype 83 | if upcast: 84 | q, k, v = q.float(), k.float(), v.float() 85 | seqlen_q, seqlen_k = q.shape[1], k.shape[1] 86 | k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) 87 | v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) 88 | d = q.shape[-1] 89 | if not reorder_ops: 90 | scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) 91 | else: 92 | scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) 93 | if softcap > 0: 94 | scores = scores / softcap 95 | scores = scores.tanh() 96 | scores = scores * softcap 97 | if key_padding_mask is not None: 98 | scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) 99 | if window_size[0] >= 0 or window_size[1] >= 0: 100 | local_mask = construct_local_mask( 101 | seqlen_q, 102 | seqlen_k, 103 | window_size, 104 | query_padding_mask, 105 | key_padding_mask, 106 | q.device, 107 | key_leftpad=key_leftpad, 108 | ) 109 | scores.masked_fill_(local_mask, float("-inf")) 110 | if attn_bias is not None: 111 | scores = scores + attn_bias 112 | attention = torch.softmax(scores, dim=-1).to(v.dtype) 113 | # Some rows might be completely masked out so we fill them with zero instead of NaN 114 | if window_size[0] >= 0 or window_size[1] >= 0: 115 | attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) 116 | # We want to mask here so that the attention matrix doesn't have any NaNs 117 | # Otherwise we'll get NaN in dV 118 | if query_padding_mask is not None: 119 | attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) 120 | dropout_scaling = 1.0 / (1 - dropout_p) 121 | # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling 122 | # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) 123 | if dropout_mask is not None: 124 | attention_drop = attention.masked_fill(~dropout_mask, 0.0) 125 | else: 126 | attention_drop = attention 127 | output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) 128 | if query_padding_mask is not None: 129 | output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) 130 | return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) -------------------------------------------------------------------------------- /test/test_hybrid_qkvpacked_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from yunchang import ( 4 | LongContextAttentionQKVPacked, 5 | set_seq_parallel_pg, 6 | EXTRACT_FUNC_DICT, 7 | RING_IMPL_QKVPACKED_DICT 8 | ) 9 | from yunchang.kernels import AttnType 10 | 11 | 12 | def log(msg, a, rank0_only=False): 13 | world_size = dist.get_world_size() 14 | rank = dist.get_rank() 15 | if rank0_only: 16 | if rank == 0: 17 | print( 18 | f"{msg}: " 19 | f"max {a.abs().max().item()}, " 20 | f"mean {a.abs().mean().item()}", 21 | flush=True, 22 | ) 23 | return 24 | 25 | for i in range(world_size): 26 | if i == rank: 27 | if rank == 0: 28 | print(f"{msg}:") 29 | print( 30 | f"[{rank}] " 31 | f"max {a.abs().max().item()}, " 32 | f"mean {a.abs().mean().item()}", 33 | flush=True, 34 | ) 35 | dist.barrier() 36 | 37 | import os 38 | 39 | def get_local_rank(): 40 | local_rank = int(os.getenv('LOCAL_RANK', '0')) 41 | return local_rank 42 | 43 | def test(ring_impl_type="zigzag"): 44 | 45 | rank = dist.get_rank() 46 | local_rank = get_local_rank() 47 | world_size = dist.get_world_size() 48 | dtype = torch.bfloat16 49 | device = torch.device(f"cuda:{local_rank}") 50 | print(f"rank {rank} local_rank {local_rank} world_size {world_size}") 51 | 52 | batch_size = 2 53 | seqlen = 1024 54 | nheads = 8 55 | d = 32 56 | dropout_p = 0.0 57 | causal = True 58 | deterministic = False 59 | 60 | assert seqlen % world_size == 0 61 | assert d % 8 == 0 62 | 63 | sp_ulysses_degree = 2 # min(world_size, nheads) 64 | sp_ring_degree = world_size // sp_ulysses_degree 65 | 66 | set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size) 67 | 68 | longctx_attn = LongContextAttentionQKVPacked(ring_impl_type=ring_impl_type, 69 | attn_type=AttnType.FA) 70 | 71 | ## prepare input and output tensors 72 | 73 | # global tensors 74 | qkv = torch.randn( 75 | batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True 76 | ) 77 | 78 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 79 | 80 | with torch.no_grad(): 81 | dist.broadcast(qkv, src=0) 82 | dist.broadcast(dout, src=0) 83 | 84 | # sharded tensors for long context attn 85 | local_qkv = ( 86 | EXTRACT_FUNC_DICT[ring_impl_type]( 87 | qkv, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 88 | ) 89 | .detach() 90 | .clone() 91 | ) 92 | local_qkv.requires_grad = True 93 | 94 | local_dout = ( 95 | EXTRACT_FUNC_DICT[ring_impl_type]( 96 | dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 97 | ) 98 | .detach() 99 | .clone() 100 | ) 101 | # shared tensors for reference 102 | local_qkv_ref = local_qkv.detach().clone() 103 | local_qkv_ref.requires_grad = True 104 | 105 | dist.barrier() 106 | if rank == 0: 107 | print("#" * 30) 108 | print("# forward:") 109 | print("#" * 30) 110 | 111 | print(f"local_qkv shape {local_qkv.shape}") 112 | local_out = longctx_attn( 113 | local_qkv, 114 | dropout_p=dropout_p, 115 | causal=causal, 116 | window_size=(-1, -1), 117 | softcap=0.0, 118 | alibi_slopes=None, 119 | deterministic=deterministic, 120 | return_attn_probs=True, 121 | ) 122 | 123 | from flash_attn import flash_attn_qkvpacked_func 124 | # local_out = out.chunk(world_size, dim=1)[rank] 125 | # local_lse = lse.chunk(world_size, dim=-1)[rank] 126 | 127 | out, lse, _ = flash_attn_qkvpacked_func( 128 | qkv, 129 | dropout_p=dropout_p, 130 | causal=causal, 131 | window_size=(-1, -1), 132 | softcap=0.0, 133 | alibi_slopes=None, 134 | deterministic=deterministic, 135 | return_attn_probs=True, 136 | ) 137 | 138 | local_out_ref = EXTRACT_FUNC_DICT[ring_impl_type]( 139 | out, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 140 | ) 141 | 142 | log("out_ref", local_out_ref, rank0_only=True) 143 | log("out", local_out, rank0_only=True) 144 | 145 | # log("lse", lse, rank0_only=True) 146 | log("out diff", local_out - local_out_ref) 147 | # log("lse diff", local_lse - ring_lse) 148 | 149 | dist.barrier() 150 | 151 | # if rank == 0: 152 | # print(local_out_ref) 153 | # print(local_out) 154 | 155 | if rank == 0: 156 | print("#" * 30) 157 | print("# backward:") 158 | print("#" * 30) 159 | 160 | # long context attn backward 161 | local_out.backward(local_dout) 162 | local_dqkv = local_qkv.grad 163 | 164 | # local ring backward 165 | out.backward(dout) 166 | dqkv = qkv.grad 167 | 168 | local_dqkv_ref = EXTRACT_FUNC_DICT[ring_impl_type]( 169 | dqkv, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 170 | ) 171 | 172 | log("load_dq", local_dqkv_ref) 173 | log("dq diff", local_dqkv - local_dqkv_ref) 174 | 175 | 176 | 177 | if __name__ == "__main__": 178 | dist.init_process_group("nccl") 179 | for ring_impl_type in ["basic", "zigzag"]: 180 | print(f"ring_impl_type: {ring_impl_type}") 181 | test(ring_impl_type) 182 | if dist.is_initialized(): 183 | dist.destroy_process_group() 184 | -------------------------------------------------------------------------------- /benchmark/benchmark_longctx_qkvpacked.py: -------------------------------------------------------------------------------- 1 | from flash_attn import flash_attn_varlen_qkvpacked_func 2 | import torch 3 | import torch.distributed as dist 4 | from yunchang import set_seq_parallel_pg, LongContextAttentionQKVPacked 5 | from yunchang.comm import EXTRACT_FUNC_DICT 6 | import torch.cuda 7 | 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser(description="args for benchmark.") 11 | 12 | parser.add_argument( 13 | "--ring_impl_type", 14 | type=str, 15 | default="basic", 16 | choices=["basic", "zigzag", "strip"], 17 | help="ring attn implementation type", 18 | ) 19 | parser.add_argument("--nheads", type=int, default=2, help="head number") 20 | parser.add_argument("--head_size", type=int, default=128, help="head number") 21 | parser.add_argument( 22 | "--seq_len", 23 | type=int, 24 | default=4 * 1024, 25 | help="local sequence length, the global sequence length is seq_len * world_size", 26 | ) 27 | parser.add_argument("--batch_size", type=int, default=2, help="batch size") 28 | parser.add_argument( 29 | "--fwd_only", action="store_true", help="benchmark forward pass only" 30 | ) 31 | parser.add_argument( 32 | "--use_ulysses_lowdim", 33 | action="store_true", 34 | default=True, 35 | help="ulysses process group on low dimension", 36 | ) 37 | parser.add_argument( 38 | "--ulysses_degree", 39 | type=int, 40 | default=1, 41 | help="ulysses attention sequence parallel degree", 42 | ) 43 | args = parser.parse_args() 44 | 45 | 46 | def color_print(text): 47 | print("\033[91m {}\033[00m".format(text)) 48 | 49 | import os 50 | def get_local_rank(): 51 | local_rank = int(os.getenv('LOCAL_RANK', '0')) 52 | return local_rank 53 | 54 | 55 | def benchmark(num_iter=100, forward_only=True, log=True): 56 | dtype = torch.bfloat16 57 | rank = dist.get_rank() 58 | local_rank = get_local_rank() 59 | world_size = dist.get_world_size() 60 | device = torch.device(f"cuda:{local_rank}") 61 | torch.cuda.set_device(device) 62 | 63 | batch_size = args.batch_size 64 | seqlen = args.seq_len 65 | nheads = args.nheads 66 | d = args.head_size 67 | 68 | dropout_p = 0 69 | causal = True 70 | deterministic = False 71 | 72 | assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}" 73 | assert d % 8 == 0 74 | 75 | qkv = torch.randn( 76 | batch_size, 77 | seqlen * world_size, 78 | 3, 79 | nheads, 80 | d, 81 | device=device, 82 | dtype=dtype, 83 | requires_grad=True, 84 | ) 85 | dout = torch.randn( 86 | batch_size, seqlen * world_size, nheads, d, device=device, dtype=dtype 87 | ) 88 | 89 | sp_ulysses_degree = min(args.ulysses_degree, world_size) 90 | sp_ring_degree = world_size // sp_ulysses_degree 91 | 92 | set_seq_parallel_pg( 93 | sp_ulysses_degree, sp_ring_degree, rank, world_size, args.use_ulysses_lowdim 94 | ) 95 | 96 | longctx_attn = LongContextAttentionQKVPacked(ring_impl_type=args.ring_impl_type) 97 | 98 | # NOTE() using zigzag and stripe have a special layout. 99 | qkv = ( 100 | EXTRACT_FUNC_DICT[args.ring_impl_type]( 101 | qkv, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 102 | ) 103 | .detach() 104 | .clone() 105 | ) 106 | qkv.requires_grad = True 107 | dout = ( 108 | EXTRACT_FUNC_DICT[args.ring_impl_type]( 109 | dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 110 | ) 111 | .detach() 112 | .clone() 113 | ) 114 | 115 | out = longctx_attn( 116 | qkv, 117 | dropout_p=dropout_p, 118 | causal=causal, 119 | window_size=(-1, -1), 120 | softcap=0.0, 121 | alibi_slopes=None, 122 | deterministic=deterministic, 123 | return_attn_probs=False, 124 | ) 125 | out.backward(dout) 126 | 127 | begin = torch.cuda.Event(enable_timing=True) 128 | begin.record() 129 | 130 | if forward_only: 131 | with torch.no_grad(): 132 | for _ in range(num_iter): 133 | _ = longctx_attn( 134 | qkv, 135 | dropout_p=dropout_p, 136 | causal=causal, 137 | window_size=(-1, -1), 138 | softcap=0.0, 139 | alibi_slopes=None, 140 | deterministic=deterministic, 141 | return_attn_probs=False, 142 | ) 143 | 144 | else: 145 | for _ in range(num_iter): 146 | qkv.grad = None 147 | out = longctx_attn( 148 | qkv, 149 | dropout_p=dropout_p, 150 | causal=causal, 151 | window_size=(-1, -1), 152 | softcap=0.0, 153 | alibi_slopes=None, 154 | deterministic=deterministic, 155 | return_attn_probs=False, 156 | ) 157 | out.backward(dout) 158 | end = torch.cuda.Event(enable_timing=True) 159 | end.record() 160 | torch.cuda.synchronize(device=device) 161 | time = begin.elapsed_time(end) / 1000.0 162 | 163 | if rank == 0 and log: 164 | color_print(f"{num_iter / time:.3f} iter/s, {time:.3f} sec") 165 | 166 | 167 | if __name__ == "__main__": 168 | dist.init_process_group("nccl") 169 | rank = dist.get_rank() 170 | 171 | forward_only = args.fwd_only 172 | 173 | torch.cuda.empty_cache() 174 | if rank == 0: 175 | color_print(vars(args)) 176 | color_print( 177 | f"# long context attention qkvpacked {args.ring_impl_type}. ulysses_degree : {args.ulysses_degree} " 178 | f"fwd_only {forward_only} " 179 | f"use_ulysses_lowdim {args.use_ulysses_lowdim} " 180 | ) 181 | benchmark(forward_only=forward_only, log=False) 182 | benchmark(forward_only=forward_only, log=True) 183 | -------------------------------------------------------------------------------- /yunchang/ring/ring_npu_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from .utils import RingComm, update_out_and_lse 4 | from yunchang.kernels.attention import ( 5 | npu_fused_attn_forward, 6 | npu_fused_attn_backward, 7 | ) 8 | from datetime import datetime 9 | 10 | 11 | def ring_npu_flash_attn_forward( 12 | process_group, 13 | q: torch.Tensor, 14 | k: torch.Tensor, 15 | v: torch.Tensor, 16 | head_num: int=None, 17 | input_layout: str="BSND" 18 | ): 19 | comm = RingComm(process_group) 20 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()}, ring_npu_flash_attn_forward") 21 | # 单卡场景直接计算 22 | if comm.world_size == 1: 23 | return npu_fused_attn_forward(q, k, v, head_num, input_layout) 24 | 25 | attention_out,softmax_max, softmax_sum, scale_value = None,None,None,None 26 | 27 | next_k, next_v = None, None 28 | 29 | for step in range(comm.world_size): 30 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_forward step: {step}") 31 | # 非最后一步:发起下一个kv的通信(异步) 32 | if step + 1 != comm.world_size: 33 | next_k = comm.send_recv(k) 34 | next_v = comm.send_recv(v) 35 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_forward commit: {step}") 36 | comm.commit() 37 | 38 | # 当前step计算(仅当step <= 当前rank时处理本地kv) 39 | if step <= comm.rank: 40 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_forward calculation: {step}") 41 | attention_out, softmax_max, softmax_sum, scale_value = npu_fused_attn_forward(q, k, v, head_num, input_layout) 42 | 43 | # 非最后一步:等待通信完成,更新kv 44 | if step + 1 != comm.world_size: 45 | comm.wait() 46 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_forward wait: {step}") 47 | k = next_k 48 | v = next_v 49 | return attention_out, softmax_max, softmax_sum, scale_value 50 | 51 | 52 | def ring_npu_flash_attn_backward( 53 | process_group,q, k, v, grad_attention_out, head_num=None, input_layout="BSND", softmax_max=None,softmax_sum=None,attention_in=None, scale_value=None): 54 | kv_comm = RingComm(process_group) 55 | d_kv_comm = RingComm(process_group) 56 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()}, ring_npu_flash_attn_backward") 57 | 58 | # 初始化梯度张量(避免None,用0张量初始化) 59 | dq = torch.zeros_like(q, dtype=torch.float32) 60 | dk = torch.zeros_like(k, dtype=torch.float32) 61 | dv = torch.zeros_like(v, dtype=torch.float32) 62 | next_k, next_v = None, None 63 | next_dk, next_dv = None, None 64 | 65 | for step in range(kv_comm.world_size): 66 | # 1. 发起kv通信(获取下一个step的kv) 67 | if step + 1 != kv_comm.world_size: 68 | next_k = kv_comm.send_recv(k) 69 | next_v = kv_comm.send_recv(v) 70 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward commit: {step}") 71 | kv_comm.commit() 72 | 73 | # 2. 计算当前step的梯度 74 | if step <= kv_comm.rank: 75 | grad_query, grad_key, grad_value = npu_fused_attn_backward( 76 | q, k, v, grad_attention_out, head_num, input_layout, softmax_max=softmax_max, softmax_sum=softmax_sum, attention_in=attention_in, scale_value=scale_value) 77 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward calculation: {step}") 78 | # 累加query梯度(每个rank只计算自己的q梯度) 79 | dq += grad_query.to(torch.float32) 80 | 81 | # 累加kv梯度:如果不是第一步,需要加上通信过来的梯度 82 | if step > 0: 83 | d_kv_comm.wait() # 等待上一轮dk/dv通信完成 84 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward d_kv_comm wait: {step}") 85 | dk += grad_key.to(torch.float32) + next_dk 86 | dv += grad_value.to(torch.float32) + next_dv 87 | else: 88 | # 第一步直接赋值 89 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward dkdv: {step}") 90 | dk = grad_key.to(torch.float32) 91 | dv = grad_value.to(torch.float32) 92 | else: 93 | # step > 当前rank:仅接收上一轮的dk/dv 94 | if step > 0: 95 | d_kv_comm.wait() 96 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward d_kv_comm wait to next_dk: {step}") 97 | dk = next_dk 98 | dv = next_dv 99 | 100 | # 3. 等待kv通信完成,更新kv 101 | if step + 1 != kv_comm.world_size: 102 | kv_comm.wait() 103 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward kv_comm wait for update: {step}") 104 | k = next_k 105 | v = next_v 106 | 107 | next_dk = d_kv_comm.send_recv(dk) 108 | next_dv = d_kv_comm.send_recv(dv) 109 | d_kv_comm.commit() 110 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward d_kv_comm commit: {step}") 111 | 112 | # 等待最后一轮dk/dv通信完成 113 | d_kv_comm.wait() 114 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward d_kv_comm wait for last: {step}") 115 | 116 | # 转换为输入 dtype 并返回 117 | return (dq.to(q.dtype), dk.to(q.dtype), dv.to(q.dtype)) 118 | 119 | class RingNpuFlashAttnFunc(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, group, q, k, v, head_num, input_layout="BSND"): 122 | # 前向传播逻辑 123 | attention_out,softmax_max, softmax_sum, scale = ring_npu_flash_attn_forward(group,q=q, k=k, v=v, head_num=head_num, input_layout=input_layout) 124 | # 保存中间结果,以便在反向传播中使用 125 | ctx.save_for_backward(q, k, v, attention_out,softmax_max, softmax_sum) 126 | ctx.head_num = head_num 127 | ctx.input_layout = input_layout 128 | ctx.group = group 129 | ctx.scale=scale 130 | 131 | return attention_out 132 | 133 | @staticmethod 134 | def backward(ctx, grad_attention_out): 135 | # 获取保存的中间结果 136 | q, k, v, attention_out,softmax_max, softmax_sum = ctx.saved_tensors 137 | # 反向传播逻辑 138 | # 这里假设有一个实现反向传播的函数 `npu_fusion_attention_backward` 139 | grad_query, grad_key, grad_value = ring_npu_flash_attn_backward(ctx.group,q, k, v, grad_attention_out, 140 | ctx.head_num, ctx.input_layout,softmax_max, softmax_sum, attention_out, ctx.scale) 141 | return None, grad_query, grad_key, grad_value,None,None 142 | 143 | def ring_npu_flash_attn_func( 144 | group, 145 | q: torch.Tensor, 146 | k: torch.Tensor, 147 | v: torch.Tensor, 148 | head_num: int=None, 149 | input_layout: str="BSND" 150 | ): 151 | head_num = q.shape[-2] 152 | return RingNpuFlashAttnFunc.apply( 153 | group, 154 | q, 155 | k, 156 | v, 157 | head_num, 158 | input_layout 159 | ) -------------------------------------------------------------------------------- /yunchang/hybrid/async_attn_layer.py: -------------------------------------------------------------------------------- 1 | from yunchang.comm.all_to_all import SeqAllToAll4D, SeqAllToAll5D 2 | 3 | import torch 4 | 5 | from typing import Any 6 | from torch import Tensor 7 | 8 | import torch.distributed as dist 9 | from .utils import RING_IMPL_DICT, RING_IMPL_QKVPACKED_DICT 10 | from yunchang.globals import PROCESS_GROUP 11 | 12 | 13 | class AsyncLongContextAttention(torch.nn.Module): 14 | """Initialization. 15 | 16 | Arguments: 17 | ulysses_pg (ProcessGroup): ulysses process group 18 | ring_pg (ProcessGroup): ring process group 19 | scatter_idx (int): scatter_idx for all2all comm 20 | gather_idx (int): gather_idx for all2all comm 21 | """ 22 | 23 | def __init__( 24 | self, 25 | scatter_idx: int = 2, 26 | gather_idx: int = 1, 27 | ring_impl_type: str = "basic", 28 | ) -> None: 29 | 30 | super(AsyncLongContextAttention, self).__init__() 31 | self.ring_pg = PROCESS_GROUP.RING_PG 32 | self.ulysses_pg = PROCESS_GROUP.ULYSSES_PG 33 | 34 | self.stream = torch.cuda.Stream() 35 | self._async_op = True 36 | 37 | assert ( 38 | self.ulysses_pg is not None or self.ring_pg is not None 39 | ), f"use set_seq_parallel_pg() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}" 40 | self.scatter_idx = scatter_idx 41 | self.gather_idx = gather_idx 42 | self.ring_attn_fn = RING_IMPL_DICT[ring_impl_type] 43 | 44 | def forward( 45 | self, 46 | query: Tensor, 47 | key: Tensor, 48 | value: Tensor, 49 | dropout_p=0.0, 50 | softmax_scale=None, 51 | causal=False, 52 | window_size=(-1, -1), 53 | softcap=0.0, 54 | alibi_slopes=None, 55 | deterministic=False, 56 | return_attn_probs=False, 57 | *args: Any, 58 | ) -> Tensor: 59 | """forward 60 | 61 | Arguments: 62 | query (Tensor): query input to the layer (bs, seqlen/P, hc, hs) 63 | key (Tensor): key input to the layer (bs, seqlen/P, hc_kv, hs) 64 | value (Tensor): value input to the layer (bs, seqlen/P, hc_kv, hs) 65 | args: other args 66 | 67 | Returns: 68 | * output (Tensor): context output 69 | """ 70 | 71 | # un*ud = hc 72 | 73 | ulysses_degree = dist.get_world_size(self.ulysses_pg) 74 | 75 | bs, shard_seqlen, hc, hs = query.shape 76 | bs, shard_seqlen, hc_kv, hs = key.shape 77 | seq_len = shard_seqlen * ulysses_degree 78 | un = hc // ulysses_degree 79 | un_kv = hc_kv // ulysses_degree 80 | 81 | assert un_kv == un, f"un_kv {un_kv} un {un}" 82 | 83 | qkv = torch.cat([query, key, value]).contiguous() 84 | # (3*bs, seqlen/P, hc, hs) -> (hc, seqlen/P, 3*bs, hs) -> (un, ud, seqlen/P, 3*bs, hs), where hc = un*ud 85 | qkv_list = torch.unbind( 86 | qkv.transpose(0, 2) 87 | .contiguous() 88 | .reshape(un, ulysses_degree, shard_seqlen, 3 * bs, hs) 89 | ) 90 | # 3xall-to-all output buffer 91 | qkv_trans_list = [ 92 | torch.zeros( 93 | ulysses_degree, 94 | 1, 95 | shard_seqlen, 96 | 3 * bs, 97 | hs, 98 | dtype=query.dtype, 99 | device=query.device, 100 | ) 101 | for i in range(len(qkv_list)) 102 | ] 103 | # last all-to-all buffter 104 | context_layer_list = [ 105 | torch.zeros( 106 | ulysses_degree, 107 | 1, 108 | shard_seqlen, 109 | bs, 110 | hs, 111 | dtype=query.dtype, 112 | device=query.device, 113 | ) 114 | for i in range(len(qkv_list)) 115 | ] 116 | 117 | comm_handle_list = [] 118 | 119 | # un * (ud, shard_seqlen, 3*bs, hs) 120 | for i, qkv in enumerate(qkv_list): 121 | with torch.cuda.stream(self.stream): 122 | ret = dist.all_to_all_single( 123 | qkv_trans_list[i], 124 | qkv, 125 | group=self.ulysses_pg, 126 | async_op=self._async_op, 127 | ) 128 | comm_handle_list.append(ret) 129 | 130 | last_comm_handle_list = [] 131 | for i, qkv_trans in enumerate(qkv_trans_list): 132 | if comm_handle_list[i] is not None: 133 | comm_handle_list[i].wait() 134 | qkv_trans = ( 135 | qkv_trans.reshape(seq_len, 3 * bs, 1, hs) 136 | .transpose(0, 1) 137 | .contiguous() 138 | .reshape(3 * bs, seq_len, 1, hs) 139 | ) 140 | 141 | # qkv_trans = all_to_all_4D_async(qkv, qkv_trans_list[i], self.scatter_idx, self.gather_idx, self.ulysses_pg) 142 | qkv_trans = torch.chunk(qkv_trans, 3, dim=0) 143 | 144 | out = self.ring_attn_fn( 145 | qkv_trans[0], 146 | qkv_trans[1], 147 | qkv_trans[2], 148 | dropout_p=dropout_p, 149 | softmax_scale=softmax_scale, 150 | causal=causal, 151 | window_size=window_size, 152 | softcap=softcap, 153 | alibi_slopes=alibi_slopes, 154 | deterministic=deterministic, 155 | return_attn_probs=return_attn_probs, 156 | group=self.ring_pg, 157 | ) 158 | 159 | if type(out) == tuple: 160 | context_layer, _, _ = out 161 | else: 162 | context_layer = out 163 | 164 | # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) 165 | # scatter 1, gather 2 166 | 167 | context_layer = ( 168 | context_layer.reshape(bs, ulysses_degree, shard_seqlen, 1, hs) 169 | .transpose(0, 3) 170 | .transpose(0, 1) 171 | .contiguous() 172 | .reshape(ulysses_degree, 1, shard_seqlen, bs, hs) 173 | ) 174 | with torch.cuda.stream(self.stream): 175 | ret = dist.all_to_all_single( 176 | context_layer_list[i], 177 | context_layer, 178 | group=self.ulysses_pg, 179 | async_op=self._async_op, 180 | ) 181 | last_comm_handle_list.append(ret) 182 | 183 | # hc = un * P 184 | # un x (hc = P, seq_len/P, bs, hs) -> (bs, seq_len, hc = P, hs) 185 | for i, ret in enumerate(last_comm_handle_list): 186 | if ret is not None: 187 | ret.wait() 188 | context_layer_list[i] = ( 189 | context_layer_list[i] 190 | .reshape(ulysses_degree, shard_seqlen, bs, hs) 191 | .transpose(0, 2) 192 | .contiguous() 193 | .reshape(bs, shard_seqlen, ulysses_degree, hs) 194 | ) 195 | 196 | output = torch.cat(context_layer_list, dim=2) 197 | return output 198 | 199 | def backward(self, *args, **kwargs): 200 | raise RuntimeError( 201 | "Backward computation is not allowed for AsyncLongContextAttention." 202 | ) 203 | -------------------------------------------------------------------------------- /benchmark/benchmark_longctx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from yunchang import ( 4 | AsyncLongContextAttention, 5 | LongContextAttention, 6 | set_seq_parallel_pg, 7 | UlyssesAttention, 8 | ) 9 | from yunchang.comm import EXTRACT_FUNC_DICT 10 | import torch.cuda 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser(description="args for benchmark.") 14 | 15 | parser.add_argument( 16 | "--ring_impl_type", 17 | type=str, 18 | default="basic", 19 | choices=["basic", "zigzag", "strip"], 20 | help="ring attn implementation type", 21 | ) 22 | parser.add_argument("--nheads", type=int, default=2, help="head number") 23 | parser.add_argument("--head_size", type=int, default=128, help="head size") 24 | parser.add_argument("--seq_len", type=int, default=4 * 1024, help="sequence length") 25 | parser.add_argument("--group_num", type=int, default=1, help="group number") 26 | parser.add_argument("--batch_size", type=int, default=2, help="batch size") 27 | parser.add_argument( 28 | "--fwd_only", action="store_true", help="benchmark forward pass only" 29 | ) 30 | parser.add_argument( 31 | "--use_ulysses_lowdim", 32 | action="store_true", 33 | default=True, 34 | help="ulysses process group on low dimension", 35 | ) 36 | parser.add_argument( 37 | "--use_qkvpack", 38 | action="store_true", 39 | default=False, 40 | help="pack qkv before all-to-all", 41 | ) 42 | parser.add_argument( 43 | "--ulysses_degree", 44 | type=int, 45 | default=1, 46 | help="ulysses attention sequence parallel degree", 47 | ) 48 | parser.add_argument( 49 | "--use_profiler", 50 | action="store_true", 51 | default=False, 52 | help="use torch profiler", 53 | ) 54 | parser.add_argument( 55 | "--use_ulysses", 56 | action="store_true", 57 | default=False, 58 | help="use ulysses", 59 | ) 60 | parser.add_argument( 61 | "--attn_type", 62 | type=str, 63 | default="fa", 64 | choices=["fa", "fa3", "torch"], 65 | help="attention type", 66 | ) 67 | # decault causal=True for LLM. no_causal is for DiT. 68 | parser.add_argument( 69 | "--no_causal", 70 | action="store_true", 71 | default=False, 72 | help="use no causal attention", 73 | ) 74 | 75 | args = parser.parse_args() 76 | 77 | 78 | def color_print(text): 79 | print("\033[91m {}\033[00m".format(text)) 80 | 81 | 82 | def init_prof(use_profiler): 83 | activities = [] 84 | activities.append(torch.profiler.ProfilerActivity.CPU) 85 | activities.append(torch.profiler.ProfilerActivity.CUDA) 86 | 87 | from contextlib import nullcontext 88 | 89 | ctx = ( 90 | torch.profiler.profile( 91 | activities=activities, 92 | schedule=torch.profiler.schedule(wait=0, warmup=2, active=4, repeat=1), 93 | on_trace_ready=torch.profiler.tensorboard_trace_handler("./profile/"), 94 | record_shapes=True, 95 | with_stack=True, 96 | ) 97 | if use_profiler 98 | else nullcontext() 99 | ) 100 | return ctx 101 | 102 | import os 103 | 104 | def get_local_rank(): 105 | local_rank = int(os.getenv('LOCAL_RANK', '0')) 106 | return local_rank 107 | 108 | def benchmark(num_iter=10, forward_only=True, log=True, profile=False): 109 | dtype = torch.float16 110 | rank = dist.get_rank() 111 | local_rank = get_local_rank() 112 | world_size = dist.get_world_size() 113 | device = torch.device(f"cuda:{local_rank}") 114 | torch.cuda.set_device(device) 115 | 116 | batch_size = args.batch_size 117 | seqlen = args.seq_len 118 | nheads = args.nheads 119 | group_num = args.group_num 120 | d = args.head_size 121 | 122 | dropout_p = 0.0 123 | causal = not args.no_causal 124 | deterministic = False 125 | 126 | assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}" 127 | assert d % 8 == 0 128 | assert nheads % group_num == 0, f"nheads {nheads} group_num {group_num}" 129 | assert ( 130 | nheads // group_num % args.ulysses_degree == 0 131 | ), f"nheads {nheads}, group_num {group_num}, ulysses_degree {args.ulysses_degree}" 132 | 133 | q = torch.randn( 134 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 135 | ) 136 | k = torch.randn( 137 | batch_size, 138 | seqlen, 139 | nheads // group_num, 140 | d, 141 | device=device, 142 | dtype=dtype, 143 | requires_grad=True, 144 | ) 145 | v = torch.randn( 146 | batch_size, 147 | seqlen, 148 | nheads // group_num, 149 | d, 150 | device=device, 151 | dtype=dtype, 152 | requires_grad=True, 153 | ) 154 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 155 | 156 | sp_ulysses_degree = min(args.ulysses_degree, world_size) 157 | sp_ring_degree = world_size // sp_ulysses_degree 158 | 159 | set_seq_parallel_pg( 160 | sp_ulysses_degree, sp_ring_degree, rank, world_size, args.use_ulysses_lowdim 161 | ) 162 | 163 | from yunchang.kernels import AttnType 164 | attn_type = AttnType.from_string(args.attn_type) 165 | if args.use_ulysses: 166 | longctx_attn = UlyssesAttention(attn_type=attn_type) 167 | else: 168 | longctx_attn = LongContextAttention(ring_impl_type=args.ring_impl_type, attn_type=attn_type) 169 | 170 | out = longctx_attn( 171 | q, 172 | k, 173 | v, 174 | dropout_p=dropout_p, 175 | causal=causal, 176 | window_size=(-1, -1), 177 | softcap=0.0, 178 | alibi_slopes=None, 179 | deterministic=deterministic, 180 | return_attn_probs=False, 181 | ) 182 | if not args.fwd_only: 183 | out.backward(dout) 184 | 185 | out = longctx_attn( 186 | q, 187 | k, 188 | v, 189 | dropout_p=dropout_p, 190 | causal=causal, 191 | window_size=(-1, -1), 192 | softcap=0.0, 193 | alibi_slopes=None, 194 | deterministic=deterministic, 195 | return_attn_probs=False, 196 | ) 197 | if not args.fwd_only: 198 | out.backward(dout) 199 | 200 | begin = torch.cuda.Event(enable_timing=True) 201 | begin.record() 202 | 203 | ctx = init_prof(profile) 204 | 205 | with ctx as prof: 206 | if forward_only: 207 | with torch.no_grad(): 208 | for _ in range(num_iter): 209 | _ = longctx_attn( 210 | q, 211 | k, 212 | v, 213 | dropout_p=dropout_p, 214 | causal=causal, 215 | window_size=(-1, -1), 216 | softcap=0.0, 217 | alibi_slopes=None, 218 | deterministic=deterministic, 219 | return_attn_probs=False, 220 | ) 221 | 222 | torch.cuda.synchronize(device=device) 223 | 224 | if profile: 225 | prof.step() 226 | else: 227 | for _ in range(num_iter): 228 | q.grad = None 229 | k.grad = None 230 | v.grad = None 231 | out = longctx_attn( 232 | q, 233 | k, 234 | v, 235 | dropout_p=dropout_p, 236 | causal=causal, 237 | window_size=(-1, -1), 238 | softcap=0.0, 239 | alibi_slopes=None, 240 | deterministic=deterministic, 241 | return_attn_probs=False, 242 | ) 243 | out.backward(dout) 244 | 245 | if profile: 246 | prof.step() 247 | 248 | end = torch.cuda.Event(enable_timing=True) 249 | end.record() 250 | 251 | torch.cuda.synchronize(device=device) 252 | elapse = begin.elapsed_time(end) / 1000.0 253 | 254 | if rank == 0 and log: 255 | color_print(f"{num_iter / elapse:.3f} iter/s, {elapse:.3f} sec") 256 | 257 | 258 | if __name__ == "__main__": 259 | dist.init_process_group("nccl") 260 | rank = dist.get_rank() 261 | 262 | forward_only = args.fwd_only 263 | 264 | torch.cuda.empty_cache() 265 | if rank == 0: 266 | color_print( 267 | f"ring_impl_type: {args.ring_impl_type}. " 268 | f"nheads: {args.nheads} head_size: {args.head_size} seq_len: {args.seq_len} " 269 | f"ulysses_degree : {args.ulysses_degree} fwd_only {forward_only} use_ulysses_lowdim {args.use_ulysses_lowdim}. " 270 | f"use_qkvpack: {args.use_qkvpack} " 271 | f"use_ulysses: {args.use_ulysses} " 272 | f"causal: {not args.no_causal} " 273 | f"attn_type: {args.attn_type} " 274 | ) 275 | torch.cuda.empty_cache() 276 | benchmark(forward_only=forward_only, log=False) 277 | benchmark(forward_only=forward_only, log=True, profile=args.use_profiler) 278 | dist.destroy_process_group() -------------------------------------------------------------------------------- /yunchang/hybrid/attn_layer.py: -------------------------------------------------------------------------------- 1 | from yunchang.comm.all_to_all import SeqAllToAll4D, SeqAllToAll5D 2 | 3 | import torch 4 | 5 | from typing import Any 6 | from torch import Tensor 7 | 8 | import torch.distributed as dist 9 | from .utils import RING_IMPL_DICT, RING_IMPL_QKVPACKED_DICT 10 | from yunchang.globals import PROCESS_GROUP, HAS_SPARSE_SAGE_ATTENTION 11 | from yunchang.kernels import AttnType 12 | 13 | 14 | class LongContextAttention(torch.nn.Module): 15 | """Initialization. 16 | 17 | Arguments: 18 | ulysses_pg (ProcessGroup): ulysses process group 19 | ring_pg (ProcessGroup): ring process group 20 | scatter_idx (int): scatter_idx for all2all comm 21 | gather_idx (int): gather_idx for all2all comm 22 | use_sync (bool): whether to synchronize after all-to-all 23 | """ 24 | 25 | def __init__( 26 | self, 27 | scatter_idx: int = 2, 28 | gather_idx: int = 1, 29 | ring_impl_type: str = "basic", 30 | use_pack_qkv: bool = False, 31 | use_sync: bool = False, 32 | attn_type: AttnType = AttnType.FA, 33 | attn_processor: torch.nn.Module = None, 34 | ) -> None: 35 | 36 | super(LongContextAttention, self).__init__() 37 | self.ring_pg = PROCESS_GROUP.RING_PG 38 | self.ulysses_pg = PROCESS_GROUP.ULYSSES_PG 39 | 40 | self.use_pack_qkv = use_pack_qkv 41 | self.use_sync = use_sync 42 | self.attn_type = attn_type 43 | assert ( 44 | self.ulysses_pg is not None or self.ring_pg is not None 45 | ), f"use set_seq_parallel_pg() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}" 46 | self.scatter_idx = scatter_idx 47 | self.gather_idx = gather_idx 48 | self.attn_processor = attn_processor 49 | self.ring_attn_fn = RING_IMPL_DICT[ring_impl_type] 50 | 51 | if HAS_SPARSE_SAGE_ATTENTION: 52 | from spas_sage_attn.autotune import SparseAttentionMeansim 53 | if isinstance(attn_processor, SparseAttentionMeansim) and dist.get_world_size(self.ring_pg) > 1: 54 | raise RuntimeError("Sparse Sage attention does not support ring degree > 1.") 55 | 56 | 57 | def forward( 58 | self, 59 | query: Tensor, 60 | key: Tensor, 61 | value: Tensor, 62 | dropout_p=0.0, 63 | softmax_scale=None, 64 | causal=False, 65 | window_size=(-1, -1), 66 | softcap=0.0, 67 | alibi_slopes=None, 68 | deterministic=False, 69 | return_attn_probs=False, 70 | *args: Any, 71 | ) -> Tensor: 72 | """forward 73 | 74 | Arguments: 75 | query (Tensor): query input to the layer 76 | key (Tensor): key input to the layer 77 | value (Tensor): value input to the layer 78 | args: other args 79 | 80 | Returns: 81 | * output (Tensor): context output 82 | """ 83 | 84 | # 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size) 85 | # scatter 2, gather 1 86 | if self.use_pack_qkv: 87 | # (3*bs, seq_len/N, head_cnt, head_size) 88 | qkv = torch.cat([query, key, value]).continous() 89 | # (3*bs, seq_len, head_cnt/N, head_size) 90 | qkv = SeqAllToAll4D.apply( 91 | self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx, use_sync=self.use_sync 92 | ) 93 | qkv = torch.chunk(qkv, 3, dim=0) 94 | out = self.ring_attn_fn( 95 | qkv[0], 96 | qkv[1], 97 | qkv[2], 98 | dropout_p=dropout_p, 99 | softmax_scale=softmax_scale, 100 | causal=causal, 101 | window_size=window_size, 102 | softcap=softcap, 103 | alibi_slopes=alibi_slopes, 104 | deterministic=deterministic, 105 | return_attn_probs=return_attn_probs, 106 | group=self.ring_pg, 107 | attn_type=self.attn_type, 108 | attn_processor=self.attn_processor, 109 | ) 110 | else: 111 | query_layer = SeqAllToAll4D.apply( 112 | self.ulysses_pg, query, self.scatter_idx, self.gather_idx, self.use_sync 113 | ) 114 | key_layer = SeqAllToAll4D.apply( 115 | self.ulysses_pg, key, self.scatter_idx, self.gather_idx, self.use_sync 116 | ) 117 | value_layer = SeqAllToAll4D.apply( 118 | self.ulysses_pg, value, self.scatter_idx, self.gather_idx, self.use_sync 119 | ) 120 | if self.attn_type is AttnType.NPU: 121 | out = self.ring_attn_fn( 122 | self.ring_pg, 123 | query_layer, 124 | key_layer, 125 | value_layer 126 | ) 127 | else: 128 | out = self.ring_attn_fn( 129 | query_layer, 130 | key_layer, 131 | value_layer, 132 | dropout_p=dropout_p, 133 | softmax_scale=softmax_scale, 134 | causal=causal, 135 | window_size=window_size, 136 | softcap=softcap, 137 | alibi_slopes=alibi_slopes, 138 | deterministic=deterministic, 139 | return_attn_probs=return_attn_probs, 140 | group=self.ring_pg, 141 | attn_type=self.attn_type, 142 | attn_processor=self.attn_processor, 143 | ) 144 | 145 | if type(out) == tuple: 146 | context_layer, _, _ = out 147 | else: 148 | context_layer = out 149 | 150 | # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) 151 | # scatter 1, gather 2 152 | output = SeqAllToAll4D.apply( 153 | self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx, self.use_sync 154 | ) 155 | 156 | # out e.g., [s/p::h] 157 | return output 158 | 159 | 160 | class LongContextAttentionQKVPacked(torch.nn.Module): 161 | """Initialization. 162 | 163 | Arguments: 164 | ulysses_pg (ProcessGroup): ulysses process group 165 | ring_pg (ProcessGroup): ring process group 166 | scatter_idx (int): scatter_idx for all2all comm 167 | gather_idx (int): gather_idx for all2all comm 168 | use_sync (bool): whether to synchronize after all-to-all 169 | """ 170 | 171 | def __init__( 172 | self, 173 | scatter_idx: int = 3, 174 | gather_idx: int = 1, 175 | ring_impl_type: str = "basic", 176 | use_sync: bool = False, 177 | attn_type: AttnType = AttnType.FA, 178 | ) -> None: 179 | 180 | super(LongContextAttentionQKVPacked, self).__init__() 181 | 182 | self.ring_pg = PROCESS_GROUP.RING_PG 183 | self.ulysses_pg = PROCESS_GROUP.ULYSSES_PG 184 | 185 | assert ( 186 | self.ulysses_pg is not None or self.ring_pg is not None 187 | ), f"use set_seq_parallel_pg() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}" 188 | self.scatter_idx = scatter_idx 189 | self.gather_idx = gather_idx 190 | self.use_sync = use_sync 191 | self.ring_attn_fn = RING_IMPL_QKVPACKED_DICT[ring_impl_type] 192 | self.attn_type = attn_type 193 | 194 | def forward( 195 | self, 196 | qkv, 197 | dropout_p=0.0, 198 | softmax_scale=None, 199 | causal=False, 200 | window_size=(-1, -1), 201 | softcap=0.0, 202 | alibi_slopes=None, 203 | deterministic=False, 204 | return_attn_probs=False, 205 | *args: Any, 206 | ) -> Tensor: 207 | """forward 208 | 209 | Arguments: 210 | query (Tensor): query input to the layer 211 | key (Tensor): key input to the layer 212 | value (Tensor): value input to the layer 213 | args: other args 214 | 215 | Returns: 216 | * output (Tensor): context output 217 | """ 218 | 219 | # scatter 3, gather 1 220 | 221 | world_size = dist.get_world_size(self.ulysses_pg) 222 | 223 | if world_size > 1: 224 | qkv = SeqAllToAll5D.apply( 225 | self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx, self.use_sync 226 | ) 227 | 228 | out = self.ring_attn_fn( 229 | qkv, 230 | dropout_p=dropout_p, 231 | softmax_scale=softmax_scale, 232 | causal=causal, 233 | window_size=window_size, 234 | softcap=softcap, 235 | alibi_slopes=alibi_slopes, 236 | deterministic=deterministic, 237 | return_attn_probs=return_attn_probs, 238 | group=self.ring_pg, 239 | attn_type=self.attn_type, 240 | ) 241 | 242 | # print(f"out {out.shape}") 243 | 244 | if type(out) == tuple: 245 | out = out[0] 246 | 247 | # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) 248 | # scatter 1, gather 2 249 | 250 | if world_size > 1: 251 | out = SeqAllToAll4D.apply( 252 | self.ulysses_pg, out, self.gather_idx, self.scatter_idx - 1, self.use_sync 253 | ) 254 | # out e.g., [s/p::h] 255 | return out 256 | -------------------------------------------------------------------------------- /yunchang/ring/ring_flashinfer_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | from .utils import RingComm, update_out_and_lse 5 | from yunchang.kernels import select_flash_attn_impl, AttnType 6 | 7 | 8 | def ring_flashinfer_attn_forward( 9 | process_group, 10 | q: torch.Tensor, 11 | k: torch.Tensor, 12 | v: torch.Tensor, 13 | softmax_scale, 14 | dropout_p=0, 15 | causal=True, 16 | window_size=(-1, -1), 17 | softcap=0.0, 18 | alibi_slopes=None, 19 | deterministic=False, 20 | attn_type: AttnType = AttnType.FLASHINFER, 21 | attn_processor=None, 22 | ): 23 | comm = RingComm(process_group) 24 | 25 | out = None 26 | lse = None 27 | 28 | next_k, next_v = None, None 29 | 30 | for step in range(comm.world_size): 31 | if step + 1 != comm.world_size: 32 | next_k: torch.Tensor = comm.send_recv(k) 33 | next_v: torch.Tensor = comm.send_recv(v) 34 | comm.commit() 35 | 36 | if not causal or step <= comm.rank: 37 | fn = select_flash_attn_impl(attn_type, stage="fwd-only", attn_processor=attn_processor) 38 | block_out, block_lse = fn( 39 | q, 40 | k, 41 | v, 42 | dropout_p=dropout_p, 43 | softmax_scale=softmax_scale, 44 | causal=causal and step == 0, 45 | window_size=window_size, 46 | softcap=softcap, 47 | alibi_slopes=alibi_slopes, 48 | return_softmax=True and dropout_p > 0, 49 | ) 50 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 51 | 52 | if step + 1 != comm.world_size: 53 | comm.wait() 54 | k = next_k 55 | v = next_v 56 | 57 | out = out.to(q.dtype) 58 | lse = lse.squeeze(dim=-1).transpose(1, 2) 59 | return out, lse 60 | 61 | 62 | def ring_flashinfer_attn_backward( 63 | process_group, 64 | dout, 65 | q, 66 | k, 67 | v, 68 | out, 69 | softmax_lse, 70 | softmax_scale, 71 | dropout_p=0, 72 | causal=True, 73 | window_size=(-1, -1), 74 | softcap=0.0, 75 | alibi_slopes=None, 76 | deterministic=False, 77 | attn_type: AttnType = AttnType.FLASHINFER, 78 | ): 79 | kv_comm = RingComm(process_group) 80 | d_kv_comm = RingComm(process_group) 81 | dq, dk, dv = None, None, None 82 | next_dk, next_dv = None, None 83 | 84 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 85 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 86 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 87 | 88 | next_dk, next_dv = None, None 89 | next_k, next_v = None, None 90 | 91 | for step in range(kv_comm.world_size): 92 | if step + 1 != kv_comm.world_size: 93 | next_k = kv_comm.send_recv(k) 94 | next_v = kv_comm.send_recv(v) 95 | kv_comm.commit() 96 | if step <= kv_comm.rank or not causal: 97 | bwd_causal = causal and step == 0 98 | fn = select_flash_attn_impl(attn_type, stage="bwd-only") 99 | fn( 100 | dout, 101 | q, 102 | k, 103 | v, 104 | out, 105 | softmax_lse, 106 | block_dq_buffer, 107 | block_dk_buffer, 108 | block_dv_buffer, 109 | dropout_p, 110 | softmax_scale, 111 | bwd_causal, 112 | window_size, 113 | softcap, 114 | alibi_slopes, 115 | deterministic, 116 | rng_state=None, 117 | ) 118 | 119 | if dq is None: 120 | dq = block_dq_buffer.to(torch.float32) 121 | dk = block_dk_buffer.to(torch.float32) 122 | dv = block_dv_buffer.to(torch.float32) 123 | else: 124 | dq += block_dq_buffer 125 | d_kv_comm.wait() 126 | dk = block_dk_buffer + next_dk 127 | dv = block_dv_buffer + next_dv 128 | elif step != 0: 129 | d_kv_comm.wait() 130 | dk = next_dk 131 | dv = next_dv 132 | 133 | if step + 1 != kv_comm.world_size: 134 | kv_comm.wait() 135 | k = next_k 136 | v = next_v 137 | 138 | next_dk = d_kv_comm.send_recv(dk) 139 | next_dv = d_kv_comm.send_recv(dv) 140 | d_kv_comm.commit() 141 | 142 | d_kv_comm.wait() 143 | 144 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) 145 | 146 | 147 | WORKSPACE_BUFFER = None 148 | PREFILL_WRAPPER = None 149 | 150 | 151 | class RingFlashInferAttnFunc(torch.autograd.Function): 152 | @staticmethod 153 | def forward( 154 | ctx, 155 | q, 156 | k, 157 | v, 158 | dropout_p, 159 | softmax_scale, 160 | causal, 161 | window_size, 162 | softcap, 163 | alibi_slopes, 164 | deterministic, 165 | return_softmax, 166 | group, 167 | attn_type, 168 | attn_processor, 169 | ): 170 | if softmax_scale is None: 171 | softmax_scale = q.shape[-1] ** (-0.5) 172 | 173 | assert alibi_slopes is None 174 | k = k.contiguous() 175 | v = v.contiguous() 176 | out, softmax_lse = ring_flashinfer_attn_forward( 177 | group, 178 | q, 179 | k, 180 | v, 181 | softmax_scale=softmax_scale, 182 | dropout_p=dropout_p, 183 | causal=causal, 184 | window_size=window_size, 185 | softcap=softcap, 186 | alibi_slopes=alibi_slopes, 187 | deterministic=False, 188 | attn_type=attn_type, 189 | attn_processor=attn_processor, 190 | ) 191 | # this should be out_padded 192 | ctx.save_for_backward(q, k, v, out, softmax_lse) 193 | ctx.dropout_p = dropout_p 194 | ctx.softmax_scale = softmax_scale 195 | ctx.causal = causal 196 | ctx.window_size = window_size 197 | ctx.softcap = softcap 198 | ctx.alibi_slopes = alibi_slopes 199 | ctx.deterministic = deterministic 200 | ctx.group = group 201 | ctx.attn_type = attn_type 202 | ctx.attn_processor = attn_processor 203 | return out if not return_softmax else (out, softmax_lse, None) 204 | 205 | @staticmethod 206 | def backward(ctx, dout, *args): 207 | q, k, v, out, softmax_lse = ctx.saved_tensors 208 | dq, dk, dv = ring_flashinfer_attn_backward( 209 | ctx.group, 210 | dout, 211 | q, 212 | k, 213 | v, 214 | out, 215 | softmax_lse, 216 | softmax_scale=ctx.softmax_scale, 217 | dropout_p=ctx.dropout_p, 218 | causal=ctx.causal, 219 | window_size=ctx.window_size, 220 | softcap=ctx.softcap, 221 | alibi_slopes=ctx.alibi_slopes, 222 | deterministic=ctx.deterministic, 223 | attn_type=ctx.attn_type, 224 | ) 225 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None 226 | 227 | 228 | def ring_flashinfer_attn_qkvpacked_func( 229 | qkv, 230 | dropout_p=0.0, 231 | softmax_scale=None, 232 | causal=False, 233 | window_size=(-1, -1), 234 | softcap=0.0, 235 | alibi_slopes=None, 236 | deterministic=False, 237 | return_attn_probs=False, 238 | group=None, 239 | attn_type: AttnType = AttnType.FLASHINFER, 240 | ): 241 | return RingFlashInferAttnFunc.apply( 242 | qkv[:, :, 0], 243 | qkv[:, :, 1], 244 | qkv[:, :, 2], 245 | dropout_p, 246 | softmax_scale, 247 | causal, 248 | window_size, 249 | softcap, 250 | alibi_slopes, 251 | deterministic, 252 | return_attn_probs, 253 | group, 254 | attn_type, 255 | ) 256 | 257 | 258 | def ring_flashinfer_attn_kvpacked_func( 259 | q, 260 | kv, 261 | dropout_p=0.0, 262 | softmax_scale=None, 263 | causal=False, 264 | window_size=(-1, -1), 265 | softcap=0.0, 266 | alibi_slopes=None, 267 | deterministic=False, 268 | return_attn_probs=False, 269 | group=None, 270 | attn_type: AttnType = AttnType.FLASHINFER, 271 | ): 272 | return RingFlashInferAttnFunc.apply( 273 | q, 274 | kv[:, :, 0], 275 | kv[:, :, 1], 276 | dropout_p, 277 | softmax_scale, 278 | causal, 279 | window_size, 280 | softcap, 281 | alibi_slopes, 282 | deterministic, 283 | return_attn_probs, 284 | group, 285 | attn_type, 286 | ) 287 | 288 | 289 | def ring_flashinfer_attn_func( 290 | q, 291 | k, 292 | v, 293 | dropout_p=0.0, 294 | softmax_scale=None, 295 | causal=False, 296 | window_size=(-1, -1), 297 | softcap=0.0, 298 | alibi_slopes=None, 299 | deterministic=False, 300 | return_attn_probs=False, 301 | group=None, 302 | attn_type: AttnType = AttnType.FLASHINFER, 303 | attn_processor=None, 304 | ): 305 | return RingFlashInferAttnFunc.apply( 306 | q, 307 | k, 308 | v, 309 | dropout_p, 310 | softmax_scale, 311 | causal, 312 | window_size, 313 | softcap, 314 | alibi_slopes, 315 | deterministic, 316 | return_attn_probs, 317 | group, 318 | attn_type, 319 | attn_processor, 320 | ) 321 | -------------------------------------------------------------------------------- /yunchang/ring/ring_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | # from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward 4 | from .utils import RingComm, update_out_and_lse 5 | from yunchang.kernels import select_flash_attn_impl, AttnType 6 | 7 | def ring_flash_attn_forward( 8 | process_group, 9 | q: torch.Tensor, 10 | k: torch.Tensor, 11 | v: torch.Tensor, 12 | softmax_scale, 13 | dropout_p=0, 14 | causal=True, 15 | window_size=(-1, -1), 16 | softcap=0.0, 17 | alibi_slopes=None, 18 | deterministic=False, 19 | attn_type: AttnType = AttnType.FA, 20 | attn_processor=None, 21 | ): 22 | comm = RingComm(process_group) 23 | 24 | out = None 25 | lse = None 26 | 27 | next_k, next_v = None, None 28 | 29 | for step in range(comm.world_size): 30 | if step + 1 != comm.world_size: 31 | next_k: torch.Tensor = comm.send_recv(k) 32 | next_v: torch.Tensor = comm.send_recv(v) 33 | comm.commit() 34 | 35 | if not causal or step <= comm.rank: 36 | fn = select_flash_attn_impl(attn_type, stage="fwd-only", attn_processor=attn_processor) 37 | block_out, block_lse = fn( 38 | q, 39 | k, 40 | v, 41 | dropout_p=dropout_p, 42 | softmax_scale=softmax_scale, 43 | causal=causal and step == 0, 44 | window_size=window_size, 45 | softcap=softcap, 46 | alibi_slopes=alibi_slopes, 47 | return_softmax=True and dropout_p > 0, 48 | ) 49 | if attn_type == AttnType.SPARSE_SAGE: 50 | out, lse = block_out, block_lse 51 | else: 52 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 53 | 54 | if step + 1 != comm.world_size: 55 | comm.wait() 56 | k = next_k 57 | v = next_v 58 | 59 | out = out.to(q.dtype) 60 | if attn_type != AttnType.SPARSE_SAGE: 61 | lse = lse.squeeze(dim=-1).transpose(1, 2) 62 | return out, lse 63 | 64 | 65 | def ring_flash_attn_backward( 66 | process_group, 67 | dout, 68 | q, 69 | k, 70 | v, 71 | out, 72 | softmax_lse, 73 | softmax_scale, 74 | dropout_p=0, 75 | causal=True, 76 | window_size=(-1, -1), 77 | softcap=0.0, 78 | alibi_slopes=None, 79 | deterministic=False, 80 | attn_type: AttnType = AttnType.FA, 81 | ): 82 | kv_comm = RingComm(process_group) 83 | d_kv_comm = RingComm(process_group) 84 | dq, dk, dv = None, None, None 85 | next_dk, next_dv = None, None 86 | 87 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 88 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 89 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 90 | 91 | next_dk, next_dv = None, None 92 | next_k, next_v = None, None 93 | 94 | for step in range(kv_comm.world_size): 95 | if step + 1 != kv_comm.world_size: 96 | next_k = kv_comm.send_recv(k) 97 | next_v = kv_comm.send_recv(v) 98 | kv_comm.commit() 99 | if step <= kv_comm.rank or not causal: 100 | bwd_causal = causal and step == 0 101 | fn = select_flash_attn_impl(attn_type, stage="bwd-only") 102 | fn( 103 | dout, 104 | q, 105 | k, 106 | v, 107 | out, 108 | softmax_lse, 109 | block_dq_buffer, 110 | block_dk_buffer, 111 | block_dv_buffer, 112 | dropout_p, 113 | softmax_scale, 114 | bwd_causal, 115 | window_size, 116 | softcap, 117 | alibi_slopes, 118 | deterministic, 119 | rng_state=None, 120 | ) 121 | 122 | if dq is None: 123 | dq = block_dq_buffer.to(torch.float32) 124 | dk = block_dk_buffer.to(torch.float32) 125 | dv = block_dv_buffer.to(torch.float32) 126 | else: 127 | dq += block_dq_buffer 128 | d_kv_comm.wait() 129 | dk = block_dk_buffer + next_dk 130 | dv = block_dv_buffer + next_dv 131 | elif step != 0: 132 | d_kv_comm.wait() 133 | dk = next_dk 134 | dv = next_dv 135 | 136 | if step + 1 != kv_comm.world_size: 137 | kv_comm.wait() 138 | k = next_k 139 | v = next_v 140 | 141 | next_dk = d_kv_comm.send_recv(dk) 142 | next_dv = d_kv_comm.send_recv(dv) 143 | d_kv_comm.commit() 144 | 145 | d_kv_comm.wait() 146 | 147 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) 148 | 149 | 150 | class RingFlashAttnFunc(torch.autograd.Function): 151 | @staticmethod 152 | def forward( 153 | ctx, 154 | q, 155 | k, 156 | v, 157 | dropout_p, 158 | softmax_scale, 159 | causal, 160 | window_size, 161 | softcap, 162 | alibi_slopes, 163 | deterministic, 164 | return_softmax, 165 | group, 166 | attn_type, 167 | attn_processor, 168 | ): 169 | if softmax_scale is None: 170 | softmax_scale = q.shape[-1] ** (-0.5) 171 | 172 | assert alibi_slopes is None 173 | k = k.contiguous() 174 | v = v.contiguous() 175 | out, softmax_lse = ring_flash_attn_forward( 176 | group, 177 | q, 178 | k, 179 | v, 180 | softmax_scale=softmax_scale, 181 | dropout_p=dropout_p, 182 | causal=causal, 183 | window_size=window_size, 184 | softcap=softcap, 185 | alibi_slopes=alibi_slopes, 186 | deterministic=False, 187 | attn_type=attn_type, 188 | attn_processor=attn_processor, 189 | ) 190 | # this should be out_padded 191 | ctx.save_for_backward(q, k, v, out, softmax_lse) 192 | ctx.dropout_p = dropout_p 193 | ctx.softmax_scale = softmax_scale 194 | ctx.causal = causal 195 | ctx.window_size = window_size 196 | ctx.softcap = softcap 197 | ctx.alibi_slopes = alibi_slopes 198 | ctx.deterministic = deterministic 199 | ctx.group = group 200 | ctx.attn_type = attn_type 201 | ctx.attn_processor = attn_processor 202 | return out if not return_softmax else (out, softmax_lse, None) 203 | 204 | @staticmethod 205 | def backward(ctx, dout, *args): 206 | q, k, v, out, softmax_lse = ctx.saved_tensors 207 | dq, dk, dv = ring_flash_attn_backward( 208 | ctx.group, 209 | dout, 210 | q, 211 | k, 212 | v, 213 | out, 214 | softmax_lse, 215 | softmax_scale=ctx.softmax_scale, 216 | dropout_p=ctx.dropout_p, 217 | causal=ctx.causal, 218 | window_size=ctx.window_size, 219 | softcap=ctx.softcap, 220 | alibi_slopes=ctx.alibi_slopes, 221 | deterministic=ctx.deterministic, 222 | attn_type=ctx.attn_type, 223 | ) 224 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None 225 | 226 | 227 | def ring_flash_attn_qkvpacked_func( 228 | qkv, 229 | dropout_p=0.0, 230 | softmax_scale=None, 231 | causal=False, 232 | window_size=(-1, -1), 233 | softcap=0.0, 234 | alibi_slopes=None, 235 | deterministic=False, 236 | return_attn_probs=False, 237 | group=None, 238 | attn_type: AttnType = AttnType.FA, 239 | ): 240 | return RingFlashAttnFunc.apply( 241 | qkv[:, :, 0], 242 | qkv[:, :, 1], 243 | qkv[:, :, 2], 244 | dropout_p, 245 | softmax_scale, 246 | causal, 247 | window_size, 248 | softcap, 249 | alibi_slopes, 250 | deterministic, 251 | return_attn_probs, 252 | group, 253 | attn_type, 254 | ) 255 | 256 | 257 | def ring_flash_attn_kvpacked_func( 258 | q, 259 | kv, 260 | dropout_p=0.0, 261 | softmax_scale=None, 262 | causal=False, 263 | window_size=(-1, -1), 264 | softcap=0.0, 265 | alibi_slopes=None, 266 | deterministic=False, 267 | return_attn_probs=False, 268 | group=None, 269 | attn_type: AttnType = AttnType.FA, 270 | ): 271 | return RingFlashAttnFunc.apply( 272 | q, 273 | kv[:, :, 0], 274 | kv[:, :, 1], 275 | dropout_p, 276 | softmax_scale, 277 | causal, 278 | window_size, 279 | softcap, 280 | alibi_slopes, 281 | deterministic, 282 | return_attn_probs, 283 | group, 284 | attn_type, 285 | ) 286 | 287 | 288 | def ring_flash_attn_func( 289 | q, 290 | k, 291 | v, 292 | dropout_p=0.0, 293 | softmax_scale=None, 294 | causal=False, 295 | window_size=(-1, -1), 296 | softcap=0.0, 297 | alibi_slopes=None, 298 | deterministic=False, 299 | return_attn_probs=False, 300 | group=None, 301 | attn_type: AttnType = AttnType.FA, 302 | attn_processor=None, 303 | ): 304 | return RingFlashAttnFunc.apply( 305 | q, 306 | k, 307 | v, 308 | dropout_p, 309 | softmax_scale, 310 | causal, 311 | window_size, 312 | softcap, 313 | alibi_slopes, 314 | deterministic, 315 | return_attn_probs, 316 | group, 317 | attn_type, 318 | attn_processor, 319 | ) 320 | -------------------------------------------------------------------------------- /yunchang/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from .attention import ( 5 | flash_attn_forward_aiter, 6 | flash_attn_forward, 7 | flash_attn_backward, 8 | flash_attn3_func_forward, 9 | flash_attn3_func_backward, 10 | pytorch_attn_forward, 11 | pytorch_attn_backward, 12 | flashinfer_attn_forward, 13 | flashinfer_attn_backbward, 14 | npu_fused_attn_forward, 15 | npu_fused_attn_backward, 16 | HAS_FLASH_ATTN_HOPPER, 17 | ) 18 | from enum import Enum, auto 19 | 20 | from yunchang.globals import ( 21 | HAS_AITER, 22 | HAS_FLASH_ATTN, 23 | HAS_SAGE_ATTENTION, 24 | HAS_SPARSE_SAGE_ATTENTION, 25 | HAS_NPU, 26 | ) 27 | 28 | if HAS_FLASH_ATTN: 29 | from flash_attn import flash_attn_func 30 | 31 | if HAS_SAGE_ATTENTION: 32 | import sageattention 33 | 34 | if HAS_SPARSE_SAGE_ATTENTION: 35 | from spas_sage_attn.autotune import SparseAttentionMeansim 36 | 37 | 38 | class AttnType(Enum): 39 | AITER = "aiter" 40 | FA = "fa" 41 | FA3 = "fa3" 42 | FLASHINFER = "flashinfer" 43 | TORCH = "torch" 44 | SAGE_AUTO = "sage_auto" 45 | SAGE_FP16 = "sage_fp16" 46 | SAGE_FP16_TRITON = "sage_fp16_triton" 47 | SAGE_FP8 = "sage_fp8" 48 | SAGE_FP8_SM90 = "sage_fp8_sm90" 49 | SPARSE_SAGE = "sparse_sage" 50 | NPU = 'npu' 51 | 52 | @classmethod 53 | def from_string(cls, s: str): 54 | for member in cls: 55 | if member.value == s: 56 | return member 57 | raise ValueError(f"'{s}' is not a valid {cls.__name__}") 58 | 59 | 60 | def select_flash_attn_impl( 61 | impl_type: AttnType, stage: str = "fwd-bwd", attn_processor: torch.nn.Module = None 62 | ): 63 | if impl_type == AttnType.AITER: 64 | if stage == "fwd-only": 65 | return flash_attn_forward_aiter 66 | elif stage == "bwd-only": 67 | raise ValueError("Aiter does not support bwd-only stage.") 68 | elif stage == "fwd-bwd": 69 | raise ValueError("Aiter does not support fwd-bwd stage.") 70 | else: 71 | raise ValueError(f"Unknown stage: {stage}") 72 | 73 | elif impl_type == AttnType.FA: 74 | if stage == "fwd-only": 75 | return flash_attn_forward 76 | elif stage == "bwd-only": 77 | return flash_attn_backward 78 | elif stage == "fwd-bwd": 79 | assert HAS_FLASH_ATTN, "FlashAttention is not available" 80 | return flash_attn_func 81 | else: 82 | raise ValueError(f"Unknown stage: {stage}") 83 | 84 | elif impl_type == AttnType.FA3: 85 | if stage == "fwd-only": 86 | return flash_attn3_func_forward 87 | elif stage == "bwd-only": 88 | return flash_attn3_func_backward 89 | elif stage == "fwd-bwd": 90 | 91 | def fn( 92 | q, 93 | k, 94 | v, 95 | dropout_p=0.0, 96 | softmax_scale=None, 97 | causal=False, 98 | *args, 99 | **kwargs, 100 | ): 101 | assert ( 102 | HAS_FLASH_ATTN_HOPPER 103 | ), "FlashAttention3 is not available! install it from https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release" 104 | # (q, k, v, softmax_scale=None, causal=False, window_size=(-1, -1), 105 | # deterministic=False, descale_q=None, descale_k=None, descale_v=None, gqa_parallel=False) 106 | from .attention import flash3_attn_func 107 | 108 | assert softmax_scale is not None, f"softmax_scale is required for FA3" 109 | assert ( 110 | dropout_p == 0.0 111 | ), f"dropout_p: {dropout_p} is not supported for FA3" 112 | return flash3_attn_func( 113 | q, k, v, softmax_scale=softmax_scale, causal=causal 114 | ) 115 | 116 | return fn 117 | else: 118 | raise ValueError(f"Unknown stage: {stage}") 119 | 120 | elif impl_type == AttnType.FLASHINFER: 121 | if stage == "fwd-only": 122 | return flashinfer_attn_forward 123 | elif stage == "bwd-only": 124 | return flashinfer_attn_backbward 125 | elif stage == "fwd-bwd": 126 | raise ValueError("FlashInfer does not support fwd-bwd stage.") 127 | else: 128 | raise ValueError(f"Unknown stage: {stage}") 129 | 130 | elif impl_type == AttnType.TORCH: 131 | if stage == "fwd-only": 132 | return pytorch_attn_forward 133 | elif stage == "bwd-only": 134 | return pytorch_attn_backward 135 | elif stage == "fwd-bwd": 136 | from yunchang.ring.ring_pytorch_attn import pytorch_attn_func 137 | 138 | return pytorch_attn_func 139 | else: 140 | raise ValueError(f"Unknown stage: {stage}") 141 | 142 | elif impl_type == AttnType.SAGE_AUTO: 143 | if not HAS_SAGE_ATTENTION: 144 | raise ImportError("SageAttention is not available!") 145 | if stage == "fwd-only": 146 | return partial( 147 | sageattention.sageattn, 148 | tensor_layout="NHD", 149 | return_lse=True, 150 | ) 151 | else: 152 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 153 | 154 | elif impl_type == AttnType.SAGE_FP16: 155 | if not HAS_SAGE_ATTENTION: 156 | raise ImportError("SageAttention is not available!") 157 | 158 | if stage == "fwd-only": 159 | return partial( 160 | sageattention.sageattn_qk_int8_pv_fp16_cuda, 161 | pv_accum_dtype="fp32", 162 | tensor_layout="NHD", 163 | return_lse=True, 164 | ) 165 | else: 166 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 167 | 168 | elif impl_type == AttnType.SAGE_FP16_TRITON: 169 | if not HAS_SAGE_ATTENTION: 170 | raise ImportError("SageAttention is not available!") 171 | 172 | if stage == "fwd-only": 173 | return partial( 174 | sageattention.sageattn_qk_int8_pv_fp16_triton, 175 | tensor_layout="NHD", 176 | return_lse=True, 177 | ) 178 | else: 179 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 180 | 181 | elif impl_type == AttnType.SAGE_FP8: 182 | if not HAS_SAGE_ATTENTION: 183 | raise ImportError("SageAttention is not available!") 184 | if stage == "fwd-only": 185 | return partial( 186 | sageattention.sageattn_qk_int8_pv_fp8_cuda, 187 | pv_accum_dtype="fp32+fp32", 188 | tensor_layout="NHD", 189 | return_lse=True, 190 | ) 191 | else: 192 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 193 | 194 | elif impl_type == AttnType.SAGE_FP8_SM90: 195 | if not HAS_SAGE_ATTENTION: 196 | raise ImportError("SageAttention is not available!") 197 | if stage == "fwd-only": 198 | return partial( 199 | sageattention.sageattn_qk_int8_pv_fp8_cuda_sm90, 200 | pv_accum_dtype="fp32+fp32", 201 | tensor_layout="NHD", 202 | return_lse=True, 203 | ) 204 | else: 205 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 206 | 207 | elif impl_type == AttnType.SAGE_FP16_TRITON: 208 | if not HAS_SAGE_ATTENTION: 209 | raise ImportError("SageAttention is not available!") 210 | if stage == "fwd-only": 211 | return partial( 212 | sageattention.sageattn_qk_int8_pv_fp16_triton, 213 | pv_accum_dtype="fp32", 214 | tensor_layout="NHD", 215 | quantization_backend="cuda", 216 | return_lse=True, 217 | ) 218 | else: 219 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 220 | 221 | elif impl_type == AttnType.SPARSE_SAGE: 222 | if not HAS_SPARSE_SAGE_ATTENTION: 223 | raise ImportError("SparseSageAttention is not available!") 224 | if not isinstance(attn_processor, SparseAttentionMeansim): 225 | raise ImportError( 226 | "SparseSageAttention is only available with a SparseAttentionProcessor class passed in" 227 | ) 228 | if stage == "fwd-only": 229 | 230 | def fn(q, k, v, causal=False, softmax_scale=None, *args, **kwargs): 231 | return ( 232 | attn_processor( 233 | q, 234 | k, 235 | v, 236 | is_causal=causal, 237 | scale=softmax_scale, 238 | tensor_layout="NHD", 239 | ), 240 | None, 241 | ) 242 | 243 | return fn 244 | else: 245 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 246 | 247 | elif impl_type == AttnType.NPU: 248 | if stage == "fwd-only": 249 | return npu_fused_attn_forward 250 | elif stage == "bwd-only": 251 | return npu_fused_attn_backward 252 | elif stage == "fwd-bwd": 253 | return npu_fused_attn_forward 254 | else: 255 | raise ValueError(f"Unknown stage: {stage}") 256 | 257 | elif attn_processor is not None: 258 | return attn_processor 259 | else: 260 | raise ValueError(f"Unknown flash attention implementation: {impl_type}") 261 | 262 | 263 | __all__ = [ 264 | "flash_attn_forward", 265 | "flash_attn_backward", 266 | "flash_attn3_func_forward", 267 | "flash_attn3_func_forward", 268 | "flashinfer_attn_forward", 269 | "flashinfer_attn_backbward", 270 | "AttnType", 271 | ] 272 | -------------------------------------------------------------------------------- /yunchang/ring/ring_flash_attn_varlen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | from yunchang.globals import HAS_FLASH_ATTN 5 | 6 | if HAS_FLASH_ATTN: 7 | from flash_attn.flash_attn_interface import ( 8 | _flash_attn_varlen_forward, 9 | _flash_attn_varlen_backward, 10 | ) 11 | from .utils import ( 12 | RingComm, 13 | update_out_and_lse, 14 | ) 15 | 16 | try: 17 | from .triton_utils import ( 18 | flatten_varlen_lse, 19 | unflatten_varlen_lse, 20 | ) 21 | except: 22 | from .utils import ( 23 | flatten_varlen_lse, 24 | unflatten_varlen_lse, 25 | ) 26 | 27 | 28 | def ring_flash_attn_varlen_forward( 29 | process_group, 30 | q: torch.Tensor, 31 | k: torch.Tensor, 32 | v: torch.Tensor, 33 | cu_seqlens, 34 | max_seqlen, 35 | softmax_scale, 36 | dropout_p=0, 37 | causal=True, 38 | window_size=(-1, -1), 39 | softcap=0.0, 40 | alibi_slopes=None, 41 | deterministic=False, 42 | ): 43 | comm = RingComm(process_group) 44 | 45 | out = None 46 | lse = None 47 | next_k, next_v = None, None 48 | 49 | for step in range(comm.world_size): 50 | if step + 1 != comm.world_size: 51 | next_k: torch.Tensor = comm.send_recv(k) 52 | next_v: torch.Tensor = comm.send_recv(v) 53 | comm.commit() 54 | if not causal or step <= comm.rank: 55 | assert HAS_FLASH_ATTN, "FlashAttention is not available" 56 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( 57 | q, 58 | k, 59 | v, 60 | cu_seqlens, 61 | cu_seqlens, 62 | max_seqlen, 63 | max_seqlen, 64 | dropout_p, 65 | softmax_scale, 66 | causal=causal and step == 0, 67 | window_size=window_size, 68 | softcap=softcap, 69 | alibi_slopes=alibi_slopes, 70 | return_softmax=True and dropout_p > 0, 71 | ) 72 | block_lse = flatten_varlen_lse( 73 | block_lse, 74 | cu_seqlens=cu_seqlens, 75 | ) 76 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 77 | 78 | if step + 1 != comm.world_size: 79 | comm.wait() 80 | k = next_k 81 | v = next_v 82 | 83 | out = out.to(q.dtype) 84 | lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) 85 | return out, lse 86 | 87 | 88 | def ring_flash_attn_varlen_backward( 89 | process_group, 90 | dout, 91 | q, 92 | k, 93 | v, 94 | out, 95 | softmax_lse, 96 | cu_seqlens, 97 | max_seqlen, 98 | softmax_scale, 99 | dropout_p=0, 100 | causal=True, 101 | window_size=(-1, -1), 102 | softcap=0.0, 103 | alibi_slopes=None, 104 | deterministic=False, 105 | ): 106 | kv_comm = RingComm(process_group) 107 | d_kv_comm = RingComm(process_group) 108 | dq, dk, dv = None, None, None 109 | next_dk, next_dv = None, None 110 | 111 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 112 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 113 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 114 | 115 | next_dk, next_dv = None, None 116 | next_k, next_v = None, None 117 | for step in range(kv_comm.world_size): 118 | if step + 1 != kv_comm.world_size: 119 | next_k = kv_comm.send_recv(k) 120 | next_v = kv_comm.send_recv(v) 121 | kv_comm.commit() 122 | if step <= kv_comm.rank or not causal: 123 | bwd_causal = causal and step == 0 124 | assert HAS_FLASH_ATTN, "FlashAttention is not available" 125 | _flash_attn_varlen_backward( 126 | dout, 127 | q, 128 | k, 129 | v, 130 | out, 131 | softmax_lse, 132 | block_dq_buffer, 133 | block_dk_buffer, 134 | block_dv_buffer, 135 | cu_seqlens, 136 | cu_seqlens, 137 | max_seqlen, 138 | max_seqlen, 139 | dropout_p, 140 | softmax_scale, 141 | bwd_causal, 142 | window_size, 143 | softcap, 144 | alibi_slopes, 145 | deterministic, 146 | rng_state=None, 147 | ) 148 | 149 | if dq is None: 150 | dq = block_dq_buffer.to(torch.float32) 151 | dk = block_dk_buffer.to(torch.float32) 152 | dv = block_dv_buffer.to(torch.float32) 153 | else: 154 | dq += block_dq_buffer 155 | d_kv_comm.wait() 156 | dk = block_dk_buffer + next_dk 157 | dv = block_dv_buffer + next_dv 158 | elif step != 0: 159 | d_kv_comm.wait() 160 | dk = next_dk 161 | dv = next_dv 162 | 163 | if step + 1 != kv_comm.world_size: 164 | kv_comm.wait() 165 | k = next_k 166 | v = next_v 167 | 168 | next_dk = d_kv_comm.send_recv(dk) 169 | next_dv = d_kv_comm.send_recv(dv) 170 | d_kv_comm.commit() 171 | 172 | d_kv_comm.wait() 173 | 174 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) 175 | 176 | 177 | class RingFlashAttnVarlenFunc(torch.autograd.Function): 178 | @staticmethod 179 | def forward( 180 | ctx, 181 | q, 182 | k, 183 | v, 184 | cu_seqlens, 185 | max_seqlen, 186 | dropout_p, 187 | softmax_scale, 188 | causal, 189 | window_size, 190 | softcap, 191 | alibi_slopes, 192 | deterministic, 193 | return_softmax, 194 | group, 195 | ): 196 | if softmax_scale is None: 197 | softmax_scale = q.shape[-1] ** (-0.5) 198 | 199 | assert alibi_slopes is None 200 | k = k.contiguous() 201 | v = v.contiguous() 202 | out, softmax_lse = ring_flash_attn_varlen_forward( 203 | group, 204 | q, 205 | k, 206 | v, 207 | cu_seqlens, 208 | max_seqlen, 209 | softmax_scale=softmax_scale, 210 | dropout_p=dropout_p, 211 | causal=causal, 212 | window_size=window_size, 213 | softcap=softcap, 214 | alibi_slopes=alibi_slopes, 215 | deterministic=False, 216 | ) 217 | # this should be out_padded 218 | ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) 219 | ctx.max_seqlen = max_seqlen 220 | ctx.dropout_p = dropout_p 221 | ctx.softmax_scale = softmax_scale 222 | ctx.causal = causal 223 | ctx.window_size = window_size 224 | ctx.softcap = softcap 225 | ctx.alibi_slopes = alibi_slopes 226 | ctx.deterministic = deterministic 227 | ctx.group = group 228 | return out if not return_softmax else (out, softmax_lse, None) 229 | 230 | @staticmethod 231 | def backward(ctx, dout, *args): 232 | q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors 233 | dq, dk, dv = ring_flash_attn_varlen_backward( 234 | ctx.group, 235 | dout, 236 | q, 237 | k, 238 | v, 239 | out, 240 | softmax_lse, 241 | cu_seqlens, 242 | ctx.max_seqlen, 243 | softmax_scale=ctx.softmax_scale, 244 | dropout_p=ctx.dropout_p, 245 | causal=ctx.causal, 246 | window_size=ctx.window_size, 247 | softcap=ctx.softcap, 248 | alibi_slopes=ctx.alibi_slopes, 249 | deterministic=ctx.deterministic, 250 | ) 251 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None 252 | 253 | 254 | def ring_flash_attn_varlen_qkvpacked_func( 255 | qkv, 256 | cu_seqlens, 257 | max_seqlen, 258 | dropout_p=0.0, 259 | softmax_scale=None, 260 | causal=False, 261 | window_size=(-1, -1), # -1 means infinite context window 262 | softcap=0.0, 263 | alibi_slopes=None, 264 | deterministic=False, 265 | return_attn_probs=False, 266 | group=None, 267 | ): 268 | return RingFlashAttnVarlenFunc.apply( 269 | qkv[:, 0], 270 | qkv[:, 1], 271 | qkv[:, 2], 272 | cu_seqlens, 273 | max_seqlen, 274 | dropout_p, 275 | softmax_scale, 276 | causal, 277 | window_size, 278 | softcap, 279 | alibi_slopes, 280 | deterministic, 281 | return_attn_probs, 282 | group, 283 | ) 284 | 285 | 286 | def ring_flash_attn_varlen_kvpacked_func( 287 | q, 288 | kv, 289 | cu_seqlens, 290 | max_seqlen, 291 | dropout_p=0.0, 292 | softmax_scale=None, 293 | causal=False, 294 | window_size=(-1, -1), # -1 means infinite context window 295 | softcap=0.0, 296 | alibi_slopes=None, 297 | deterministic=False, 298 | return_attn_probs=False, 299 | group=None, 300 | ): 301 | return RingFlashAttnVarlenFunc.apply( 302 | q, 303 | kv[:, 0], 304 | kv[:, 1], 305 | cu_seqlens, 306 | max_seqlen, 307 | dropout_p, 308 | softmax_scale, 309 | causal, 310 | window_size, 311 | softcap, 312 | alibi_slopes, 313 | deterministic, 314 | return_attn_probs, 315 | group, 316 | ) 317 | 318 | 319 | def ring_flash_attn_varlen_func( 320 | q, 321 | k, 322 | v, 323 | cu_seqlens, 324 | max_seqlen, 325 | dropout_p=0.0, 326 | softmax_scale=None, 327 | causal=False, 328 | window_size=(-1, -1), # -1 means infinite context window 329 | softcap=0.0, 330 | alibi_slopes=None, 331 | deterministic=False, 332 | return_attn_probs=False, 333 | group=None, 334 | ): 335 | return RingFlashAttnVarlenFunc.apply( 336 | q, 337 | k, 338 | v, 339 | cu_seqlens, 340 | max_seqlen, 341 | dropout_p, 342 | softmax_scale, 343 | causal, 344 | window_size, 345 | softcap, 346 | alibi_slopes, 347 | deterministic, 348 | return_attn_probs, 349 | group, 350 | ) 351 | -------------------------------------------------------------------------------- /test/test_hybrid_attn_npu.py: -------------------------------------------------------------------------------- 1 | import os 2 | from yunchang import LongContextAttention, set_seq_parallel_pg, EXTRACT_FUNC_DICT 3 | import torch 4 | import torch_npu 5 | from torch_npu.contrib import transfer_to_npu 6 | import torch.distributed as dist 7 | 8 | 9 | from yunchang.kernels import AttnType 10 | from test_utils import attention_ref 11 | import argparse 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description="Test hybrid attention with configurable sequence length" 17 | ) 18 | parser.add_argument( 19 | "--seqlen", type=int, default=1024, help="sequence length (default: 1024)" 20 | ) 21 | parser.add_argument( 22 | "--use_bwd", 23 | action="store_true", 24 | help="whether to test backward pass (default: False)", 25 | ) 26 | parser.add_argument( 27 | "--sp_ulysses_degree", 28 | type=int, 29 | default=None, 30 | help="sp_ulysses_degree (default: world_size)", 31 | ) 32 | parser.add_argument( 33 | "--ring_impl_type", 34 | type=str, 35 | default="basic_npu", 36 | choices=["basic_npu"], 37 | help="ring implementation type (default: basic_npu)", 38 | ) 39 | parser.add_argument( 40 | "--causal", 41 | action="store_true", 42 | help="whether to use causal attention (default: False)", 43 | ) 44 | parser.add_argument( 45 | "--attn_impl", 46 | type=str, 47 | default="npu", 48 | choices=[ 49 | "npu", 50 | ], 51 | help="attention implementation type (default: torch)", 52 | ) 53 | parser.add_argument( 54 | "--sparse_sage_l1", 55 | type=float, 56 | default=0.07, 57 | help="l1 for sparse sage attention (default: 0.07)", 58 | ) 59 | parser.add_argument( 60 | "--sparse_sage_pv_l1", 61 | type=float, 62 | default=0.08, 63 | help="pv_l1 for sparse sage attention (default: 0.08)", 64 | ) 65 | parser.add_argument( 66 | "--sparse_sage_tune_mode", 67 | action="store_true", 68 | default=False, 69 | help="enable tune mode for sparse sage attention (default: False)", 70 | ) 71 | parser.add_argument( 72 | "--sparse_sage_tune_path", 73 | type=str, 74 | default="./sparsesage_autotune.pt", 75 | help="path to the sparse sage autotune results (default: ./sparsesage_autotune.pt)", 76 | ) 77 | return parser.parse_args() 78 | 79 | 80 | def log(msg, a, rank0_only=False): 81 | world_size = dist.get_world_size() 82 | rank = dist.get_rank() 83 | if rank0_only: 84 | if rank == 0: 85 | print( 86 | f"[Rank#0] {msg}: " 87 | f"max {a.abs().max().item()}, " 88 | f"mean {a.abs().mean().item()}", 89 | flush=True, 90 | ) 91 | return 92 | 93 | for i in range(world_size): 94 | if i == rank: 95 | if rank == 0: 96 | print(f"{msg}:") 97 | print( 98 | f"[Rank#{rank}] " 99 | f"max {a.abs().max().item()}, " 100 | f"mean {a.abs().mean().item()}", 101 | flush=True, 102 | ) 103 | dist.barrier() 104 | 105 | 106 | # test it with: 107 | # torchrun --nproc_per_node=4 test/test_hybrid_attn.py 108 | if __name__ == "__main__": 109 | args = parse_args() 110 | 111 | torch.random.manual_seed(0) 112 | 113 | dist.init_process_group("hccl") 114 | 115 | rank = dist.get_rank() 116 | world_size = dist.get_world_size() 117 | 118 | # Inference mainly uses fp16; ROCM flash attention with bf16 precision is slightly larger, will be fixed soon 119 | dtype = torch.bfloat16 120 | device = torch.device(f"npu:{rank}") 121 | 122 | batch_size = 1 123 | seqlen = args.seqlen 124 | nheads = 32 125 | d = 2048 // 32 126 | dropout_p = 0 127 | causal = args.causal 128 | deterministic = False 129 | 130 | use_bwd = args.use_bwd 131 | 132 | assert seqlen % world_size == 0 133 | assert d % 8 == 0 134 | 135 | ring_impl_type = args.ring_impl_type 136 | 137 | # Prepare inputs 138 | q = torch.randn( 139 | batch_size, 140 | seqlen, 141 | nheads, 142 | d, 143 | device=device, 144 | dtype=dtype, 145 | requires_grad=True if use_bwd else False, 146 | ) 147 | k = torch.randn( 148 | batch_size, 149 | seqlen, 150 | nheads, 151 | d, 152 | device=device, 153 | dtype=dtype, 154 | requires_grad=True if use_bwd else False, 155 | ) 156 | v = torch.randn( 157 | batch_size, 158 | seqlen, 159 | nheads, 160 | d, 161 | device=device, 162 | dtype=dtype, 163 | requires_grad=True if use_bwd else False, 164 | ) 165 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 166 | 167 | dist.broadcast(q, src=0) 168 | dist.broadcast(k, src=0) 169 | dist.broadcast(v, src=0) 170 | dist.broadcast(dout, src=0) 171 | 172 | # prepare process group for hybrid sequence parallelism 173 | use_ring_low_dim = True 174 | 175 | sp_ulysses_degree = ( 176 | args.sp_ulysses_degree if args.sp_ulysses_degree is not None else world_size 177 | ) 178 | sp_ring_degree = world_size // sp_ulysses_degree 179 | 180 | print( 181 | f"rank {rank}, sp_ulysses_degree: {sp_ulysses_degree}, sp_ring_degree: {sp_ring_degree}" 182 | ) 183 | 184 | set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size) 185 | 186 | # Use EXTRACT_FUNC_DICT to shard the tensors 187 | local_q = ( 188 | EXTRACT_FUNC_DICT[ring_impl_type]( 189 | q, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 190 | ) 191 | .detach() 192 | .clone() 193 | ) 194 | 195 | local_k = ( 196 | EXTRACT_FUNC_DICT[ring_impl_type]( 197 | k, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 198 | ) 199 | .detach() 200 | .clone() 201 | ) 202 | 203 | local_v = ( 204 | EXTRACT_FUNC_DICT[ring_impl_type]( 205 | v, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 206 | ) 207 | .detach() 208 | .clone() 209 | ) 210 | 211 | if use_bwd: 212 | local_q.requires_grad = True 213 | local_k.requires_grad = True 214 | local_v.requires_grad = True 215 | 216 | # Map argument to AttnType enum 217 | attn_impl_map = { 218 | "npu": AttnType.NPU, 219 | } 220 | 221 | usp_attn = LongContextAttention( 222 | ring_impl_type=ring_impl_type, 223 | attn_type=attn_impl_map[args.attn_impl], 224 | ) 225 | 226 | if rank == 0: 227 | print("#" * 30) 228 | print("# ds-ulysses forward:") 229 | print("#" * 30) 230 | 231 | # common test parameters 232 | window_size = (-1, -1) 233 | alibi_slopes, attn_bias = None, None 234 | dropout_mask = None 235 | 236 | print(f"before usp attn forward: {local_q.shape} {local_k.shape} {local_v.shape}") 237 | 238 | # usp attn forward 239 | local_out = usp_attn( 240 | local_q, 241 | local_k, 242 | local_v 243 | ) 244 | 245 | # extract local dout 246 | local_dout = ( 247 | EXTRACT_FUNC_DICT[ring_impl_type]( 248 | dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 249 | ) 250 | .detach() 251 | .clone() 252 | ) 253 | 254 | max_memory = torch.cuda.max_memory_allocated(device) / ( 255 | 1024 * 1024 256 | ) # Convert to MB 257 | print(f"[Rank#{rank}] Maximum GPU memory used: {max_memory:.2f} MB") 258 | torch.cuda.reset_peak_memory_stats(device) # Reset stats 259 | 260 | if rank == 0: 261 | print("#" * 30) 262 | print("# ds-ulysses backward:") 263 | print("#" * 30) 264 | 265 | # usp attn backward 266 | if use_bwd: 267 | local_out.backward(local_dout) 268 | 269 | dist.barrier() 270 | 271 | if rank == 0: 272 | print("#" * 30) 273 | print("# local forward:") 274 | print("#" * 30) 275 | # reference, a local flash attn 276 | softmax_scale = q.shape[-1] ** -0.5 277 | out_ref = torch_npu.npu_fusion_attention_v2(q, k, v, 278 | head_num = q.shape[-2], 279 | input_layout = "BSND", 280 | scale = softmax_scale, 281 | pre_tokens=65535, 282 | next_tokens=65535)[0] 283 | if rank == 0: 284 | print("#" * 30) 285 | print("# local forward:") 286 | print("#" * 30) 287 | 288 | if use_bwd: 289 | out_ref.backward(dout) 290 | 291 | dist.barrier() 292 | 293 | # check correctness 294 | # When checking correctness, use EXTRACT_FUNC_DICT for reference outputs 295 | local_out_ref = EXTRACT_FUNC_DICT[ring_impl_type]( 296 | out_ref, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 297 | ) 298 | 299 | log("local (rank) out", local_out, rank0_only=True) 300 | log("out (distributed) - out_ref (non-distributed) diff", local_out_ref - local_out) 301 | 302 | # log("out_ref (non-distributed) - out_pt_ref (gpu) diff", local_out_ref - local_out_pt_ref) 303 | 304 | torch.testing.assert_close(local_out, local_out_ref, atol=1e-1, rtol=0) 305 | # torch.testing.assert_close(out_ref, out_pt_ref, atol=1e-2, rtol=0) 306 | 307 | if use_bwd: 308 | local_dq_ref = EXTRACT_FUNC_DICT[ring_impl_type]( 309 | q.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 310 | ) 311 | log("load_dq", local_q.grad) 312 | log("dq diff", local_dq_ref - local_q.grad) 313 | 314 | local_dk_ref = EXTRACT_FUNC_DICT[ring_impl_type]( 315 | k.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 316 | ) 317 | log("load_dk", local_k.grad) 318 | log("dk diff", local_dk_ref - local_k.grad) 319 | 320 | local_dv_ref = EXTRACT_FUNC_DICT[ring_impl_type]( 321 | v.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 322 | ) 323 | log("load_dv", local_v.grad) 324 | log("dv diff", local_dv_ref - local_v.grad) 325 | 326 | if dist.is_initialized(): 327 | dist.destroy_process_group() 328 | -------------------------------------------------------------------------------- /yunchang/comm/all_to_all.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation and Jiarui Fang 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | 6 | import torch 7 | 8 | from typing import Any, Tuple 9 | from torch import Tensor 10 | from torch.nn import Module 11 | 12 | import torch.distributed as dist 13 | 14 | 15 | def all_to_all_4D( 16 | input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False 17 | ) -> torch.tensor: 18 | """ 19 | all-to-all for QKV 20 | 21 | Args: 22 | input (torch.tensor): a tensor sharded along dim scatter dim 23 | scatter_idx (int): default 1 24 | gather_idx (int): default 2 25 | group : torch process group 26 | use_sync (bool): whether to synchronize after all-to-all 27 | 28 | Returns: 29 | torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) 30 | """ 31 | assert ( 32 | input.dim() == 4 33 | ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" 34 | 35 | seq_world_size = dist.get_world_size(group) 36 | 37 | if scatter_idx == 2 and gather_idx == 1: 38 | # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) 39 | bs, shard_seqlen, hc, hs = input.shape 40 | seqlen = shard_seqlen * seq_world_size 41 | shard_hc = hc // seq_world_size 42 | 43 | # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! 44 | # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) 45 | input_t = ( 46 | input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs) 47 | .transpose(0, 2) 48 | .contiguous() 49 | ) 50 | 51 | output = torch.empty_like(input_t) 52 | # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single 53 | # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head 54 | 55 | if seq_world_size > 1: 56 | dist.all_to_all_single(output, input_t, group=group) 57 | if use_sync: 58 | torch.cuda.synchronize() 59 | else: 60 | output = input_t 61 | # if scattering the seq-dim, transpose the heads back to the original dimension 62 | output = output.reshape(seqlen, bs, shard_hc, hs) 63 | 64 | # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) 65 | output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) 66 | 67 | return output 68 | 69 | elif scatter_idx == 1 and gather_idx == 2: 70 | # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) 71 | bs, seqlen, shard_hc, hs = input.shape 72 | hc = shard_hc * seq_world_size 73 | shard_seqlen = seqlen // seq_world_size 74 | seq_world_size = dist.get_world_size(group) 75 | 76 | # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! 77 | # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) 78 | input_t = ( 79 | input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs) 80 | .transpose(0, 3) 81 | .transpose(0, 1) 82 | .contiguous() 83 | .reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs) 84 | ) 85 | 86 | output = torch.empty_like(input_t) 87 | # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single 88 | # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head 89 | if seq_world_size > 1: 90 | dist.all_to_all_single(output, input_t, group=group) 91 | if use_sync: 92 | torch.cuda.synchronize() 93 | else: 94 | output = input_t 95 | 96 | # if scattering the seq-dim, transpose the heads back to the original dimension 97 | output = output.reshape(hc, shard_seqlen, bs, hs) 98 | 99 | # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) 100 | output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) 101 | 102 | return output 103 | else: 104 | raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") 105 | 106 | 107 | class SeqAllToAll4D(torch.autograd.Function): 108 | @staticmethod 109 | def forward( 110 | ctx: Any, 111 | group: dist.ProcessGroup, 112 | input: Tensor, 113 | scatter_idx: int, 114 | gather_idx: int, 115 | use_sync: bool = False, 116 | ) -> Tensor: 117 | 118 | ctx.group = group 119 | ctx.scatter_idx = scatter_idx 120 | ctx.gather_idx = gather_idx 121 | ctx.use_sync = use_sync 122 | return all_to_all_4D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) 123 | 124 | @staticmethod 125 | def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: 126 | return ( 127 | None, 128 | SeqAllToAll4D.apply( 129 | ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync 130 | ), 131 | None, 132 | None, 133 | None, 134 | ) 135 | 136 | 137 | def all_to_all_5D( 138 | input: torch.tensor, scatter_idx: int = 3, gather_idx: int = 1, group=None, use_sync: bool = False 139 | ) -> torch.tensor: 140 | """ 141 | all-to-all for QKV 142 | forward (bs, seqlen/N, 3, hc, hs) -> (bs, seqlen, 3, hc/N, hs) 143 | 144 | Args: 145 | input (torch.tensor): a tensor sharded along dim scatter dim 146 | scatter_idx (int): default 1 147 | gather_idx (int): default 2 148 | group : torch process group 149 | use_sync: whether to synchronize after all-to-all 150 | 151 | Returns: 152 | torch.tensor: resharded tensor (bs, seqlen/P, 3, hc, hs) 153 | """ 154 | assert ( 155 | input.dim() == 5 156 | ), f"input must be 5D tensor, got {input.dim()} and shape {input.shape}" 157 | 158 | seq_world_size = dist.get_world_size(group) 159 | 160 | if scatter_idx == 3 and gather_idx == 1: 161 | # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, 3, hc, hs) output: (bs, seqlen, 3, hc/P, hs) 162 | bs, shard_seqlen, t_cnt, hc, hs = input.shape 163 | 164 | assert t_cnt == 3 165 | seqlen = shard_seqlen * seq_world_size 166 | shard_hc = hc // seq_world_size 167 | 168 | # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! 169 | # (bs, seqlen/P, 3, hc, hs) -reshape-> (bs, seq_len/P, 3, P, hc/P, hs) -transpose(0,3)-> (P, seq_len/P, 3, bs, hc/P, hs) 170 | input_t = ( 171 | input.reshape(bs, shard_seqlen, 3, seq_world_size, shard_hc, hs) 172 | .transpose(0, 3) 173 | .contiguous() 174 | ) 175 | 176 | output = torch.empty_like(input_t) 177 | # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single 178 | # (P, seq_len/P, 3, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, 3, bs, hc/P, hs) scatter head 179 | if seq_world_size > 1: 180 | dist.all_to_all_single(output, input_t, group=group) 181 | if use_sync: 182 | torch.cuda.synchronize() 183 | else: 184 | output = input_t 185 | 186 | # if scattering the seq-dim, transpose the heads back to the original dimension 187 | output = output.reshape(seqlen, 3, bs, shard_hc, hs) 188 | 189 | # (seq_len, 3, bs, hc/P, hs) -trans-> (bs, seq_len, 3, hc/P, hs) 190 | output = output.transpose(0, 2).transpose(1, 2).contiguous() 191 | 192 | return output.reshape(bs, seqlen, 3, shard_hc, hs).contiguous() 193 | elif scatter_idx == 1 and gather_idx == 3: 194 | # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) 195 | bs, seqlen, _, shard_hc, hs = input.shape 196 | hc = shard_hc * seq_world_size 197 | shard_seqlen = seqlen // seq_world_size 198 | seq_world_size = dist.get_world_size(group) 199 | 200 | # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! 201 | # (bs, seqlen, 3, hc/P, hs) -reshape-> (bs, P, seq_len/P, 3, hc/P, hs) -transpose(0, 4)-> (hc/P, P, seqlen/P, 3, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, 3, bs, hs) 202 | input_t = ( 203 | input.reshape(bs, seq_world_size, shard_seqlen, 3, shard_hc, hs) 204 | .transpose(0, 4) 205 | .transpose(0, 1) 206 | .contiguous() 207 | .reshape(seq_world_size, shard_hc, shard_seqlen, 3, bs, hs) 208 | ) 209 | 210 | output = torch.empty_like(input_t) 211 | # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single 212 | # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head 213 | if seq_world_size > 1: 214 | dist.all_to_all_single(output, input_t, group=group) 215 | if use_sync: 216 | torch.cuda.synchronize() 217 | else: 218 | output = input_t 219 | 220 | # if scattering the seq-dim, transpose the heads back to the original dimension 221 | output = output.reshape(hc, shard_seqlen, 3, bs, hs) 222 | 223 | # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) 224 | output = output.transpose(0, 3).contiguous() 225 | 226 | return output.reshape(bs, shard_seqlen, 3, hc, hs).contiguous() 227 | else: 228 | raise RuntimeError("scatter_idx must be 1 or 3 and gather_idx must be 1 or 3") 229 | 230 | 231 | class SeqAllToAll5D(torch.autograd.Function): 232 | @staticmethod 233 | def forward( 234 | ctx: Any, 235 | group: dist.ProcessGroup, 236 | input: Tensor, 237 | scatter_idx: int = 3, 238 | gather_idx: int = 1, 239 | use_sync: bool = False, 240 | ) -> Tensor: 241 | 242 | ctx.group = group 243 | ctx.scatter_idx = scatter_idx 244 | ctx.gather_idx = gather_idx 245 | ctx.use_sync = use_sync 246 | 247 | return all_to_all_5D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) 248 | 249 | @staticmethod 250 | def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: 251 | return ( 252 | None, 253 | SeqAllToAll5D.apply( 254 | ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync 255 | ), 256 | None, 257 | None, 258 | None, 259 | ) 260 | -------------------------------------------------------------------------------- /yunchang/ring/zigzag_ring_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .utils import RingComm, update_out_and_lse 3 | from yunchang.kernels import AttnType, select_flash_attn_impl 4 | 5 | def zigzag_ring_flash_attn_forward( 6 | process_group, 7 | q: torch.Tensor, 8 | k: torch.Tensor, 9 | v: torch.Tensor, 10 | softmax_scale, 11 | dropout_p=0, 12 | causal=True, 13 | window_size=(-1, -1), 14 | softcap=0.0, 15 | alibi_slopes=None, 16 | deterministic=False, 17 | attn_type: AttnType = AttnType.FA, 18 | ): 19 | assert causal == True, "zigzag ring is meaningless for causal=False" 20 | comm = RingComm(process_group) 21 | 22 | block_seq_len = q.shape[1] // 2 23 | q1 = q[:, block_seq_len:] 24 | 25 | out = None 26 | lse = None 27 | next_k, next_v = None, None 28 | 29 | def forward(q, k, v, causal): 30 | fn = select_flash_attn_impl(attn_type, stage="fwd-only") 31 | block_out, block_lse = fn( 32 | q, 33 | k, 34 | v, 35 | dropout_p, 36 | softmax_scale, 37 | causal=causal, 38 | window_size=window_size, 39 | softcap=softcap, 40 | alibi_slopes=alibi_slopes, 41 | return_softmax=True and dropout_p > 0, 42 | ) 43 | return block_out, block_lse 44 | 45 | for step in range(comm.world_size): 46 | if step + 1 != comm.world_size: 47 | next_k: torch.Tensor = comm.send_recv(k) 48 | next_v: torch.Tensor = comm.send_recv(v) 49 | comm.commit() 50 | 51 | if step == 0: 52 | block_out, block_lse = forward(q, k, v, causal=True) 53 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 54 | elif step <= comm.rank: 55 | k0 = k[:, :block_seq_len] 56 | v0 = v[:, :block_seq_len] 57 | block_out, block_lse = forward(q, k0, v0, causal=False) 58 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 59 | else: 60 | block_out, block_lse = forward(q1, k, v, causal=False) 61 | out, lse = update_out_and_lse( 62 | out, 63 | lse, 64 | block_out, 65 | block_lse, 66 | slice_=(slice(None), slice(block_seq_len, None)), 67 | ) 68 | 69 | if step + 1 != comm.world_size: 70 | comm.wait() 71 | k = next_k 72 | v = next_v 73 | 74 | out = out.to(q.dtype) 75 | lse = lse.squeeze(dim=-1).transpose(1, 2) 76 | return out, lse 77 | 78 | 79 | def zigzag_ring_flash_attn_backward( 80 | process_group, 81 | dout, 82 | q, 83 | k, 84 | v, 85 | out, 86 | softmax_lse, 87 | softmax_scale, 88 | dropout_p=0, 89 | causal=True, 90 | window_size=(-1, -1), 91 | softcap=0.0, 92 | alibi_slopes=None, 93 | deterministic=False, 94 | attn_type: AttnType = AttnType.FA, 95 | ): 96 | assert causal == True, "zigzag ring is meaningless for causal=False" 97 | kv_comm = RingComm(process_group) 98 | d_kv_comm = RingComm(process_group) 99 | dq, dk, dv = None, None, None 100 | next_dk, next_dv = None, None 101 | next_k, next_v = None, None 102 | dk_comm_buffer, dv_comm_buffer = None, None 103 | 104 | dout1 = dout.chunk(2, dim=1)[1] 105 | q1 = q.chunk(2, dim=1)[1] 106 | out1 = out.chunk(2, dim=1)[1] 107 | softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() 108 | block_seq_len = q.shape[1] // 2 109 | 110 | # repeatly allocating buffer may be slow... 111 | dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 112 | dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 113 | dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 114 | 115 | def backward(dout, q, k, v, out, softmax_lse, causal): 116 | seqlen_q = q.shape[1] 117 | seqlen_kv = k.shape[1] 118 | fn = select_flash_attn_impl(attn_type, stage="bwd-only") 119 | fn( 120 | dout, 121 | q, 122 | k, 123 | v, 124 | out, 125 | softmax_lse, 126 | dq_buffer[:, :seqlen_q], 127 | dk_buffer[:, :seqlen_kv], 128 | dv_buffer[:, :seqlen_kv], 129 | dropout_p, 130 | softmax_scale, 131 | causal, 132 | window_size, 133 | softcap, 134 | alibi_slopes, 135 | deterministic, 136 | rng_state=None, 137 | ) 138 | 139 | for step in range(kv_comm.world_size): 140 | if step + 1 != kv_comm.world_size: 141 | next_k = kv_comm.send_recv(k) 142 | next_v = kv_comm.send_recv(v) 143 | kv_comm.commit() 144 | 145 | if step == 0: 146 | backward(dout, q, k, v, out, softmax_lse, causal=True) 147 | dq = dq_buffer.to(torch.float32) 148 | dk = dk_buffer.to(torch.float32) 149 | dv = dv_buffer.to(torch.float32) 150 | else: 151 | if step <= kv_comm.rank: 152 | k0 = k[:, :block_seq_len] 153 | v0 = v[:, :block_seq_len] 154 | backward(dout, q, k0, v0, out, softmax_lse, causal=False) 155 | dq += dq_buffer 156 | else: 157 | backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) 158 | # always use the first half in dq_buffer. 159 | dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] 160 | 161 | d_kv_comm.wait() 162 | dk_comm_buffer, dv_comm_buffer = dk, dv 163 | dk, dv = next_dk, next_dv 164 | 165 | if step <= kv_comm.rank: 166 | dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] 167 | dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] 168 | else: 169 | dk += dk_buffer 170 | dv += dv_buffer 171 | 172 | if step + 1 != kv_comm.world_size: 173 | kv_comm.wait() 174 | k = next_k 175 | v = next_v 176 | 177 | next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) 178 | next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) 179 | d_kv_comm.commit() 180 | 181 | d_kv_comm.wait() 182 | 183 | return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) 184 | 185 | 186 | class ZigZagRingFlashAttnFunc(torch.autograd.Function): 187 | @staticmethod 188 | def forward( 189 | ctx, 190 | q, 191 | k, 192 | v, 193 | dropout_p, 194 | softmax_scale, 195 | causal, 196 | window_size, 197 | softcap, 198 | alibi_slopes, 199 | deterministic, 200 | return_softmax, 201 | group, 202 | attn_type, 203 | ): 204 | if softmax_scale is None: 205 | softmax_scale = q.shape[-1] ** (-0.5) 206 | 207 | assert alibi_slopes is None 208 | k = k.contiguous() 209 | v = v.contiguous() 210 | out, softmax_lse = zigzag_ring_flash_attn_forward( 211 | group, 212 | q, 213 | k, 214 | v, 215 | softmax_scale=softmax_scale, 216 | dropout_p=dropout_p, 217 | causal=causal, 218 | window_size=window_size, 219 | softcap=softcap, 220 | alibi_slopes=alibi_slopes, 221 | deterministic=False, 222 | attn_type=attn_type, 223 | ) 224 | # this should be out_padded 225 | ctx.save_for_backward(q, k, v, out, softmax_lse) 226 | ctx.dropout_p = dropout_p 227 | ctx.softmax_scale = softmax_scale 228 | ctx.causal = causal 229 | ctx.window_size = window_size 230 | ctx.softcap = softcap 231 | ctx.alibi_slopes = alibi_slopes 232 | ctx.deterministic = deterministic 233 | ctx.group = group 234 | ctx.attn_type = attn_type 235 | return out if not return_softmax else (out, softmax_lse, None) 236 | 237 | @staticmethod 238 | def backward(ctx, dout, *args): 239 | q, k, v, out, softmax_lse = ctx.saved_tensors 240 | dq, dk, dv = zigzag_ring_flash_attn_backward( 241 | ctx.group, 242 | dout, 243 | q, 244 | k, 245 | v, 246 | out, 247 | softmax_lse, 248 | softmax_scale=ctx.softmax_scale, 249 | dropout_p=ctx.dropout_p, 250 | causal=ctx.causal, 251 | window_size=ctx.window_size, 252 | softcap=ctx.softcap, 253 | alibi_slopes=ctx.alibi_slopes, 254 | deterministic=ctx.deterministic, 255 | attn_type=ctx.attn_type, 256 | ) 257 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None 258 | 259 | 260 | def zigzag_ring_flash_attn_qkvpacked_func( 261 | qkv, 262 | dropout_p=0.0, 263 | softmax_scale=None, 264 | causal=False, 265 | window_size=(-1, -1), 266 | softcap=0.0, 267 | alibi_slopes=None, 268 | deterministic=False, 269 | return_attn_probs=False, 270 | group=None, 271 | attn_type: AttnType = AttnType.FA, 272 | ): 273 | return ZigZagRingFlashAttnFunc.apply( 274 | qkv[:, :, 0], 275 | qkv[:, :, 1], 276 | qkv[:, :, 2], 277 | dropout_p, 278 | softmax_scale, 279 | causal, 280 | window_size, 281 | softcap, 282 | alibi_slopes, 283 | deterministic, 284 | return_attn_probs, 285 | group, 286 | attn_type, 287 | ) 288 | 289 | 290 | def zigzag_ring_flash_attn_kvpacked_func( 291 | q, 292 | kv, 293 | dropout_p=0.0, 294 | softmax_scale=None, 295 | causal=False, 296 | window_size=(-1, -1), 297 | softcap=0.0, 298 | alibi_slopes=None, 299 | deterministic=False, 300 | return_attn_probs=False, 301 | group=None, 302 | attn_type: AttnType = AttnType.FA, 303 | ): 304 | return ZigZagRingFlashAttnFunc.apply( 305 | q, 306 | kv[:, :, 0], 307 | kv[:, :, 1], 308 | dropout_p, 309 | softmax_scale, 310 | causal, 311 | window_size, 312 | softcap, 313 | alibi_slopes, 314 | deterministic, 315 | return_attn_probs, 316 | group, 317 | attn_type, 318 | ) 319 | 320 | 321 | def zigzag_ring_flash_attn_func( 322 | q, 323 | k, 324 | v, 325 | dropout_p=0.0, 326 | softmax_scale=None, 327 | causal=False, 328 | window_size=(-1, -1), 329 | softcap=0.0, 330 | alibi_slopes=None, 331 | deterministic=False, 332 | return_attn_probs=False, 333 | group=None, 334 | attn_type: AttnType = AttnType.FA, 335 | attn_processor=None, 336 | ): 337 | return ZigZagRingFlashAttnFunc.apply( 338 | q, 339 | k, 340 | v, 341 | dropout_p, 342 | softmax_scale, 343 | causal, 344 | window_size, 345 | softcap, 346 | alibi_slopes, 347 | deterministic, 348 | return_attn_probs, 349 | group, 350 | attn_type, 351 | ) 352 | -------------------------------------------------------------------------------- /yunchang/ring/stripe_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from yunchang.kernels import select_flash_attn_impl, AttnType 3 | from .utils import RingComm, update_out_and_lse 4 | 5 | 6 | def stripe_flash_attn_forward( 7 | process_group, 8 | q: torch.Tensor, 9 | k: torch.Tensor, 10 | v: torch.Tensor, 11 | softmax_scale, 12 | dropout_p=0, 13 | causal=True, 14 | window_size=(-1, -1), 15 | softcap=0.0, 16 | alibi_slopes=None, 17 | deterministic=False, 18 | attn_type: AttnType = AttnType.FA, 19 | ): 20 | assert ( 21 | causal 22 | ), "stripe flash attn only supports causal attention, if not causal, use ring flash attn instead" 23 | comm = RingComm(process_group) 24 | 25 | out = None 26 | lse = None 27 | 28 | next_k, next_v = None, None 29 | 30 | for step in range(comm.world_size): 31 | if step + 1 != comm.world_size: 32 | next_k: torch.Tensor = comm.send_recv(k) 33 | next_v: torch.Tensor = comm.send_recv(v) 34 | comm.commit() 35 | 36 | if step <= comm.rank: 37 | fn = select_flash_attn_impl(attn_type, stage="fwd-only") 38 | block_out, block_lse = fn( 39 | q, 40 | k, 41 | v, 42 | dropout_p, 43 | softmax_scale, 44 | causal=causal, 45 | window_size=window_size, 46 | softcap=softcap, 47 | alibi_slopes=alibi_slopes, 48 | return_softmax=True and dropout_p > 0, 49 | ) 50 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 51 | else: 52 | fn = select_flash_attn_impl(attn_type, stage="fwd-only") 53 | block_out, block_lse = fn( 54 | q[:, 1:], 55 | k[:, :-1], 56 | v[:, :-1], 57 | dropout_p, 58 | softmax_scale, 59 | causal=causal, 60 | window_size=window_size, 61 | softcap=softcap, 62 | alibi_slopes=alibi_slopes, 63 | return_softmax=True and dropout_p > 0, 64 | ) 65 | out, lse = update_out_and_lse( 66 | out, lse, block_out, block_lse, slice_=(slice(None), slice(1, None)) 67 | ) 68 | 69 | if step + 1 != comm.world_size: 70 | comm.wait() 71 | k = next_k 72 | v = next_v 73 | 74 | out = out.to(q.dtype) 75 | lse = lse.squeeze(dim=-1).transpose(1, 2) 76 | return out, lse 77 | 78 | 79 | def stripe_flash_attn_backward( 80 | process_group, 81 | dout, 82 | q, 83 | k, 84 | v, 85 | out, 86 | softmax_lse, 87 | softmax_scale, 88 | dropout_p=0, 89 | causal=True, 90 | window_size=(-1, -1), 91 | softcap=0.0, 92 | alibi_slopes=None, 93 | deterministic=False, 94 | attn_type: AttnType = AttnType.FA, 95 | ): 96 | assert ( 97 | causal 98 | ), "stripe flash attn only supports causal attention, if not causal, ring flash attn instead" 99 | kv_comm = RingComm(process_group) 100 | d_kv_comm = RingComm(process_group) 101 | dq, dk, dv = None, None, None 102 | next_dk, next_dv = None, None 103 | next_k, next_v = None, None 104 | dk_comm_buffer, dv_comm_buffer = None, None 105 | 106 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 107 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 108 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 109 | for step in range(kv_comm.world_size): 110 | if step + 1 != kv_comm.world_size: 111 | next_k = kv_comm.send_recv(k) 112 | next_v = kv_comm.send_recv(v) 113 | kv_comm.commit() 114 | 115 | shift_causal = step > kv_comm.rank 116 | softmax_lse_1 = None 117 | if not shift_causal: 118 | fn = select_flash_attn_impl(attn_type, stage="bwd-only") 119 | fn( 120 | dout, 121 | q, 122 | k, 123 | v, 124 | out, 125 | softmax_lse, 126 | block_dq_buffer, 127 | block_dk_buffer, 128 | block_dv_buffer, 129 | dropout_p, 130 | softmax_scale, 131 | causal, 132 | window_size, 133 | softcap, 134 | alibi_slopes, 135 | deterministic, 136 | rng_state=None, 137 | ) 138 | else: 139 | if softmax_lse_1 is None: 140 | # lazy init, since the last rank does not need softmax_lse_1 141 | softmax_lse_1 = softmax_lse[:, :, 1:].contiguous() 142 | fn = select_flash_attn_impl(attn_type, stage="bwd-only") 143 | fn( 144 | dout[:, 1:], 145 | q[:, 1:], 146 | k[:, :-1], 147 | v[:, :-1], 148 | out[:, 1:], 149 | softmax_lse_1, 150 | block_dq_buffer[:, 1:], 151 | block_dk_buffer[:, :-1], 152 | block_dv_buffer[:, :-1], 153 | dropout_p, 154 | softmax_scale, 155 | causal, 156 | window_size, 157 | softcap, 158 | alibi_slopes, 159 | deterministic, 160 | rng_state=None, 161 | ) 162 | 163 | if dq is None: 164 | dq = block_dq_buffer.to(torch.float32) 165 | dk = block_dk_buffer.to(torch.float32) 166 | dv = block_dv_buffer.to(torch.float32) 167 | else: 168 | if not shift_causal: 169 | dq += block_dq_buffer 170 | else: 171 | dq[:, 1:] += block_dq_buffer[:, 1:] 172 | d_kv_comm.wait() 173 | dk_comm_buffer, dv_comm_buffer = dk, dv 174 | dk = next_dk 175 | dv = next_dv 176 | 177 | if not shift_causal: 178 | dk = block_dk_buffer + dk 179 | dv = block_dv_buffer + dv 180 | else: 181 | dk[:, :-1] += block_dk_buffer[:, :-1] 182 | dv[:, :-1] += block_dv_buffer[:, :-1] 183 | 184 | if step + 1 != kv_comm.world_size: 185 | kv_comm.wait() 186 | k = next_k 187 | v = next_v 188 | 189 | next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) 190 | next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) 191 | d_kv_comm.commit() 192 | 193 | d_kv_comm.wait() 194 | 195 | return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) 196 | 197 | 198 | class StripeFlashAttnFunc(torch.autograd.Function): 199 | @staticmethod 200 | def forward( 201 | ctx, 202 | q, 203 | k, 204 | v, 205 | dropout_p, 206 | softmax_scale, 207 | causal, 208 | window_size, 209 | softcap, 210 | alibi_slopes, 211 | deterministic, 212 | return_softmax, 213 | group, 214 | attn_type: AttnType = AttnType.FA, 215 | ): 216 | if softmax_scale is None: 217 | softmax_scale = q.shape[-1] ** (-0.5) 218 | 219 | assert alibi_slopes is None 220 | k = k.contiguous() 221 | v = v.contiguous() 222 | out, softmax_lse = stripe_flash_attn_forward( 223 | group, 224 | q, 225 | k, 226 | v, 227 | softmax_scale=softmax_scale, 228 | dropout_p=dropout_p, 229 | causal=causal, 230 | window_size=window_size, 231 | softcap=softcap, 232 | alibi_slopes=alibi_slopes, 233 | deterministic=False, 234 | attn_type=attn_type, 235 | ) 236 | # this should be out_padded 237 | ctx.save_for_backward(q, k, v, out, softmax_lse) 238 | ctx.dropout_p = dropout_p 239 | ctx.softmax_scale = softmax_scale 240 | ctx.causal = causal 241 | ctx.window_size = window_size 242 | ctx.softcap = softcap 243 | ctx.alibi_slopes = alibi_slopes 244 | ctx.deterministic = deterministic 245 | ctx.group = group 246 | ctx.attn_type = attn_type 247 | return out if not return_softmax else (out, softmax_lse, None) 248 | 249 | @staticmethod 250 | def backward(ctx, dout, *args): 251 | q, k, v, out, softmax_lse = ctx.saved_tensors 252 | dq, dk, dv = stripe_flash_attn_backward( 253 | ctx.group, 254 | dout, 255 | q, 256 | k, 257 | v, 258 | out, 259 | softmax_lse, 260 | softmax_scale=ctx.softmax_scale, 261 | dropout_p=ctx.dropout_p, 262 | causal=ctx.causal, 263 | window_size=ctx.window_size, 264 | softcap=ctx.softcap, 265 | alibi_slopes=ctx.alibi_slopes, 266 | deterministic=ctx.deterministic, 267 | attn_type=ctx.attn_type, 268 | ) 269 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None 270 | 271 | 272 | def stripe_flash_attn_qkvpacked_func( 273 | qkv, 274 | dropout_p=0.0, 275 | softmax_scale=None, 276 | causal=False, 277 | window_size=(-1, -1), # -1 means infinite context window 278 | softcap=0.0, 279 | alibi_slopes=None, 280 | deterministic=False, 281 | return_attn_probs=False, 282 | group=None, 283 | attn_type: AttnType = AttnType.FA, 284 | ): 285 | return StripeFlashAttnFunc.apply( 286 | qkv[:, :, 0], 287 | qkv[:, :, 1], 288 | qkv[:, :, 2], 289 | dropout_p, 290 | softmax_scale, 291 | causal, 292 | window_size, 293 | softcap, 294 | alibi_slopes, 295 | deterministic, 296 | return_attn_probs, 297 | group, 298 | attn_type, 299 | ) 300 | 301 | 302 | def stripe_flash_attn_kvpacked_func( 303 | q, 304 | kv, 305 | dropout_p=0.0, 306 | softmax_scale=None, 307 | causal=False, 308 | window_size=(-1, -1), # -1 means infinite context window 309 | softcap=0.0, 310 | alibi_slopes=None, 311 | deterministic=False, 312 | return_attn_probs=False, 313 | group=None, 314 | attn_type: AttnType = AttnType.FA, 315 | ): 316 | return StripeFlashAttnFunc.apply( 317 | q, 318 | kv[:, :, 0], 319 | kv[:, :, 1], 320 | dropout_p, 321 | softmax_scale, 322 | causal, 323 | window_size, 324 | softcap, 325 | alibi_slopes, 326 | deterministic, 327 | return_attn_probs, 328 | group, 329 | attn_type, 330 | ) 331 | 332 | 333 | def stripe_flash_attn_func( 334 | q, 335 | k, 336 | v, 337 | dropout_p=0.0, 338 | softmax_scale=None, 339 | causal=False, 340 | window_size=(-1, -1), # -1 means infinite context window 341 | softcap=0.0, 342 | alibi_slopes=None, 343 | deterministic=False, 344 | return_attn_probs=False, 345 | group=None, 346 | attn_type: AttnType = AttnType.FA, 347 | attn_processor=None, 348 | ): 349 | return StripeFlashAttnFunc.apply( 350 | q, 351 | k, 352 | v, 353 | dropout_p, 354 | softmax_scale, 355 | causal, 356 | window_size, 357 | softcap, 358 | alibi_slopes, 359 | deterministic, 360 | return_attn_probs, 361 | group, 362 | attn_type, 363 | ) 364 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [2024] [feifeibear (Jiarui Fang)] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YunChang: A Unified Sequence Parallel (USP) Attention for Long Context LLM Model Training and Inference. 2 | 3 | [\[Tech Report\] USP: A Unified Sequence Parallelism Approach for Long Context Generative AI](https://arxiv.org/abs/2405.07719) 4 | 5 | 6 |

7 | 8 |

9 | 10 | This repo provides a sequence parallel approach that synergizes the strengths of two popular distributed attentions, i.e. DeepSpeed-Ulysses-Attention and Ring-Attention, delivering a more general and stronger versatility and better performance. 11 | The project is built on [zhuzilin/ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) and refers to the [DeepSpeed-Ulysses](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/README.md). 12 | 13 | USP has been applied in [NVIDIA/TransformerEngine](https://github.com/NVIDIA/TransformerEngine/blob/54aa12a9a1f166c53a20f17f309adeab5698f5f6/transformer_engine/pytorch/attention.py#L1542) `AttnFuncWithCPAndKVP2P`. You can use it in API `attn_forward_func_with_cp`. 14 | 15 | 16 | ## Why not apply Ulysses and Ring Attention Individually? 17 | 18 | - Ulysses is sensitive to the number of attention heads. 19 | The parallelism degree in Ulysses cannot exceed the number of heads. 20 | Consequently, it is not suitable for GQA (Grouped Query Attention) and MQA (Multi-Query Attention) scenarios. For instance, Ulysses does not operate effectively with a single head. 21 | In addition, since Tensor Parallelism also requires division across the head number dimension, achieving compatibility between Ulysses and TP can be challenging. 22 | 23 | - Ring-Attention is ineffient than Ulysses in computation and communication. 24 | Ring-Attention segments the Query, Key, and Value (QKV) into smaller blocks, which can lead to a decrease in efficiency when using FlashAttention. 25 | Even with the communication and computation processes fully overlapped, the total execution time lags behind that of Ulysses. 26 | Furthermore, Ring-Attention utilizes asynchronous peer-to-peer communication, which not only has a lower bandwidth utilization compared to collective communication methods but also poses the risk of potential communication deadlocks in large-scale deployments. 27 | 28 | 29 | ## LongContextAttention, also known as Unified Sequence Parallelism and Hybrid Sequence Parallelism 30 | 31 | `LongContextAttention` is a **unified sequence parallel** , also known as **hybrid sequence parallel** ,that hybrid DeepSpeed-Ulysses-Attention and Ring-Attention therefore addressing the limitations of both methods. 32 | 33 |

34 | 35 |

36 | 37 | 38 | ### 1. Installation 39 | 40 | FlashAttention is the most important external dependency and is often the cause of errors when installing and using yunchang. 41 | Yunchang supports flash_attn 2.6.x and 2.7.x, both v3 and v2 versions. Additionally, yunchang supports runs without flash_attn, which is suitable for NPUs. 42 | 43 | As shown in the figure below, there are three usage methods based on the flash_attn situation: 44 | 45 | 1. For H100, B100, hardware that supports FA v3, ring_flash_attn uses FA v3. 46 | 47 | 2. For A100, L40, hardware that supports FA v2, ring_flash_attn uses FA v2. 48 | 49 | 3. For hardware such as NPUs that does not support FA, use torch to implement attention computation. In this case, there is no need to install `flash_attn`, and you should apply `LongContextAttention(ring_impl_type="basic", attn_type=AttnType.TORCH)`. *Note: the backward pass is not supported for AttnType.TORCH.* 50 | 51 | Option 1: pip install 52 | 53 | `pip install flash-attn` 54 | 55 | `pip install yunchang` 56 | 57 | #### Apply FlashAttention V3: Since FA V3 is beta-released, you need to install FlashAttention V3 from source code. 58 | 59 | Follow the [FlashAttention beta-release](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release) to install V3 for NVIDIA Hopper GPUs. 60 | 61 | We applied the Nov 10 2024 commit `b443207c1fc4c98e4532aad4e88cfee1d590d996`. 62 | 63 | 64 | Option 2: build from local. 65 | 66 | `pip install .` 67 | 68 | Install for AMD GPU: [install_amd.md](./docs/install_amd.md) 69 | 70 | 71 | ### 2. Usage 72 | 73 | Please refer to [test/test_hybrid_qkvpacked_attn.py](./test/test_hybrid_qkvpacked_attn.py) and [test/test_hybrid_attn.py](./test/test_hybrid_attn.py) for usage. 74 | 75 | In short, we take the `zigzag` ring attention implementation as an example: 76 | 77 | 1. apply `set_seq_parallel_pg` to set the process group 78 | 2. extract local tensors with `zigzag_extract_local`. We need reorder the input tokens or input tensors for load balance ring attention. 79 | 3. then apply `LongContextAttention(ring_impl_type="zigzag")` as a drop-in replacement for Attention implementation. 80 | 81 | ```python 82 | from yunchang import ( 83 | AsyncLongContextAttention, 84 | LongContextAttention, 85 | set_seq_parallel_pg, 86 | EXTRACT_FUNC_DICT 87 | ) 88 | from yunchang.kernels import AttnType 89 | 90 | sp_ulysses_degree = 2 91 | sp_ring_degree = 4 92 | 93 | # support world_size = 8 94 | set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size) 95 | 96 | # attn_type could be FA, FA3, TORCH. 97 | longctx_attn = LongContextAttention(ring_impl_type="zigzag", attn_type=AttnType.FA) 98 | 99 | # if you use NPUs, where no flash_attn is supported, you can use the following code. 100 | # LongContextAttention(ring_impl_type="zigzag", attn_type=AttnType.TORCH) 101 | 102 | # extract a local shard for the global Q, K, V. 103 | local_q = EXTRACT_FUNC_DICT["zigzag"]( 104 | Q, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 105 | ).detach().clone() 106 | ... 107 | 108 | local_out = usp_attn( 109 | local_q, 110 | local_k, 111 | local_v, 112 | dropout_p=dropout_p, 113 | causal=True, # zigzag and stripe is load balance strategy for causal=True 114 | window_size=window_size, 115 | softcap=0.0, 116 | alibi_slopes=alibi_slopes, 117 | deterministic=deterministic, 118 | return_attn_probs=True, 119 | ) 120 | 121 | ``` 122 | 123 | ### 3.Test 124 | 125 | if you do not install yuanchang, add the project root directory to the PYTHONPATH: 126 | ``` 127 | export PYTHONPATH=$PWD:$PYTHONPATH 128 | ```` 129 | 130 | - FlashAttn/Torch Test 131 | ```bash 132 | torchrun --nproc_per_node=4 ./test/test_hybrid_attn.py --sp_ulysses_degree 2 --ring_impl_type "zigzag" --causal --attn_impl fa --use_bwd 133 | torchrun --nproc_per_node=4 ./test/test_hybrid_attn.py --sp_ulysses_degree 2 --ring_impl_type "zigzag" --causal --attn_impl torch 134 | torchrun --nproc_per_node 8 test/test_hybrid_qkvpacked_attn.py 135 | ``` 136 | 137 | - Sage/SpargeAttention Test (fwd only, Non causal) 138 | 139 | you need install [SpargeAttn](https://github.com/thu-ml/SpargeAttn) and [SageAttention](https://github.com/thu-ml/SageAttention) from source. 140 | 141 | ```bash 142 | torchrun --nproc_per_node=4 ./test/test_hybrid_attn.py --sp_ulysses_degree 2 --attn_impl sage_fp8 143 | ``` 144 | 145 | ```bash 146 | torchrun --nproc_per_node=4 ./test/test_hybrid_attn.py --sp_ulysses_degree 4 --attn_impl sparse_sage --sparse_sage_tune_mode 147 | ``` 148 | 149 | - FlashInfer Test (fwd only) 150 | 151 | Install FlashInfer from [here](https://docs.flashinfer.ai/installation.html#quick-start). 152 | 153 | ```bash 154 | torchrun --nproc_per_node=4 --master_port=1234 ./test/test_hybrid_attn.py --sp_ulysses_degree 2 --ring_impl_type 'basic_flashinfer' --attn_impl flashinfer 155 | ``` 156 | 157 | ### 4. Verified in Megatron-LM 158 | The loss curves for Data Parallel (DP) and Unified Sequence Parallel (ulysses=2+ring=2) are closely aligned, as illustrated in the figure. This alignment confirms the accuracy of the unified sequence parallel. 159 | 160 |

161 | 162 |

163 | 164 | When utilizing load-balance Ring Attention with a causal mask, it is essential to reorder the Query tensors using the [EXTRACT_FUNC_DICT](./yunchang/comm/extract_local.py) function. 165 | 166 | In Megatron-LM, you can reorder the input tokens before feeding them into the model and apply the same reordering to the RoPE parameters. For detailed instructions, please refer to our paper. 167 | 168 | For an example implementation, you can check out this [PR](https://github.com/FlagOpen/FlagScale/commit/f98ee1e293bd906cc77f512f7a884b2030c10a12), which integrates USP into a BAAI's Megatron-LM framework. 169 | 170 | ### 6. Benchmark 171 | 172 | 173 | ```bash 174 | bash ./scripts/run_qkvpack_compare.sh 175 | ``` 176 | 177 | On an 8xA100 NVLink machine, the benchmark results are as follows: 178 | 179 |

180 | 181 |

182 | 183 | On an 8xL20 PCIe machine and a 4xA100 PCIe machine, the benchmark results are as follows: 184 | 185 |

186 | 187 |

188 | 189 | Some Conclusions: 190 | 191 | 1. If the head number is enough, Ulysses outperforms Ring-Attention. The All-to-All communication of Ulysses is highly efficient within a single machine, with a very low overhead ratio. In contrast, Ring splits computation and communication, which increases the overall of computation time, and even with complete overlap, it is slower than Ulysses. 192 | 193 | 2. QKV packed (`LongContextAttentionQKVPacked`) is better than the QKV no packed (`LongContextAttention`) version, with the difference becoming more pronounced as the sequence length decreases. MAQ and GQA can only use the no packed version. 194 | 195 | 3. Among the variants of the Ring-Attention implementation, `zigzag` and `stripe` perform better than `basic`. Typically, zigzag is slightly better than stripe, but as the sequence length increases, the difference between zigzag and stripe becomes less noticeable. It is worth noting that both zigzag and stripe have specific layout requirements for the sequence dimension. 196 | 197 | 4. Hybrid parallelism works well to heterogeneous network devices. For example, on an 8-GPU L20 setup, the optimal performance is achieved when ulysess_degree is set to 2 and ring_degree is set to 4. 198 | 199 | ### 7. Best Practice for 4D Parallelism 200 | 201 | We analyze the impact of introducing Sequnce Parallelism to Data/ZeRO/Tensor/Pipeline Parallelism in a technique report, which can be found at [here](https://arxiv.org/abs/2405.07719). 202 | 203 | Some best practices are listed here: 204 | 205 | 1. We suggest using Unified-SP in place of SP-Ring and SP-Ulysses, as it encompasses the capabilities of both while offering additional benefits. 206 | 207 | 2. DP (data parallelism) vs SP: We suggest prioritizing the use of DP over SP if possible. 208 | Only when the batch size (bs) is insufficient for partitioning should one consider whether to employ SP 209 | 210 | 3. Utilizing SP, it should always be used in conjunction wit ZeRO-1/2. 211 | 212 | 4. Unified-SP has lower communication cost than Tensor Parallel with megatron-lm sequence parallelism (TP-sp)! You can use Unified-SP to replace TP for better speed. However, now switching TP (tensor parallelism) to SP+ZeRO2 cannot increase the sequence length in training. SP+ZeRO3 can train a similar sequence length as TP-sp. We suggest that SP may have an advantage over TP when employing GQA in terms of communication cost, as GQA can reduce the communication cost of SP without affecting TP. 213 | 214 | 5. Setting a higher parallel degree of SP parallelism is possible, which may need to set a large ring degree when the head number is limited, to train a long sequence across a greater number of computational devices. But TP could not be set a high parallel. 215 | 216 | 217 | 218 | ### 8. Projects apply USP 219 | I am honored that this repository has contributed to the following projects: 220 | 221 | 1. [xdit-project/xDiT](https://github.com/xdit-project/xDiT) 222 | 2. [NVlabs/VILA](https://github.com/NVlabs/VILA/blob/main/LongVILA.md) 223 | 3. [feifeibear/Odysseus-Transformer](https://github.com/feifeibear/Odysseus-Transformer) 224 | 4. [Ascend/AscendSpeed](https://gitee.com/ascend/AscendSpeed/blob/master/docs/features/hybrid-context-parallel.md) 225 | 5. [jzhang38/EasyContext](https://github.com/jzhang38/EasyContext) 226 | 6. [FlagOpen/FlagScale](https://github.com/FlagOpen/FlagScale/commit/f98ee1e293bd906cc77f512f7a884b2030c10a12) 227 | 7. [zhiyuanhubj/LongRecipe](https://github.com/zhiyuanhubj/LongRecipe) 228 | 8. [NVIDIA/TransformerEngine](https://github.com/NVIDIA/TransformerEngine/blob/54aa12a9a1f166c53a20f17f309adeab5698f5f6/transformer_engine/pytorch/attention.py#L1542) 229 | 9. [xdit-project/mochi-xdit](https://github.com/xdit-project/mochi-xdit) 230 | 231 | ### 9. Cite Us 232 | 233 | [USP: A Unified Sequence Parallelism Approach for Long Context Generative AI](https://arxiv.org/abs/2405.07719) 234 | 235 | ``` 236 | @article{fang2024unified, 237 | title={A Unified Sequence Parallelism Approach for Long Context Generative AI}, 238 | author={Fang, Jiarui and Zhao, Shangchun}, 239 | journal={arXiv preprint arXiv:2405.07719}, 240 | year={2024} 241 | } 242 | 243 | ``` 244 | -------------------------------------------------------------------------------- /patches/Megatron-DeepSpeed.patch: -------------------------------------------------------------------------------- 1 | From 01eb56347633f5f016ed9c9aa62b3e49d7cd37fa Mon Sep 17 00:00:00 2001 2 | From: root 3 | Date: Fri, 19 Apr 2024 06:19:51 +0000 4 | Subject: [PATCH 1/2] [cp] add hybrid context parallel 5 | 6 | --- 7 | megatron/arguments.py | 2 + 8 | megatron/core/parallel_state.py | 20 ++++ 9 | megatron/initialize.py | 3 +- 10 | megatron/model/transformer.py | 14 ++- 11 | start_gpt.sh | 176 ++++++++++++++++++++++++++++++++ 12 | 5 files changed, 211 insertions(+), 4 deletions(-) 13 | create mode 100755 start_gpt.sh 14 | 15 | diff --git a/megatron/arguments.py b/megatron/arguments.py 16 | index 631d4b1..a91db90 100644 17 | --- a/megatron/arguments.py 18 | +++ b/megatron/arguments.py 19 | @@ -951,6 +951,8 @@ def _add_training_args(parser): 20 | help='Enable Megatron-LM\'s sequence parallel optimization.') 21 | group.add_argument('--ds-sequence-parallel-size', type=int, default=1, 22 | help='Enable DeepSpeed\'s sequence parallel. Cannot be combined with "--sequence-parallel", which enables Megatron-LM\'s sequence parallel.') 23 | + group.add_argument('--ds-ring-sequence-parallel-size', type=int, default=1, 24 | + help='Ring sequenceparallel degree.') 25 | group.add_argument('--force-ds-sequence-parallel', action='store_true', 26 | help='use DeepSpeed sequence parallelism regardless of sequence parallel size.') 27 | group.add_argument('--no-gradient-accumulation-fusion', 28 | diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py 29 | index 819760e..2f2aad9 100644 30 | --- a/megatron/core/parallel_state.py 31 | +++ b/megatron/core/parallel_state.py 32 | @@ -7,6 +7,12 @@ from typing import Optional 33 | 34 | from .utils import GlobalMemoryBuffer 35 | 36 | +try: 37 | + from yunchang import set_seq_parallel_pg 38 | + from yunchang.globals import PROCESS_GROUP as YUNCHANG_PROCESS_GROUP 39 | +except ImportError: 40 | + set_seq_parallel_pg = None 41 | + 42 | # Intra-layer model parallel group that the current rank belongs to. 43 | _TENSOR_MODEL_PARALLEL_GROUP = None 44 | # Inter-layer model parallel group that the current rank belongs to. 45 | @@ -70,6 +76,7 @@ def initialize_model_parallel( 46 | pipeline_model_parallel_split_rank: Optional[int] = None, 47 | use_fp8: bool = False, 48 | use_distributed_optimizer: bool = False, 49 | + ring_parallel_size: int =1, 50 | ) -> None: 51 | """Initialize model data parallel groups. 52 | 53 | @@ -213,6 +220,15 @@ def initialize_model_parallel( 54 | if rank in ranks: 55 | _SEQUENCE_PARALLEL_GROUP = group 56 | 57 | + ring_degree = ring_parallel_size 58 | + ulysse_degree = sequence_parallel_size // ring_parallel_size 59 | + assert sequence_parallel_size % ulysse_degree == 0, f"sequence_parallel_size {sequence_parallel_size} is not divisible by ulysse_degree {ulysse_degree}" 60 | + assert sequence_parallel_size == ring_degree * ulysse_degree, f"sequence_parallel_size {sequence_parallel_size} is not equal to ring_degree {ring_degree} * ulysse_degree {ulysse_degree}" 61 | + if set_seq_parallel_pg is not None: 62 | + set_seq_parallel_pg(ulysse_degree, ring_degree, rank, world_size) 63 | + else: 64 | + print("set_seq_parallel_pg is not available") 65 | + 66 | # Build the sequence data parallel groups. 67 | global _SEQUENCE_DATA_PARALLEL_GROUP 68 | assert _SEQUENCE_DATA_PARALLEL_GROUP is None, \ 69 | @@ -445,6 +461,10 @@ def get_model_parallel_world_size(): 70 | assert get_pipeline_model_parallel_world_size() == 1, "legacy get_model_parallel_world_size is only supported if PP is disabled" 71 | return get_tensor_model_parallel_world_size() 72 | 73 | +def get_ulysses_sequence_parallel_world_size(): 74 | + """Return world size for the ulysses sequence parallel group.""" 75 | + return torch.distributed.get_world_size(group=YUNCHANG_PROCESS_GROUP.ULYSSES_PG) 76 | + 77 | def get_sequence_parallel_world_size(): 78 | """Return world size for the sequence parallel group.""" 79 | global _SEQUENCE_PARALLEL_WORLD_SIZE 80 | diff --git a/megatron/initialize.py b/megatron/initialize.py 81 | index 31f26c5..8b021be 100644 82 | --- a/megatron/initialize.py 83 | +++ b/megatron/initialize.py 84 | @@ -244,7 +244,8 @@ def _initialize_distributed(): 85 | args.ds_sequence_parallel_size, 86 | args.virtual_pipeline_model_parallel_size, 87 | args.pipeline_model_parallel_split_rank, 88 | - use_distributed_optimizer=args.use_distributed_optimizer) 89 | + use_distributed_optimizer=args.use_distributed_optimizer, 90 | + ring_parallel_size = args.ds_ring_sequence_parallel_size) 91 | if args.rank == 0: 92 | print(f'> initialized tensor model parallel with size ' 93 | f'{mpu.get_tensor_model_parallel_world_size()}') 94 | diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py 95 | index e75f13a..bcac2fb 100644 96 | --- a/megatron/model/transformer.py 97 | +++ b/megatron/model/transformer.py 98 | @@ -50,6 +50,10 @@ except ImportError: 99 | FlashAttentionBuilder = get_accelerator().get_op_builder("FlashAttentionBuilder") 100 | flash_attn_builder = None 101 | 102 | +try: 103 | + from yunchang import UlyssesAttention, LongContextAttention, set_seq_parallel_pg 104 | +except ImportError: 105 | + UlyssesAttention = None 106 | 107 | """ We use the following notation throughout this file: 108 | h: hidden size 109 | @@ -597,8 +601,11 @@ class ParallelAttention(MegatronModule): 110 | or args.force_ds_sequence_parallel 111 | if self.enable_ds_sequence_parallel: 112 | assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version' 113 | - assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0 114 | - self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group()) 115 | + # assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0 116 | + # self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group()) 117 | + assert args.num_attention_heads % parallel_state.get_ulysses_sequence_parallel_world_size() == 0, \ 118 | + f"Number of attention heads {args.num_attention_heads} must be divisible by the number of Ulysses sequence parallel partitions {parallel_state.get_ulysses_sequence_parallel_world_size()}" 119 | + self.dist_attn = LongContextAttention() 120 | else: 121 | if self.use_flash_attn: 122 | self.core_attention_flash = local_attn 123 | @@ -616,7 +623,6 @@ class ParallelAttention(MegatronModule): 124 | input_is_parallel=True, 125 | skip_bias_add=True) 126 | 127 | - 128 | def _checkpointed_attention_forward(self, query_layer, key_layer, 129 | value_layer, attention_mask, 130 | rotary_pos_emb=None): 131 | @@ -808,11 +814,13 @@ class ParallelAttention(MegatronModule): 132 | query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() 133 | for x in (query_layer, key_layer, value_layer)] 134 | 135 | + # print(f"fjr-debug use fa query_layer {query_layer.shape}") 136 | context_layer = self.dist_attn(query_layer, key_layer, value_layer) 137 | 138 | if not self.use_flash_attn_triton: 139 | context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() 140 | else: 141 | + # print(f"fjr-debug not use fa query_layer {query_layer.shape}") 142 | context_layer = self.dist_attn(query_layer, key_layer, value_layer, attention_mask) 143 | else: 144 | if self.use_flash_attn: 145 | diff --git a/start_gpt.sh b/start_gpt.sh 146 | new file mode 100755 147 | index 0000000..17efb7e 148 | --- /dev/null 149 | +++ b/start_gpt.sh 150 | @@ -0,0 +1,176 @@ 151 | +#! /bin/bash 152 | + 153 | +#################################################### 154 | +# 155 | +# usage: 156 | +# bash start.sh 157 | +# 158 | +# supported model size: {7, 13, 175} 159 | +# 160 | +#################################################### 161 | + 162 | + 163 | +# env var 164 | +export CUDA_DEVICE_MAX_CONNECTIONS=1 165 | + 166 | +# nccl settings 167 | +#export NCCL_DEBUG=INFO 168 | +export NCCL_SOCKET_IFNAME=eth0 169 | +export NCCL_IB_GID_INDEX=3 170 | +export NCCL_IB_DISABLE=0 171 | +export NCCL_NET_GDR_LEVEL=2 172 | +export NCCL_IB_QPS_PER_CONNECTION=4 173 | +export NCCL_IB_TC=160 174 | +export NCCL_IB_TIMEOUT=22 175 | + 176 | +export GLOO_SOCKET_IFNAME=eth0 177 | + 178 | +export PYTHONPATH=$PWD:$PYTHONPATH 179 | + 180 | +# data settings 181 | +BASE_DATA_PATH=/data/datasets/gpt-data/ 182 | +DATA_PATH=$BASE_DATA_PATH/my-gpt2_text_document 183 | +VOCAB_FILE=$BASE_DATA_PATH/gpt2-vocab.json 184 | +MERGE_FILE=$BASE_DATA_PATH/gpt2-merges.txt 185 | +CHECKPOINT_PATH=./output/ 186 | + 187 | + 188 | +ZERO_STAGE=3 189 | + 190 | +# create DS config 191 | +DS_CONFIG=ds_config.json 192 | +DATA_TYPE= 193 | +if [ ${ZERO_STAGE} -eq 1 ]; then 194 | + DATA_TYPE=" 195 | + \"data_types\":{ 196 | + \"grad_accum_dtype\":\"fp32\" 197 | + }, 198 | + " 199 | +fi 200 | + 201 | + 202 | +# model settings 203 | +SEQ_LEN=8192 204 | +MAX_SEQ_LEN=8192 205 | +MODEL_SIZE=${1:-7} 206 | +if [ $MODEL_SIZE == "7" ]; then 207 | + NUM_LAYERS=32 208 | + HIDDEN_SIZE=4096 209 | + NUM_ATTN_HEADS=32 210 | + MICRO_BATCH_SIZE=1 211 | + TP=1 212 | + PP=1 213 | + CP=4 214 | + RCP=2 215 | + MICRO_BATCH_NUM=32 216 | +elif [ $MODEL_SIZE == "13" ]; then 217 | + NUM_LAYERS=40 218 | + HIDDEN_SIZE=5120 219 | + NUM_ATTN_HEADS=40 220 | + MICRO_BATCH_SIZE=1 221 | + TP=1 222 | + PP=2 223 | + MICRO_BATCH_NUM=64 224 | +elif [ $MODEL_SIZE == "175" ]; then 225 | + NUM_LAYERS=96 226 | + HIDDEN_SIZE=12288 227 | + NUM_ATTN_HEADS=96 228 | + MICRO_BATCH_SIZE=1 229 | + TP=8 230 | + PP=4 231 | + MICRO_BATCH_NUM=256 232 | +else 233 | + echo "ERROR: Please supplement new model configuration to test!" 234 | + exit -1 235 | +fi 236 | + 237 | +#fp8 settings 238 | +ENABLE_FP8=false 239 | +if [ $ENABLE_FP8 == "true" ]; then 240 | + FP8_OPTS="--transformer-impl transformer_engine --fp8-format hybrid " 241 | + DT="fp8" 242 | +else 243 | + FP8_OPTS="" 244 | + DT="bf16" 245 | +fi 246 | + 247 | +# node settings 248 | +MASTER_ADDR=${2:-localhost} 249 | +MASTER_PORT=6000 250 | +NNODES=${3:-1} 251 | +NODE_RANK=${4:-0} 252 | +GPUS_PER_NODE=8 253 | +WORLD_SIZE=$(( $GPUS_PER_NODE * $NNODES )) 254 | +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" 255 | + 256 | +DP=$(( $WORLD_SIZE / $TP / $PP / $CP)) 257 | +GLOBAL_BATCH_SIZE=$(( $DP * $MICRO_BATCH_SIZE * $MICRO_BATCH_NUM )) 258 | + 259 | + 260 | +cat << EOT > $DS_CONFIG 261 | +{ 262 | + "train_batch_size" : $GLOBAL_BATCH_SIZE, 263 | + "train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE, 264 | + "steps_per_print": 1, 265 | + "gradient_clipping": 1.0, 266 | + "zero_optimization": { 267 | + "stage": $ZERO_STAGE 268 | + }, 269 | + "bf16": { 270 | + "enabled": true, 271 | + "accumulate_grads_via_hooks": true 272 | + }, 273 | + "fp16": {"enabled": false}, 274 | + "wall_clock_breakdown": false 275 | +} 276 | +EOT 277 | + 278 | + 279 | + 280 | +CMD="torchrun $DISTRIBUTED_ARGS \ 281 | + pretrain_gpt.py \ 282 | + --tensor-model-parallel-size $TP \ 283 | + --pipeline-model-parallel-size $PP \ 284 | + --ds-sequence-parallel-size $CP \ 285 | + --ds-ring-sequence-parallel-size $RCP \ 286 | + --num-layers $NUM_LAYERS \ 287 | + --hidden-size $HIDDEN_SIZE \ 288 | + --num-attention-heads $NUM_ATTN_HEADS \ 289 | + --micro-batch-size $MICRO_BATCH_SIZE \ 290 | + --global-batch-size $GLOBAL_BATCH_SIZE \ 291 | + --seq-length $SEQ_LEN \ 292 | + --max-position-embeddings $SEQ_LEN \ 293 | + --train-iters 500 \ 294 | + --lr-decay-iters 320000 \ 295 | + --save $CHECKPOINT_PATH \ 296 | + --data-path $DATA_PATH \ 297 | + --vocab-file $VOCAB_FILE \ 298 | + --merge-file $MERGE_FILE \ 299 | + --split 949,50,1 \ 300 | + --distributed-backend nccl \ 301 | + --lr 0.00015 \ 302 | + --lr-decay-style cosine \ 303 | + --min-lr 1.0e-5 \ 304 | + --weight-decay 1e-2 \ 305 | + --clip-grad 1.0 \ 306 | + --lr-warmup-fraction .01 \ 307 | + --log-interval 1 \ 308 | + --save-interval 10000 \ 309 | + --eval-interval 10000 \ 310 | + --exit-interval 10000 \ 311 | + --eval-iters 1000 \ 312 | + --use-flash-attn-v2 \ 313 | + --recompute-activations \ 314 | + --use-distributed-optimizer \ 315 | + --bf16 \ 316 | + $FP8_OPTS \ 317 | + --deepspeed \ 318 | + --deepspeed_config $DS_CONFIG \ 319 | + --zero-stage=$ZERO_STAGE \ 320 | + --no-pipeline-parallel \ 321 | + " 322 | + 323 | +echo ${CMD} 2>&1 | tee megatron_gpt-${MODEL_SIZE}B_tp${TP}_pp${PP}_dp${DP}_mb${MICRO_BATCH_SIZE}_gb${GLOBAL_BATCH_SIZE}_${DT}.log 324 | +eval ${CMD} 2>&1 | tee -a megatron_gpt-${MODEL_SIZE}B_tp${TP}_pp${PP}_dp${DP}_mb${MICRO_BATCH_SIZE}_gb${GLOBAL_BATCH_SIZE}_${DT}.log 325 | -- 326 | 2.34.1 327 | 328 | --------------------------------------------------------------------------------