├── yunchang ├── comm │ ├── __init__.py │ ├── extract_local.py │ └── all_to_all.py ├── ulysses │ ├── __init__.py │ └── attn_layer.py ├── __init__.py ├── hybrid │ ├── __init__.py │ ├── utils.py │ ├── async_attn_layer.py │ └── attn_layer.py ├── ring │ ├── __init__.py │ ├── triton_utils.py │ ├── utils.py │ ├── ring_pytorch_attn.py │ ├── ring_npu_flash_attn.py │ ├── ring_flashinfer_attn.py │ ├── ring_flash_attn.py │ ├── ring_flash_attn_varlen.py │ ├── zigzag_ring_flash_attn.py │ └── stripe_flash_attn.py ├── globals.py └── kernels │ └── __init__.py ├── media ├── gqa.png ├── loss.png ├── ring.png ├── usp.png ├── ulysses.png ├── usp_fa.png ├── yun_chang.jpg ├── long_ctx_h2.png ├── long_ctx_h8.png ├── pcie_machine.jpg └── benchmark_results.png ├── scripts ├── run_npu.sh ├── run_hybrid_npu.sh ├── run_dit.sh ├── run_gqa.sh └── run_qkvpack_compare.sh ├── .pre-commit-config.yaml ├── pyproject.toml ├── .github └── workflows │ └── python-publish.yml ├── docs └── install_amd.md ├── .gitignore ├── test ├── test_ulysses_attn_npu.py ├── test_ulysses_attn.py ├── test_utils.py ├── test_hybrid_qkvpacked_attn.py └── test_hybrid_attn_npu.py ├── benchmark ├── benchmark_longctx_qkvpacked.py └── benchmark_longctx.py ├── LICENSE.txt ├── README.md └── patches └── Megatron-DeepSpeed.patch /yunchang/comm/__init__.py: -------------------------------------------------------------------------------- 1 | from .all_to_all import * 2 | from .extract_local import * 3 | -------------------------------------------------------------------------------- /media/gqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/gqa.png -------------------------------------------------------------------------------- /media/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/loss.png -------------------------------------------------------------------------------- /media/ring.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/ring.png -------------------------------------------------------------------------------- /media/usp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/usp.png -------------------------------------------------------------------------------- /media/ulysses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/ulysses.png -------------------------------------------------------------------------------- /media/usp_fa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/usp_fa.png -------------------------------------------------------------------------------- /media/yun_chang.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/yun_chang.jpg -------------------------------------------------------------------------------- /media/long_ctx_h2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/long_ctx_h2.png -------------------------------------------------------------------------------- /media/long_ctx_h8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/long_ctx_h8.png -------------------------------------------------------------------------------- /media/pcie_machine.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/pcie_machine.jpg -------------------------------------------------------------------------------- /yunchang/ulysses/__init__.py: -------------------------------------------------------------------------------- 1 | from .attn_layer import UlyssesAttention 2 | 3 | __all__ = ['UlyssesAttention'] 4 | -------------------------------------------------------------------------------- /media/benchmark_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/long-context-attention/HEAD/media/benchmark_results.png -------------------------------------------------------------------------------- /scripts/run_npu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source /usr/local/Ascend/ascend-toolkit/latest/bin/setenv.bash 3 | export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 4 | 5 | export MASTER_ADDR=127.0.0.1 6 | export MASTER_PORT=21289 7 | export HCCL_IF_BASE_PORT=64199 8 | 9 | torchrun --master_port=29600 --nproc_per_node 4 test/test_ulysses_attn_npu.py 10 | -------------------------------------------------------------------------------- /yunchang/__init__.py: -------------------------------------------------------------------------------- 1 | from .hybrid import * 2 | from .ring import * 3 | from .ulysses import * 4 | from .globals import set_seq_parallel_pg 5 | from .comm.extract_local import ( 6 | stripe_extract_local, 7 | basic_extract_local, 8 | zigzag_extract_local, 9 | EXTRACT_FUNC_DICT, 10 | ) 11 | 12 | __version__ = "0.6.3.post1" 13 | -------------------------------------------------------------------------------- /scripts/run_hybrid_npu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source /usr/local/Ascend/ascend-toolkit/latest/bin/setenv.bash 3 | export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 4 | 5 | export MASTER_ADDR=127.0.0.1 6 | export MASTER_PORT=21289 7 | export HCCL_IF_BASE_PORT=64199 8 | 9 | torchrun --master_port=29600 --nproc_per_node 4 test/test_hybrid_attn_npu.py --use_bwd 10 | -------------------------------------------------------------------------------- /yunchang/hybrid/__init__.py: -------------------------------------------------------------------------------- 1 | from .attn_layer import LongContextAttention, LongContextAttentionQKVPacked 2 | from .async_attn_layer import AsyncLongContextAttention 3 | 4 | from .utils import RING_IMPL_QKVPACKED_DICT 5 | 6 | __all__ = [ 7 | "LongContextAttention", 8 | "LongContextAttentionQKVPacked", 9 | "RING_IMPL_QKVPACKED_DICT", 10 | "AsyncLongContextAttention", 11 | ] 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # Using this mirror lets us use mypyc-compiled black, which is about 2x faster 3 | - repo: https://github.com/psf/black-pre-commit-mirror 4 | rev: 24.2.0 5 | hooks: 6 | - id: black 7 | # It is recommended to specify the latest version of Python 8 | # supported by your project here, or alternatively use 9 | # pre-commit's default_language_version, see 10 | # https://pre-commit.com/#top_level-default_language_version 11 | language_version: python3.10 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "yunchang" 7 | version = "0.6.3.post1" 8 | authors = [ 9 | { name="Jiarui Fang", email="fangjiarui123@gmail.com" }, 10 | ] 11 | description = "a package for long context attention" 12 | readme = "README.md" 13 | requires-python = ">=3.7" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: Apache Software License", 17 | "Operating System :: OS Independent", 18 | ] 19 | dependencies = [] 20 | 21 | [project.optional-dependencies] 22 | flash = ["flash-attn>=2.6.0"] 23 | 24 | [project.urls] 25 | "Homepage" = "https://github.com/feifeibear/long-context-attention" 26 | "Bug Tracker" = "https://github.com/feifeibear/long-context-attention/issues" 27 | 28 | [tool.setuptools] 29 | packages = {find = {where = ["."], exclude = ["tests*", "benchmark*", "media*", "docs*", "patches*"]}} 30 | -------------------------------------------------------------------------------- /yunchang/hybrid/utils.py: -------------------------------------------------------------------------------- 1 | from yunchang.ring import ( 2 | ring_flash_attn_func, 3 | ring_flash_attn_qkvpacked_func, 4 | zigzag_ring_flash_attn_func, 5 | zigzag_ring_flash_attn_qkvpacked_func, 6 | stripe_flash_attn_func, 7 | stripe_flash_attn_qkvpacked_func, 8 | ring_pytorch_attn_func, 9 | ring_flashinfer_attn_func, 10 | ring_flashinfer_attn_qkvpacked_func, 11 | ring_npu_flash_attn_func 12 | ) 13 | 14 | RING_IMPL_DICT = { 15 | "basic": ring_flash_attn_func, 16 | "zigzag": zigzag_ring_flash_attn_func, 17 | "strip": stripe_flash_attn_func, 18 | "basic_pytorch": ring_pytorch_attn_func, 19 | "basic_flashinfer": ring_flashinfer_attn_func, 20 | "basic_npu": ring_npu_flash_attn_func 21 | } 22 | 23 | RING_IMPL_QKVPACKED_DICT = { 24 | "basic": ring_flash_attn_qkvpacked_func, 25 | "zigzag": zigzag_ring_flash_attn_qkvpacked_func, 26 | "strip": stripe_flash_attn_qkvpacked_func, 27 | "basic_flashinfer": ring_flashinfer_attn_qkvpacked_func, 28 | } 29 | -------------------------------------------------------------------------------- /yunchang/ring/__init__.py: -------------------------------------------------------------------------------- 1 | from .ring_flash_attn import ( 2 | ring_flash_attn_func, 3 | ring_flash_attn_kvpacked_func, 4 | ring_flash_attn_qkvpacked_func, 5 | ) 6 | from .ring_flash_attn_varlen import ( 7 | ring_flash_attn_varlen_func, 8 | ring_flash_attn_varlen_kvpacked_func, 9 | ring_flash_attn_varlen_qkvpacked_func, 10 | ) 11 | from .zigzag_ring_flash_attn import ( 12 | zigzag_ring_flash_attn_func, 13 | zigzag_ring_flash_attn_kvpacked_func, 14 | zigzag_ring_flash_attn_qkvpacked_func, 15 | ) 16 | from .zigzag_ring_flash_attn_varlen import ( 17 | zigzag_ring_flash_attn_varlen_func, 18 | zigzag_ring_flash_attn_varlen_qkvpacked_func, 19 | zigzag_ring_flash_attn_varlen_qkvpacked_func, 20 | ) 21 | from .stripe_flash_attn import ( 22 | stripe_flash_attn_func, 23 | stripe_flash_attn_kvpacked_func, 24 | stripe_flash_attn_qkvpacked_func, 25 | ) 26 | 27 | from .ring_pytorch_attn import ( 28 | ring_pytorch_attn_func, 29 | ) 30 | 31 | from .ring_flashinfer_attn import ( 32 | ring_flashinfer_attn_func, 33 | ring_flashinfer_attn_kvpacked_func, 34 | ring_flashinfer_attn_qkvpacked_func, 35 | ) 36 | 37 | from .ring_npu_flash_attn import ( 38 | ring_npu_flash_attn_func, 39 | ) -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build setuptools wheel 33 | - name: Build package 34 | run: | 35 | python -m build 36 | - name: Publish package 37 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 38 | with: 39 | user: __token__ 40 | password: ${{ secrets.PYPI_API_TOKEN }} 41 | -------------------------------------------------------------------------------- /scripts/run_dit.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PWD:$PYTHONPATH# export NCCL_PXN_DISABLE=1 2 | # export NCCL_DEBUG=INFO 3 | # export NCCL_SOCKET_IFNAME=eth0 4 | # export NCCL_IB_GID_INDEX=3 5 | # export NCCL_IB_DISABLE=0 6 | # export NCCL_NET_GDR_LEVEL=2 7 | # export NCCL_IB_QPS_PER_CONNECTION=4 8 | # export NCCL_IB_TC=160 9 | # export NCCL_IB_TIMEOUT=22 10 | 11 | export CUDA_DEVICE_MAX_CONNECTIONS=1 12 | 13 | # nccl settings 14 | #export NCCL_DEBUG=INFO 15 | export NCCL_SOCKET_IFNAME=eth0 16 | export NCCL_IB_GID_INDEX=3 17 | export NCCL_IB_DISABLE=0 18 | export NCCL_NET_GDR_LEVEL=2 19 | export NCCL_IB_QPS_PER_CONNECTION=4 20 | export NCCL_IB_TC=160 21 | export NCCL_IB_TIMEOUT=22 22 | 23 | export GLOO_SOCKET_IFNAME=eth0 24 | 25 | 26 | # comment this line fwd+bwd 27 | # FWD_FLAG="--fwd_only" 28 | 29 | NHEADS=24 30 | SEQLEN=1024 31 | GROUP_NUM=1 32 | GPU_NUM=2 33 | HEAD_SIZE=128 34 | ULYSSES_DEGREE=8 35 | 36 | NRANK=${NRANK:-0} 37 | # RING_IMPL_TYPE="zigzag" 38 | 39 | # make sure NHEADS // GROUP_NUM % ULYSSES_DEGREE == 0 40 | for attn_type in "torch" "fa" "fa3"; do 41 | for ULYSSES_DEGREE in 2; do 42 | for RING_IMPL_TYPE in "basic"; do 43 | torchrun --nproc_per_node $GPU_NUM --node_rank $NRANK benchmark/benchmark_longctx.py \ 44 | --nheads $NHEADS --group_num $GROUP_NUM --batch_size 1 $FWD_FLAG --seq_len $SEQLEN --head_size $HEAD_SIZE \ 45 | --ulysses_degree $ULYSSES_DEGREE --ring_impl_type $RING_IMPL_TYPE --no_causal --attn_type $attn_type --use_ulysses 46 | done 47 | done 48 | done 49 | -------------------------------------------------------------------------------- /scripts/run_gqa.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PWD:$PYTHONPATH 2 | # export NCCL_PXN_DISABLE=1 3 | # export NCCL_DEBUG=INFO 4 | # export NCCL_SOCKET_IFNAME=eth0 5 | # export NCCL_IB_GID_INDEX=3 6 | # export NCCL_IB_DISABLE=0 7 | # export NCCL_NET_GDR_LEVEL=2 8 | # export NCCL_IB_QPS_PER_CONNECTION=4 9 | # export NCCL_IB_TC=160 10 | # export NCCL_IB_TIMEOUT=22 11 | 12 | export CUDA_DEVICE_MAX_CONNECTIONS=1 13 | 14 | # nccl settings 15 | #export NCCL_DEBUG=INFO 16 | export NCCL_SOCKET_IFNAME=eth0 17 | export NCCL_IB_GID_INDEX=3 18 | export NCCL_IB_DISABLE=0 19 | export NCCL_NET_GDR_LEVEL=2 20 | export NCCL_IB_QPS_PER_CONNECTION=4 21 | export NCCL_IB_TC=160 22 | export NCCL_IB_TIMEOUT=22 23 | 24 | export GLOO_SOCKET_IFNAME=eth0 25 | 26 | 27 | # comment this line fwd+bwd 28 | # FWD_FLAG="--fwd_only" 29 | 30 | NHEADS=64 31 | SEQLEN=131072 32 | GROUP_NUM=8 33 | GPU_NUM=8 34 | ULYSSES_DEGREE=1 35 | 36 | NRANK=${NRANK:-0} 37 | # RING_IMPL_TYPE="zigzag" 38 | 39 | # make sure NHEADS // GROUP_NUM % ULYSSES_DEGREE == 0 40 | for ULYSSES_DEGREE in 8 4 2 1; do 41 | for RING_IMPL_TYPE in "zigzag"; do 42 | torchrun --nproc_per_node $GPU_NUM --node_rank $NRANK benchmark/benchmark_longctx.py --nheads $NHEADS --group_num $GROUP_NUM --batch_size 1 $FWD_FLAG --seq_len $SEQLEN --ulysses_degree $ULYSSES_DEGREE --ring_impl_type $RING_IMPL_TYPE 43 | done 44 | done 45 | 46 | torchrun --nproc_per_node $GPU_NUM --node_rank $NRANK benchmark/benchmark_ring_func.py --nheads $NHEADS --group_num $GROUP_NUM --batch_size 1 $FWD_FLAG --seq_len $SEQLEN 47 | 48 | -------------------------------------------------------------------------------- /scripts/run_qkvpack_compare.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PWD:$PYTHONPATH 2 | 3 | export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" 4 | # export NCCL_PXN_DISABLE=1 5 | # export NCCL_DEBUG=INFO 6 | # export NCCL_SOCKET_IFNAME=eth0 7 | # export NCCL_IB_GID_INDEX=3 8 | # export NCCL_IB_DISABLE=0 9 | # export NCCL_NET_GDR_LEVEL=2 10 | # export NCCL_IB_QPS_PER_CONNECTION=4 11 | # export NCCL_IB_TC=160 12 | # export NCCL_IB_TIMEOUT=22 13 | # export NCCL_P2P=0 14 | 15 | # torchrun --nproc_per_node 8 test/test_hybrid_attn.py 16 | 17 | FWD_FLAG="--fwd_only" 18 | 19 | 20 | 21 | # SEQLEN=512 22 | # SEQLEN=1024 23 | # SEQLEN=4096 24 | # SEQLEN=512 25 | # SEQLEN=16384 26 | SEQLEN=32768 #128K 27 | 28 | NHEADS=32 29 | 30 | # HEAD_SIZE=128 31 | HEAD_SIZE=32 32 | GROUP_NUM=4 33 | BS=1 34 | 35 | GPU_NUM=8 36 | 37 | # USE_PROFILE="--use_profiler" 38 | 39 | # NHEADS // GROUP_NUM > ulysses_degree 40 | 41 | for RING_IMPL_TYPE in "basic" "zigzag" "strip"; do 42 | for ULYSSES_DEGREE in 8 4 2 1; do 43 | 44 | torchrun --nproc_per_node $GPU_NUM benchmark/benchmark_longctx_qkvpacked.py \ 45 | --nheads $NHEADS \ 46 | --batch_size $BS \ 47 | --seq_len $SEQLEN \ 48 | --head_size $HEAD_SIZE \ 49 | --ulysses_degree $ULYSSES_DEGREE \ 50 | --ring_impl_type $RING_IMPL_TYPE \ 51 | $FWD_FLAG 52 | 53 | torchrun --nproc_per_node $GPU_NUM benchmark/benchmark_longctx.py \ 54 | --nheads $NHEADS \ 55 | --group_num $GROUP_NUM \ 56 | --batch_size $BS \ 57 | --seq_len $SEQLEN \ 58 | --head_size $HEAD_SIZE \ 59 | --ulysses_degree $ULYSSES_DEGREE \ 60 | --ring_impl_type $RING_IMPL_TYPE \ 61 | $FWD_FLAG 62 | 63 | done 64 | done 65 | 66 | 67 | -------------------------------------------------------------------------------- /yunchang/comm/extract_local.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | from yunchang.globals import PROCESS_GROUP 5 | 6 | 7 | def stripe_extract_local(value, rank, world_size, rd, ud, *args, **kwargs): 8 | # ud at the highest dim 9 | input_dim = value.dim() 10 | assert input_dim >= 2 11 | 12 | batch_size, seqlen, *rest = value.shape 13 | 14 | assert dist.get_world_size(group=PROCESS_GROUP.RING_PG) == rd 15 | assert dist.get_world_size(group=PROCESS_GROUP.ULYSSES_PG) == ud 16 | 17 | value = value.reshape(batch_size, seqlen // rd, rd, -1).contiguous() 18 | value = value.transpose(1, 2).reshape(batch_size, seqlen, -1).contiguous() 19 | value = value.chunk(world_size, dim=1)[rank] 20 | 21 | new_shape = [batch_size, seqlen // world_size] + rest 22 | return value.reshape(new_shape) 23 | 24 | 25 | def basic_extract_local(value, rank, world_size, *args, **kwargs): 26 | return value.chunk(world_size, dim=1)[rank].detach().clone() 27 | 28 | 29 | def zigzag_extract_local(value, rank, world_size, rd, ud, dim=1, *args, **kwargs): 30 | """ 31 | value is a tensor of shape (bs, seqlen, ...) 32 | """ 33 | input_dim = value.dim() 34 | assert input_dim >= 2 35 | batch_size, seqlen, *rest = value.shape 36 | 37 | value_chunks = value.chunk(2 * rd, dim=dim) 38 | r_rank = dist.get_rank(group=PROCESS_GROUP.RING_PG) 39 | u_rank = dist.get_rank(group=PROCESS_GROUP.ULYSSES_PG) 40 | 41 | assert dist.get_world_size(group=PROCESS_GROUP.RING_PG) == rd 42 | assert dist.get_world_size(group=PROCESS_GROUP.ULYSSES_PG) == ud 43 | 44 | local_value = torch.cat( 45 | [value_chunks[r_rank], value_chunks[2 * rd - r_rank - 1]], dim=dim 46 | ).chunk(ud, dim=dim)[u_rank] 47 | 48 | new_shape = [batch_size, seqlen // world_size] + rest 49 | return local_value.reshape(new_shape).contiguous() 50 | 51 | 52 | 53 | EXTRACT_FUNC_DICT = { 54 | "basic": basic_extract_local, 55 | "strip": stripe_extract_local, 56 | "zigzag": zigzag_extract_local, 57 | "basic_pytorch": basic_extract_local, 58 | "basic_flashinfer": basic_extract_local, 59 | "basic_npu": basic_extract_local, 60 | } 61 | -------------------------------------------------------------------------------- /docs/install_amd.md: -------------------------------------------------------------------------------- 1 | ## Install for AMD GPU 2 | 3 | Supported GPU : MI300X, MI308X 4 | 5 | GPU arch : gfx942 6 | 7 | Step 1: prepare docker envrionment 8 | 9 | Tow recommended docker container to start with 10 | 11 | - rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 : hosted in dockerhub, no conda 12 | - [dockerhub repo](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/Dockerfile.rocm62.ubuntu-22.04) : Customerized Dockerfile with conda virtual env and develop kit support 13 | 14 | An example to create an docker container : 15 | 16 | ```bash 17 | # create docker container 18 | IMG=rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 19 | tag=py310-rocm6.2-distattn-dev 20 | 21 | docker_args=$(echo -it --privileged \ 22 | --name $tag \ 23 | --ulimit memlock=-1:-1 --net=host --cap-add=IPC_LOCK \ 24 | --device=/dev/kfd --device=/dev/dri \ 25 | --ipc=host \ 26 | --security-opt seccomp=unconfined \ 27 | --shm-size 16G \ 28 | --group-add video \ 29 | -v $(readlink -f `pwd`):/workspace \ 30 | --workdir /workspace \ 31 | --cpus=$((`nproc` / 2 - 1)) \ 32 | $IMG 33 | ) 34 | 35 | docker_args=($docker_args) 36 | 37 | docker container create "${docker_args[@]}" 38 | 39 | # start it 40 | docker start -a -i $tag 41 | ``` 42 | 43 | Update ROCM SDK using this [script](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/update_sdk.sh): 44 | 45 | ```bash 46 | # e.g.: 47 | ROCM_VERSION=6.3 bash rocm/update_sdk.sh 48 | ``` 49 | 50 | Step 2 : build from local. 51 | 52 | install flash_attn from source 53 | 54 | ```bash 55 | pip install flash_attn@git+https://git@github.com/Dao-AILab/flash-attention.git 56 | ``` 57 | 58 | then install yunchang 59 | 60 | > MAX_JOBS=$(nproc) pip install . -verbose 61 | 62 | **Features:** 63 | 64 | 1. No Limitation on the Number of Heads: Our approach does not impose a restriction on the number of heads, providing greater flexibility for various attention mechanisms. 65 | 66 | 2. Cover the Capability of either Ulysses and Ring: By setting the ulysses_degree to the sequence parallel degree, the system operates identically to Ulysses. Conversely, setting the ulysses_degree to 1 mirrors the functionality of Ring. 67 | 68 | 3. Enhanced Performance: We achieve superior performance benchmarks over both Ulysses and Ring, offering a more efficient solution for attention mechanism computations. 69 | 70 | 4. Compatibility with Advanced Parallel Strategies: LongContextAttention is fully compatible with other sophisticated parallelization techniques, including Tensor Parallelism, ZeRO, and Pipeline Parallelism, ensuring seamless integration with the latest advancements in parallel computing. 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /yunchang/ring/triton_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | 6 | @triton.jit 7 | def flatten_kernel( 8 | # pointers to matrices 9 | OUT, 10 | LSE, 11 | CU_SEQLENS, 12 | # strides 13 | stride_out_nheads, 14 | stride_out_seqlen, 15 | stride_lse_batch, 16 | stride_lse_nheads, 17 | stride_lse_seqlen, 18 | # meta-parameters 19 | BLOCK_M: tl.constexpr, 20 | ): 21 | pid_m = tl.program_id(axis=0) 22 | pid_batch = tl.program_id(axis=1) 23 | pid_head = tl.program_id(axis=2) 24 | 25 | start_idx = tl.load(CU_SEQLENS + pid_batch) 26 | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx 27 | LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads 28 | OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen 29 | 30 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 31 | 32 | LSE = LSE + rm[:, None] * stride_lse_seqlen 33 | x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) 34 | 35 | OUT = OUT + rm[:, None] * stride_out_seqlen 36 | tl.store(OUT, x, mask=rm[:, None] < seqlen) 37 | 38 | 39 | def flatten_varlen_lse(lse, cu_seqlens): 40 | """ 41 | Arguments: 42 | lse: (batch_size, nheads, max_seqlen) 43 | cu_seqlens: (batch_size + 1,) 44 | Return: 45 | flatten_lse: (nheads, total_seqlen) 46 | """ 47 | total_seqlen = cu_seqlens[-1] 48 | batch_size, nheads, max_seqlen = lse.shape 49 | output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device) 50 | 51 | grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) 52 | BLOCK_M = 4 53 | 54 | with torch.cuda.device(lse.device.index): 55 | flatten_kernel[grid]( 56 | output, 57 | lse, 58 | cu_seqlens, 59 | # strides 60 | output.stride(0), 61 | output.stride(1), 62 | lse.stride(0), 63 | lse.stride(1), 64 | lse.stride(2), 65 | BLOCK_M, 66 | ) 67 | return output 68 | 69 | 70 | @triton.jit 71 | def unflatten_kernel( 72 | # pointers to matrices 73 | OUT, 74 | LSE, 75 | CU_SEQLENS, 76 | # strides 77 | stride_out_batch, 78 | stride_out_nheads, 79 | stride_out_seqlen, 80 | stride_lse_seqlen, 81 | stride_lse_nheads, 82 | # meta-parameters 83 | BLOCK_M: tl.constexpr, 84 | ): 85 | pid_m = tl.program_id(axis=0) 86 | pid_batch = tl.program_id(axis=1) 87 | pid_head = tl.program_id(axis=2) 88 | 89 | start_idx = tl.load(CU_SEQLENS + pid_batch) 90 | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx 91 | LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen 92 | OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads 93 | 94 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 95 | 96 | LSE = LSE + rm[:, None] * stride_lse_seqlen 97 | x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) 98 | 99 | OUT = OUT + rm[:, None] * stride_out_seqlen 100 | tl.store(OUT, x, mask=rm[:, None] < seqlen) 101 | 102 | 103 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): 104 | """ 105 | Arguments: 106 | lse: (total_seqlen, nheads, 1) 107 | cu_seqlens: (batch_size + 1,) 108 | max_seqlen: int 109 | Return: 110 | unflatten_lse: (batch_size, nheads, max_seqlen) 111 | """ 112 | lse = lse.unsqueeze(dim=-1) 113 | batch_size = len(cu_seqlens) - 1 114 | nheads = lse.shape[1] 115 | output = torch.empty( 116 | (batch_size, nheads, max_seqlen), 117 | dtype=lse.dtype, 118 | device=lse.device, 119 | ) 120 | 121 | grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) 122 | BLOCK_M = 4 123 | 124 | with torch.cuda.device(lse.device.index): 125 | unflatten_kernel[grid]( 126 | output, 127 | lse, 128 | cu_seqlens, 129 | # strides 130 | output.stride(0), 131 | output.stride(1), 132 | output.stride(2), 133 | lse.stride(0), 134 | lse.stride(1), 135 | BLOCK_M, 136 | ) 137 | return output 138 | -------------------------------------------------------------------------------- /yunchang/ring/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import torch.nn.functional as F 6 | 7 | __all__ = ["update_out_and_lse", "RingComm"] 8 | 9 | @torch.jit.script 10 | def _update_out_and_lse( 11 | out: torch.Tensor, 12 | lse: torch.Tensor, 13 | block_out: torch.Tensor, 14 | block_lse: torch.Tensor, 15 | ) -> Tuple[torch.Tensor, torch.Tensor]: 16 | 17 | block_out = block_out.to(torch.float32) 18 | block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) 19 | 20 | # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) 21 | # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out 22 | # For additional context and discussion, please refer to: 23 | # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 24 | out = out - F.sigmoid(block_lse - lse) * (out - block_out) 25 | lse = lse - F.logsigmoid(lse - block_lse) 26 | 27 | return out, lse 28 | 29 | 30 | def update_out_and_lse( 31 | out: Optional[torch.Tensor], 32 | lse: Optional[torch.Tensor], 33 | block_out: torch.Tensor, 34 | block_lse: torch.Tensor, 35 | slice_=None, 36 | ) -> Tuple[torch.Tensor, torch.Tensor]: 37 | if out is None: 38 | if slice_ is not None: 39 | raise RuntimeError("first update_out_and_lse should not pass slice_ args") 40 | out = block_out.to(torch.float32) 41 | lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) 42 | elif slice_ is not None: 43 | slice_out, slice_lse = out[slice_], lse[slice_] 44 | slice_out, slice_lse = _update_out_and_lse( 45 | slice_out, slice_lse, block_out, block_lse 46 | ) 47 | out[slice_], lse[slice_] = slice_out, slice_lse 48 | else: 49 | out, lse = _update_out_and_lse(out, lse, block_out, block_lse) 50 | return out, lse 51 | 52 | 53 | @torch.jit.script 54 | def flatten_varlen_lse(lse, cu_seqlens): 55 | new_lse = [] 56 | for i in range(len(cu_seqlens) - 1): 57 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 58 | new_lse.append(lse[i, :, : end - start]) 59 | return torch.cat(new_lse, dim=1) 60 | 61 | 62 | @torch.jit.script 63 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): 64 | num_seq = len(cu_seqlens) - 1 65 | num_head = lse.shape[-2] 66 | new_lse = torch.empty( 67 | (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device 68 | ) 69 | for i in range(num_seq): 70 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 71 | new_lse[i, : end - start] = lse[start:end] 72 | return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() 73 | 74 | 75 | class RingComm: 76 | def __init__(self, process_group: dist.ProcessGroup): 77 | self._process_group = process_group 78 | self._ops = [] 79 | self.rank = dist.get_rank(self._process_group) 80 | self.world_size = dist.get_world_size(self._process_group) 81 | self._reqs = None 82 | 83 | self.send_rank = (self.rank + 1) % self.world_size 84 | self.recv_rank = (self.rank - 1) % self.world_size 85 | 86 | if process_group is not None: 87 | self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) 88 | self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) 89 | 90 | def send_recv( 91 | self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None 92 | ) -> torch.Tensor: 93 | if recv_tensor is None: 94 | res = torch.empty_like(to_send) 95 | # print(f"send_recv: empty_like {to_send.shape}") 96 | else: 97 | res = recv_tensor 98 | 99 | send_op = dist.P2POp( 100 | dist.isend, to_send, self.send_rank, group=self._process_group 101 | ) 102 | recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) 103 | self._ops.append(send_op) 104 | self._ops.append(recv_op) 105 | return res 106 | 107 | def commit(self): 108 | if self._reqs is not None: 109 | raise RuntimeError("commit called twice") 110 | self._reqs = dist.batch_isend_irecv(self._ops) 111 | 112 | def wait(self): 113 | if self._reqs is None: 114 | raise RuntimeError("wait called before commit") 115 | for req in self._reqs: 116 | req.wait() 117 | self._reqs = None 118 | self._ops = [] -------------------------------------------------------------------------------- /test/test_ulysses_attn_npu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch 3 | import torch.distributed as dist 4 | from yunchang import UlyssesAttention 5 | 6 | try: 7 | import torch_npu 8 | except ImportError: 9 | from flash_attn import flash_attn_func 10 | from yunchang.kernels import AttnType 11 | 12 | def log(msg, a, rank0_only=False): 13 | world_size = dist.get_world_size() 14 | rank = dist.get_rank() 15 | if rank0_only: 16 | if rank == 0: 17 | print( 18 | f"{msg}: " 19 | f"max {a.abs().max().item()}, " 20 | f"mean {a.abs().mean().item()}", 21 | flush=True, 22 | ) 23 | return 24 | 25 | for i in range(world_size): 26 | if i == rank: 27 | if rank == 0: 28 | print(f"{msg}:") 29 | print( 30 | f"[{rank}] " 31 | f"max {a.abs().max().item()}, " 32 | f"mean {a.abs().mean().item()}", 33 | flush=True, 34 | ) 35 | dist.barrier() 36 | 37 | 38 | if __name__ == "__main__": 39 | dist.init_process_group("hccl") 40 | 41 | rank = dist.get_rank() 42 | world_size = dist.get_world_size() 43 | dtype = torch.bfloat16 44 | device = torch.device(f"npu:{rank}") 45 | 46 | batch_size = 2 47 | seqlen = 3816 48 | nheads = 8 49 | d = 128 50 | dropout_p = 0 51 | causal = True 52 | deterministic = False 53 | 54 | assert seqlen % world_size == 0 55 | assert d % 8 == 0 56 | # assert batch_size == 1 57 | 58 | q = torch.randn( 59 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 60 | ) 61 | k = torch.randn( 62 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 63 | ) 64 | v = torch.randn( 65 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 66 | ) 67 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 68 | 69 | dist.broadcast(q, src=0) 70 | dist.broadcast(k, src=0) 71 | dist.broadcast(v, src=0) 72 | dist.broadcast(dout, src=0) 73 | 74 | local_q = q.chunk(world_size, dim=1)[rank].detach().clone() 75 | local_q.requires_grad = True 76 | local_k = k.chunk(world_size, dim=1)[rank].detach().clone() 77 | local_k.requires_grad = True 78 | local_v = v.chunk(world_size, dim=1)[rank].detach().clone() 79 | local_v.requires_grad = True 80 | 81 | local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone() 82 | 83 | # prcess_group == sequence_process_group 84 | sp_pg = None #dist.new_group(ranks=[i for i in range(world_size)]) 85 | 86 | dist_attn = UlyssesAttention(sp_pg, attn_type=AttnType.NPU) 87 | 88 | if rank == 0: 89 | print("#" * 30) 90 | print("# ds-ulysses forward:") 91 | print("#" * 30) 92 | 93 | local_out = dist_attn( 94 | local_q, 95 | local_k, 96 | local_v, 97 | dropout_p=dropout_p, 98 | causal=causal, 99 | window_size=(-1, -1), 100 | softcap=0.0, 101 | alibi_slopes=None, 102 | deterministic=deterministic, 103 | return_attn_probs=True, 104 | ) 105 | 106 | if rank == 0: 107 | print("#" * 30) 108 | print("# ds-ulysses backward:") 109 | print("#" * 30) 110 | 111 | local_out.backward(local_dout) 112 | 113 | dist.barrier() 114 | 115 | if rank == 0: 116 | print("#" * 30) 117 | print("# local forward:") 118 | print("#" * 30) 119 | # reference, a local flash attn 120 | 121 | softmax_scale = q.shape[-1] ** (-0.5) 122 | out_ref = torch_npu.npu_fusion_attention_v2(q, k, v, 123 | head_num = q.shape[-2], 124 | input_layout = "BSND", 125 | scale = softmax_scale, 126 | pre_tokens=65535, 127 | next_tokens=65535)[0] 128 | 129 | if rank == 0: 130 | print("#" * 30) 131 | print("# local forward:") 132 | print("#" * 30) 133 | 134 | out_ref.backward(dout) 135 | 136 | dist.barrier() 137 | 138 | # check correctness 139 | 140 | local_out_ref = out_ref.chunk(world_size, dim=1)[rank] 141 | 142 | log("out", local_out, rank0_only=True) 143 | log("out diff", local_out_ref - local_out) 144 | -------------------------------------------------------------------------------- /yunchang/globals.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | 5 | class Singleton: 6 | _instance = None 7 | 8 | def __new__(cls, *args, **kwargs): 9 | if not cls._instance: 10 | cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs) 11 | return cls._instance 12 | 13 | 14 | class ProcessGroupSingleton(Singleton): 15 | def __init__(self): 16 | self.ULYSSES_PG = None 17 | self.RING_PG = None 18 | 19 | 20 | PROCESS_GROUP = ProcessGroupSingleton() 21 | 22 | def set_seq_parallel_pg( 23 | sp_ulysses_degree, sp_ring_degree, rank, world_size, use_ulysses_low=True 24 | ): 25 | """ 26 | sp_ulysses_degree x sp_ring_degree = seq_parallel_degree 27 | (ulysses_degree, dp_degree) 28 | """ 29 | sp_degree = sp_ring_degree * sp_ulysses_degree 30 | dp_degree = world_size // sp_degree 31 | 32 | assert ( 33 | world_size % sp_degree == 0 34 | ), f"world_size {world_size} % sp_degree {sp_ulysses_degree} == 0" 35 | 36 | num_ulysses_pgs = sp_ring_degree # world_size // sp_ulysses_degree 37 | num_ring_pgs = sp_ulysses_degree # world_size // sp_ring_degree 38 | 39 | if use_ulysses_low: 40 | for dp_rank in range(dp_degree): 41 | offset = dp_rank * sp_degree 42 | for i in range(num_ulysses_pgs): 43 | ulysses_ranks = list( 44 | range( 45 | i * sp_ulysses_degree + offset, 46 | (i + 1) * sp_ulysses_degree + offset, 47 | ) 48 | ) 49 | group = torch.distributed.new_group(ulysses_ranks) 50 | if rank in ulysses_ranks: 51 | ulyssess_pg = group 52 | 53 | for i in range(num_ring_pgs): 54 | ring_ranks = list(range(i + offset, sp_degree + offset, num_ring_pgs)) 55 | group = torch.distributed.new_group(ring_ranks) 56 | if rank in ring_ranks: 57 | ring_pg = group 58 | 59 | else: 60 | for dp_rank in range(dp_degree): 61 | offset = dp_rank * sp_degree 62 | for i in range(num_ring_pgs): 63 | ring_ranks = list( 64 | range( 65 | i * sp_ring_degree + offset, (i + 1) * sp_ring_degree + offset 66 | ) 67 | ) 68 | group = torch.distributed.new_group(ring_ranks) 69 | if rank in ring_ranks: 70 | ring_pg = group 71 | 72 | for i in range(num_ulysses_pgs): 73 | ulysses_ranks = list( 74 | range(i + offset, sp_degree + offset, num_ulysses_pgs) 75 | ) 76 | group = torch.distributed.new_group(ulysses_ranks) 77 | if rank in ulysses_ranks: 78 | ulyssess_pg = group 79 | 80 | PROCESS_GROUP.ULYSSES_PG = ulyssess_pg 81 | PROCESS_GROUP.RING_PG = ring_pg 82 | 83 | # test if flash_attn is available 84 | try: 85 | import flash_attn 86 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward 87 | HAS_FLASH_ATTN = True 88 | except ImportError: 89 | HAS_FLASH_ATTN = False 90 | 91 | try: 92 | from flash_attn_interface import _flash_attn_forward as flash_attn_forward_hopper 93 | from flash_attn_interface import _flash_attn_backward as flash_attn_func_hopper_backward 94 | from flash_attn_interface import flash_attn_func as flash3_attn_func 95 | HAS_FLASH_ATTN_HOPPER = True 96 | except ImportError: 97 | HAS_FLASH_ATTN_HOPPER = False 98 | 99 | try: 100 | from flashinfer.prefill import single_prefill_with_kv_cache 101 | HAS_FLASHINFER = True 102 | def get_cuda_arch(): 103 | major, minor = torch.cuda.get_device_capability() 104 | return f"{major}.{minor}" 105 | 106 | cuda_arch = get_cuda_arch() 107 | os.environ['TORCH_CUDA_ARCH_LIST'] = cuda_arch 108 | print(f"Set TORCH_CUDA_ARCH_LIST to {cuda_arch}") 109 | except ImportError: 110 | HAS_FLASHINFER = False 111 | 112 | try: 113 | import aiter 114 | from aiter import flash_attn_func as flash_attn_func_aiter 115 | HAS_AITER = True 116 | except ImportError: 117 | HAS_AITER = False 118 | 119 | try: 120 | import sageattention 121 | HAS_SAGE_ATTENTION = True 122 | except ImportError: 123 | HAS_SAGE_ATTENTION = False 124 | 125 | try: 126 | import spas_sage_attn 127 | HAS_SPARSE_SAGE_ATTENTION = True 128 | except ImportError: 129 | HAS_SPARSE_SAGE_ATTENTION = False 130 | 131 | try: 132 | import torch_npu 133 | HAS_NPU = True 134 | except ImportError: 135 | HAS_NPU = False -------------------------------------------------------------------------------- /test/test_ulysses_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from yunchang import UlyssesAttention 4 | 5 | from flash_attn import flash_attn_func 6 | from yunchang.kernels import AttnType 7 | 8 | def log(msg, a, rank0_only=False): 9 | world_size = dist.get_world_size() 10 | rank = dist.get_rank() 11 | if rank0_only: 12 | if rank == 0: 13 | print( 14 | f"{msg}: " 15 | f"max {a.abs().max().item()}, " 16 | f"mean {a.abs().mean().item()}", 17 | flush=True, 18 | ) 19 | return 20 | 21 | for i in range(world_size): 22 | if i == rank: 23 | if rank == 0: 24 | print(f"{msg}:") 25 | print( 26 | f"[{rank}] " 27 | f"max {a.abs().max().item()}, " 28 | f"mean {a.abs().mean().item()}", 29 | flush=True, 30 | ) 31 | dist.barrier() 32 | 33 | 34 | if __name__ == "__main__": 35 | dist.init_process_group("nccl") 36 | 37 | rank = dist.get_rank() 38 | world_size = dist.get_world_size() 39 | dtype = torch.bfloat16 40 | device = torch.device(f"cuda:{rank}") 41 | 42 | batch_size = 2 43 | seqlen = 3816 44 | nheads = 8 45 | d = 128 46 | dropout_p = 0 47 | causal = True 48 | deterministic = False 49 | 50 | assert seqlen % world_size == 0 51 | assert d % 8 == 0 52 | # assert batch_size == 1 53 | 54 | q = torch.randn( 55 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 56 | ) 57 | k = torch.randn( 58 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 59 | ) 60 | v = torch.randn( 61 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 62 | ) 63 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 64 | 65 | dist.broadcast(q, src=0) 66 | dist.broadcast(k, src=0) 67 | dist.broadcast(v, src=0) 68 | dist.broadcast(dout, src=0) 69 | 70 | local_q = q.chunk(world_size, dim=1)[rank].detach().clone() 71 | local_q.requires_grad = True 72 | local_k = k.chunk(world_size, dim=1)[rank].detach().clone() 73 | local_k.requires_grad = True 74 | local_v = v.chunk(world_size, dim=1)[rank].detach().clone() 75 | local_v.requires_grad = True 76 | 77 | local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone() 78 | 79 | # prcess_group == sequence_process_group 80 | sp_pg = None #dist.new_group(ranks=[i for i in range(world_size)]) 81 | 82 | dist_attn = UlyssesAttention(sp_pg, attn_type=AttnType.FA) 83 | 84 | if rank == 0: 85 | print("#" * 30) 86 | print("# ds-ulysses forward:") 87 | print("#" * 30) 88 | 89 | local_out = dist_attn( 90 | local_q, 91 | local_k, 92 | local_v, 93 | dropout_p=dropout_p, 94 | causal=causal, 95 | window_size=(-1, -1), 96 | softcap=0.0, 97 | alibi_slopes=None, 98 | deterministic=deterministic, 99 | return_attn_probs=True, 100 | ) 101 | 102 | if rank == 0: 103 | print("#" * 30) 104 | print("# ds-ulysses backward:") 105 | print("#" * 30) 106 | 107 | local_out.backward(local_dout) 108 | 109 | dist.barrier() 110 | 111 | if rank == 0: 112 | print("#" * 30) 113 | print("# local forward:") 114 | print("#" * 30) 115 | # reference, a local flash attn 116 | out_ref, _, _ = flash_attn_func( 117 | q, 118 | k, 119 | v, 120 | dropout_p=dropout_p, 121 | causal=causal, 122 | window_size=(-1, -1), 123 | softcap=0.0, 124 | alibi_slopes=None, 125 | deterministic=deterministic, 126 | return_attn_probs=True, 127 | ) 128 | if rank == 0: 129 | print("#" * 30) 130 | print("# local forward:") 131 | print("#" * 30) 132 | 133 | out_ref.backward(dout) 134 | 135 | dist.barrier() 136 | 137 | # check correctness 138 | 139 | local_out_ref = out_ref.chunk(world_size, dim=1)[rank] 140 | 141 | log("out", local_out, rank0_only=True) 142 | log("out diff", local_out_ref - local_out) 143 | 144 | local_dq_ref = q.grad.chunk(world_size, dim=1)[rank] 145 | log("load_dq", local_q.grad) 146 | log("dq diff", local_dq_ref - local_q.grad) 147 | 148 | local_dk_ref = k.grad.chunk(world_size, dim=1)[rank] 149 | log("load_dk", local_k.grad) 150 | log("dk diff", local_dk_ref - local_k.grad) 151 | 152 | local_dv_ref = v.grad.chunk(world_size, dim=1)[rank] 153 | log("load_dk", local_v.grad) 154 | log("dv diff", local_dv_ref - local_v.grad) 155 | -------------------------------------------------------------------------------- /yunchang/ring/ring_pytorch_attn.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/huggingface/picotron/blob/main/picotron/context_parallel/context_parallel.py 2 | # Copyright 2024 The HuggingFace Inc. team and Jiarui Fang. 3 | 4 | import math 5 | import torch 6 | import torch.nn.functional as F 7 | from typing import Any, Optional, Tuple 8 | from yunchang.kernels import select_flash_attn_impl, AttnType 9 | from .utils import RingComm, update_out_and_lse 10 | from yunchang.kernels.attention import pytorch_attn_forward, pytorch_attn_backward 11 | 12 | def ring_pytorch_attn_func( 13 | q, 14 | k, 15 | v, 16 | dropout_p=0.0, 17 | softmax_scale=None, 18 | causal=False, 19 | window_size=(-1, -1), 20 | softcap=0.0, 21 | alibi_slopes=None, 22 | deterministic=False, 23 | return_attn_probs=False, 24 | group=None, 25 | attn_type: AttnType = AttnType.FA, 26 | attn_processor=None, 27 | ): 28 | return RingAttentionFunc.apply(group, q, k, v, softmax_scale, causal) 29 | 30 | class RingAttentionFunc(torch.autograd.Function): 31 | 32 | @staticmethod 33 | def forward(ctx, group, q, k, v, sm_scale, is_causal): 34 | 35 | comm = RingComm(group) 36 | #TODO(fmom): add flex attention 37 | #TODO(fmom): add flash attention 38 | #TODO(fmom): Find a better to save these tensors without cloning 39 | k_og = k.clone() 40 | v_og = v.clone() 41 | out, lse = None, None 42 | next_k, next_v = None, None 43 | 44 | if sm_scale is None: 45 | sm_scale = q.shape[-1] ** -0.5 46 | 47 | for step in range(comm.world_size): 48 | if step + 1 != comm.world_size: 49 | next_k = comm.send_recv(k) 50 | next_v = comm.send_recv(v) 51 | comm.commit() 52 | 53 | if not is_causal or step <= comm.rank: 54 | block_out, block_lse = pytorch_attn_forward( 55 | q, k, v, softmax_scale = sm_scale, causal = is_causal and step == 0 56 | ) 57 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 58 | 59 | if step + 1 != comm.world_size: 60 | comm.wait() 61 | k = next_k 62 | v = next_v 63 | 64 | out = out.to(q.dtype) 65 | 66 | ctx.save_for_backward(q, k_og, v_og, out, lse.squeeze(-1)) 67 | ctx.sm_scale = sm_scale 68 | ctx.is_causal = is_causal 69 | ctx.group = group 70 | 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, dout, *args): 75 | 76 | 77 | q, k, v, out, softmax_lse = ctx.saved_tensors 78 | sm_scale = ctx.sm_scale 79 | is_causal = ctx.is_causal 80 | 81 | kv_comm = RingComm(ctx.group) 82 | d_kv_comm = RingComm(ctx.group) 83 | 84 | dq, dk, dv = None, None, None 85 | next_dk, next_dv = None, None 86 | 87 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 88 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 89 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 90 | 91 | next_dk, next_dv = None, None 92 | next_k, next_v = None, None 93 | 94 | for step in range(kv_comm.world_size): 95 | if step + 1 != kv_comm.world_size: 96 | next_k = kv_comm.send_recv(k) 97 | next_v = kv_comm.send_recv(v) 98 | kv_comm.commit() 99 | 100 | if step <= kv_comm.rank or not is_causal: 101 | bwd_causal = is_causal and step == 0 102 | 103 | block_dq_buffer, block_dk_buffer, block_dv_buffer = pytorch_attn_backward( 104 | dout, q, k, v, out, softmax_lse = softmax_lse, softmax_scale = sm_scale, causal = bwd_causal 105 | ) 106 | 107 | if dq is None: 108 | dq = block_dq_buffer.to(torch.float32) 109 | dk = block_dk_buffer.to(torch.float32) 110 | dv = block_dv_buffer.to(torch.float32) 111 | else: 112 | dq += block_dq_buffer 113 | d_kv_comm.wait() 114 | dk = block_dk_buffer + next_dk 115 | dv = block_dv_buffer + next_dv 116 | elif step != 0: 117 | d_kv_comm.wait() 118 | dk = next_dk 119 | dv = next_dv 120 | 121 | if step + 1 != kv_comm.world_size: 122 | kv_comm.wait() 123 | k = next_k 124 | v = next_v 125 | 126 | next_dk = d_kv_comm.send_recv(dk) 127 | next_dv = d_kv_comm.send_recv(dv) 128 | d_kv_comm.commit() 129 | 130 | d_kv_comm.wait() 131 | 132 | return dq, next_dk, next_dv, None, None 133 | -------------------------------------------------------------------------------- /yunchang/ulysses/attn_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation and Jiarui Fang 2 | # SPDX-License-Identifier: Apache-2.0 3 | # DeepSpeed Team & Jiarui Fang 4 | 5 | 6 | import torch 7 | 8 | from typing import Any 9 | from torch import Tensor 10 | from yunchang.kernels import AttnType, select_flash_attn_impl 11 | import torch.distributed as dist 12 | from yunchang.comm.all_to_all import SeqAllToAll4D 13 | 14 | 15 | class UlyssesAttention(torch.nn.Module): 16 | """Initialization. 17 | 18 | Arguments: 19 | local_attention (Module): local attention with q,k,v 20 | sequence_process_group (ProcessGroup): sequence parallel process group 21 | scatter_idx (int): scatter_idx for all2all comm 22 | gather_idx (int): gather_idx for all2all comm 23 | use_sync (bool): whether to synchronize after all-to-all. This flag can save cuda memory but will slow down the speed. 24 | attn_type (AttnType): attention type enum 25 | """ 26 | 27 | def __init__( 28 | self, 29 | sequence_process_group: dist.ProcessGroup = None, 30 | scatter_idx: int = 2, 31 | gather_idx: int = 1, 32 | use_sync: bool = False, 33 | attn_type : AttnType = AttnType.FA, 34 | ) -> None: 35 | 36 | super(UlyssesAttention, self).__init__() 37 | self.spg = sequence_process_group 38 | self.scatter_idx = scatter_idx 39 | self.gather_idx = gather_idx 40 | self.use_sync = use_sync 41 | self.attn_type = attn_type 42 | 43 | try: 44 | import torch_npu 45 | device = torch.device("npu") 46 | except: 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | gpu_name = torch.cuda.get_device_name(device) 49 | if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name: 50 | self.attn_type = AttnType.TORCH 51 | self.attn_fn = select_flash_attn_impl(self.attn_type, stage="fwd-bwd") 52 | 53 | def forward( 54 | self, 55 | query: Tensor, 56 | key: Tensor, 57 | value: Tensor, 58 | dropout_p=0.0, 59 | softmax_scale=None, 60 | causal=False, 61 | window_size=(-1, -1), 62 | softcap=0.0, 63 | alibi_slopes=None, 64 | deterministic=False, 65 | return_attn_probs=False, 66 | *args: Any 67 | ) -> Tensor: 68 | """forward 69 | 70 | Arguments: 71 | query (Tensor): query input to the layer 72 | key (Tensor): key input to the layer 73 | value (Tensor): value input to the layer 74 | args: other args 75 | 76 | Returns: 77 | * output (Tensor): context output 78 | """ 79 | # TODO Merge three alltoall calls into one 80 | # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! 81 | # in shape : e.g., [s/p:h:] 82 | # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) 83 | 84 | # scatter 2, gather 1 85 | q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx, self.use_sync) 86 | k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx, self.use_sync) 87 | v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx, self.use_sync) 88 | 89 | if softmax_scale is None: 90 | softmax_scale = q.shape[-1] ** -0.5 91 | 92 | if self.attn_type is AttnType.NPU: 93 | context_layer = self.attn_fn( 94 | q, 95 | k, 96 | v, 97 | head_num = q.shape[-2], 98 | input_layout = "BSND", 99 | scale = softmax_scale, 100 | pre_tokens=65535, 101 | next_tokens=65535, 102 | ) 103 | else: 104 | context_layer = self.attn_fn( 105 | q, 106 | k, 107 | v, 108 | dropout_p=dropout_p, 109 | softmax_scale = softmax_scale, 110 | causal=causal, 111 | window_size=window_size, 112 | softcap=softcap, 113 | alibi_slopes=alibi_slopes, 114 | deterministic=deterministic, 115 | return_attn_probs=return_attn_probs, 116 | ) 117 | 118 | if isinstance(context_layer, tuple): 119 | context_layer = context_layer[0] 120 | 121 | # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) 122 | # scatter 1, gather 2 123 | output = SeqAllToAll4D.apply( 124 | self.spg, context_layer, self.gather_idx, self.scatter_idx, self.use_sync 125 | ) 126 | 127 | # out e.g., [s/p::h] 128 | return output 129 | 130 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange, repeat 2 | import math 3 | 4 | import torch 5 | 6 | 7 | # adpated from flash-attention 8 | def construct_local_mask( 9 | seqlen_q, 10 | seqlen_k, 11 | window_size=(-1, -1), # -1 means infinite window size 12 | query_padding_mask=None, 13 | key_padding_mask=None, 14 | device=None, 15 | key_leftpad=None, 16 | ): 17 | row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") 18 | col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) 19 | if key_leftpad is not None: 20 | key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") 21 | col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) 22 | col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) 23 | sk = ( 24 | seqlen_k 25 | if key_padding_mask is None 26 | else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") 27 | ) 28 | sq = ( 29 | seqlen_q 30 | if query_padding_mask is None 31 | else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") 32 | ) 33 | if window_size[0] < 0: 34 | return col_idx > row_idx + sk - sq + window_size[1] 35 | else: 36 | sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk 37 | return torch.logical_or( 38 | col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), 39 | col_idx < row_idx + sk - sq - window_size[0], 40 | ) 41 | 42 | # adpated from flash-attention 43 | def attention_ref( 44 | q, 45 | k, 46 | v, 47 | query_padding_mask=None, 48 | key_padding_mask=None, 49 | attn_bias=None, 50 | dropout_p=0.0, 51 | dropout_mask=None, 52 | causal=False, 53 | window_size=(-1, -1), # -1 means infinite window size 54 | softcap=0.0, 55 | upcast=True, 56 | reorder_ops=False, 57 | key_leftpad=None, 58 | ): 59 | """ 60 | Arguments: 61 | q: (batch_size, seqlen_q, nheads, head_dim) 62 | k: (batch_size, seqlen_k, nheads_k, head_dim) 63 | v: (batch_size, seqlen_k, nheads_k, head_dim) 64 | query_padding_mask: (batch_size, seqlen_q) 65 | key_padding_mask: (batch_size, seqlen_k) 66 | attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) 67 | dropout_p: float 68 | dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) 69 | causal: whether to apply causal masking 70 | window_size: (int, int), left and right window size 71 | upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast 72 | output back to fp16/bf16. 73 | reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) 74 | without changing the math. This is to estimate the numerical error from operation 75 | reordering. 76 | Output: 77 | output: (batch_size, seqlen_q, nheads, head_dim) 78 | attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout 79 | """ 80 | if causal: 81 | window_size = (window_size[0], 0) 82 | dtype_og = q.dtype 83 | if upcast: 84 | q, k, v = q.float(), k.float(), v.float() 85 | seqlen_q, seqlen_k = q.shape[1], k.shape[1] 86 | k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) 87 | v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) 88 | d = q.shape[-1] 89 | if not reorder_ops: 90 | scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) 91 | else: 92 | scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) 93 | if softcap > 0: 94 | scores = scores / softcap 95 | scores = scores.tanh() 96 | scores = scores * softcap 97 | if key_padding_mask is not None: 98 | scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) 99 | if window_size[0] >= 0 or window_size[1] >= 0: 100 | local_mask = construct_local_mask( 101 | seqlen_q, 102 | seqlen_k, 103 | window_size, 104 | query_padding_mask, 105 | key_padding_mask, 106 | q.device, 107 | key_leftpad=key_leftpad, 108 | ) 109 | scores.masked_fill_(local_mask, float("-inf")) 110 | if attn_bias is not None: 111 | scores = scores + attn_bias 112 | attention = torch.softmax(scores, dim=-1).to(v.dtype) 113 | # Some rows might be completely masked out so we fill them with zero instead of NaN 114 | if window_size[0] >= 0 or window_size[1] >= 0: 115 | attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) 116 | # We want to mask here so that the attention matrix doesn't have any NaNs 117 | # Otherwise we'll get NaN in dV 118 | if query_padding_mask is not None: 119 | attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) 120 | dropout_scaling = 1.0 / (1 - dropout_p) 121 | # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling 122 | # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) 123 | if dropout_mask is not None: 124 | attention_drop = attention.masked_fill(~dropout_mask, 0.0) 125 | else: 126 | attention_drop = attention 127 | output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) 128 | if query_padding_mask is not None: 129 | output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) 130 | return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) -------------------------------------------------------------------------------- /test/test_hybrid_qkvpacked_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from yunchang import ( 4 | LongContextAttentionQKVPacked, 5 | set_seq_parallel_pg, 6 | EXTRACT_FUNC_DICT, 7 | RING_IMPL_QKVPACKED_DICT 8 | ) 9 | from yunchang.kernels import AttnType 10 | 11 | 12 | def log(msg, a, rank0_only=False): 13 | world_size = dist.get_world_size() 14 | rank = dist.get_rank() 15 | if rank0_only: 16 | if rank == 0: 17 | print( 18 | f"{msg}: " 19 | f"max {a.abs().max().item()}, " 20 | f"mean {a.abs().mean().item()}", 21 | flush=True, 22 | ) 23 | return 24 | 25 | for i in range(world_size): 26 | if i == rank: 27 | if rank == 0: 28 | print(f"{msg}:") 29 | print( 30 | f"[{rank}] " 31 | f"max {a.abs().max().item()}, " 32 | f"mean {a.abs().mean().item()}", 33 | flush=True, 34 | ) 35 | dist.barrier() 36 | 37 | import os 38 | 39 | def get_local_rank(): 40 | local_rank = int(os.getenv('LOCAL_RANK', '0')) 41 | return local_rank 42 | 43 | def test(ring_impl_type="zigzag"): 44 | 45 | rank = dist.get_rank() 46 | local_rank = get_local_rank() 47 | world_size = dist.get_world_size() 48 | dtype = torch.bfloat16 49 | device = torch.device(f"cuda:{local_rank}") 50 | print(f"rank {rank} local_rank {local_rank} world_size {world_size}") 51 | 52 | batch_size = 2 53 | seqlen = 1024 54 | nheads = 8 55 | d = 32 56 | dropout_p = 0.0 57 | causal = True 58 | deterministic = False 59 | 60 | assert seqlen % world_size == 0 61 | assert d % 8 == 0 62 | 63 | sp_ulysses_degree = 2 # min(world_size, nheads) 64 | sp_ring_degree = world_size // sp_ulysses_degree 65 | 66 | set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size) 67 | 68 | longctx_attn = LongContextAttentionQKVPacked(ring_impl_type=ring_impl_type, 69 | attn_type=AttnType.FA) 70 | 71 | ## prepare input and output tensors 72 | 73 | # global tensors 74 | qkv = torch.randn( 75 | batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True 76 | ) 77 | 78 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 79 | 80 | with torch.no_grad(): 81 | dist.broadcast(qkv, src=0) 82 | dist.broadcast(dout, src=0) 83 | 84 | # sharded tensors for long context attn 85 | local_qkv = ( 86 | EXTRACT_FUNC_DICT[ring_impl_type]( 87 | qkv, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 88 | ) 89 | .detach() 90 | .clone() 91 | ) 92 | local_qkv.requires_grad = True 93 | 94 | local_dout = ( 95 | EXTRACT_FUNC_DICT[ring_impl_type]( 96 | dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 97 | ) 98 | .detach() 99 | .clone() 100 | ) 101 | # shared tensors for reference 102 | local_qkv_ref = local_qkv.detach().clone() 103 | local_qkv_ref.requires_grad = True 104 | 105 | dist.barrier() 106 | if rank == 0: 107 | print("#" * 30) 108 | print("# forward:") 109 | print("#" * 30) 110 | 111 | print(f"local_qkv shape {local_qkv.shape}") 112 | local_out = longctx_attn( 113 | local_qkv, 114 | dropout_p=dropout_p, 115 | causal=causal, 116 | window_size=(-1, -1), 117 | softcap=0.0, 118 | alibi_slopes=None, 119 | deterministic=deterministic, 120 | return_attn_probs=True, 121 | ) 122 | 123 | from flash_attn import flash_attn_qkvpacked_func 124 | # local_out = out.chunk(world_size, dim=1)[rank] 125 | # local_lse = lse.chunk(world_size, dim=-1)[rank] 126 | 127 | out, lse, _ = flash_attn_qkvpacked_func( 128 | qkv, 129 | dropout_p=dropout_p, 130 | causal=causal, 131 | window_size=(-1, -1), 132 | softcap=0.0, 133 | alibi_slopes=None, 134 | deterministic=deterministic, 135 | return_attn_probs=True, 136 | ) 137 | 138 | local_out_ref = EXTRACT_FUNC_DICT[ring_impl_type]( 139 | out, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 140 | ) 141 | 142 | log("out_ref", local_out_ref, rank0_only=True) 143 | log("out", local_out, rank0_only=True) 144 | 145 | # log("lse", lse, rank0_only=True) 146 | log("out diff", local_out - local_out_ref) 147 | # log("lse diff", local_lse - ring_lse) 148 | 149 | dist.barrier() 150 | 151 | # if rank == 0: 152 | # print(local_out_ref) 153 | # print(local_out) 154 | 155 | if rank == 0: 156 | print("#" * 30) 157 | print("# backward:") 158 | print("#" * 30) 159 | 160 | # long context attn backward 161 | local_out.backward(local_dout) 162 | local_dqkv = local_qkv.grad 163 | 164 | # local ring backward 165 | out.backward(dout) 166 | dqkv = qkv.grad 167 | 168 | local_dqkv_ref = EXTRACT_FUNC_DICT[ring_impl_type]( 169 | dqkv, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 170 | ) 171 | 172 | log("load_dq", local_dqkv_ref) 173 | log("dq diff", local_dqkv - local_dqkv_ref) 174 | 175 | 176 | 177 | if __name__ == "__main__": 178 | dist.init_process_group("nccl") 179 | for ring_impl_type in ["basic", "zigzag"]: 180 | print(f"ring_impl_type: {ring_impl_type}") 181 | test(ring_impl_type) 182 | if dist.is_initialized(): 183 | dist.destroy_process_group() 184 | -------------------------------------------------------------------------------- /benchmark/benchmark_longctx_qkvpacked.py: -------------------------------------------------------------------------------- 1 | from flash_attn import flash_attn_varlen_qkvpacked_func 2 | import torch 3 | import torch.distributed as dist 4 | from yunchang import set_seq_parallel_pg, LongContextAttentionQKVPacked 5 | from yunchang.comm import EXTRACT_FUNC_DICT 6 | import torch.cuda 7 | 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser(description="args for benchmark.") 11 | 12 | parser.add_argument( 13 | "--ring_impl_type", 14 | type=str, 15 | default="basic", 16 | choices=["basic", "zigzag", "strip"], 17 | help="ring attn implementation type", 18 | ) 19 | parser.add_argument("--nheads", type=int, default=2, help="head number") 20 | parser.add_argument("--head_size", type=int, default=128, help="head number") 21 | parser.add_argument( 22 | "--seq_len", 23 | type=int, 24 | default=4 * 1024, 25 | help="local sequence length, the global sequence length is seq_len * world_size", 26 | ) 27 | parser.add_argument("--batch_size", type=int, default=2, help="batch size") 28 | parser.add_argument( 29 | "--fwd_only", action="store_true", help="benchmark forward pass only" 30 | ) 31 | parser.add_argument( 32 | "--use_ulysses_lowdim", 33 | action="store_true", 34 | default=True, 35 | help="ulysses process group on low dimension", 36 | ) 37 | parser.add_argument( 38 | "--ulysses_degree", 39 | type=int, 40 | default=1, 41 | help="ulysses attention sequence parallel degree", 42 | ) 43 | args = parser.parse_args() 44 | 45 | 46 | def color_print(text): 47 | print("\033[91m {}\033[00m".format(text)) 48 | 49 | import os 50 | def get_local_rank(): 51 | local_rank = int(os.getenv('LOCAL_RANK', '0')) 52 | return local_rank 53 | 54 | 55 | def benchmark(num_iter=100, forward_only=True, log=True): 56 | dtype = torch.bfloat16 57 | rank = dist.get_rank() 58 | local_rank = get_local_rank() 59 | world_size = dist.get_world_size() 60 | device = torch.device(f"cuda:{local_rank}") 61 | torch.cuda.set_device(device) 62 | 63 | batch_size = args.batch_size 64 | seqlen = args.seq_len 65 | nheads = args.nheads 66 | d = args.head_size 67 | 68 | dropout_p = 0 69 | causal = True 70 | deterministic = False 71 | 72 | assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}" 73 | assert d % 8 == 0 74 | 75 | qkv = torch.randn( 76 | batch_size, 77 | seqlen * world_size, 78 | 3, 79 | nheads, 80 | d, 81 | device=device, 82 | dtype=dtype, 83 | requires_grad=True, 84 | ) 85 | dout = torch.randn( 86 | batch_size, seqlen * world_size, nheads, d, device=device, dtype=dtype 87 | ) 88 | 89 | sp_ulysses_degree = min(args.ulysses_degree, world_size) 90 | sp_ring_degree = world_size // sp_ulysses_degree 91 | 92 | set_seq_parallel_pg( 93 | sp_ulysses_degree, sp_ring_degree, rank, world_size, args.use_ulysses_lowdim 94 | ) 95 | 96 | longctx_attn = LongContextAttentionQKVPacked(ring_impl_type=args.ring_impl_type) 97 | 98 | # NOTE() using zigzag and stripe have a special layout. 99 | qkv = ( 100 | EXTRACT_FUNC_DICT[args.ring_impl_type]( 101 | qkv, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 102 | ) 103 | .detach() 104 | .clone() 105 | ) 106 | qkv.requires_grad = True 107 | dout = ( 108 | EXTRACT_FUNC_DICT[args.ring_impl_type]( 109 | dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 110 | ) 111 | .detach() 112 | .clone() 113 | ) 114 | 115 | out = longctx_attn( 116 | qkv, 117 | dropout_p=dropout_p, 118 | causal=causal, 119 | window_size=(-1, -1), 120 | softcap=0.0, 121 | alibi_slopes=None, 122 | deterministic=deterministic, 123 | return_attn_probs=False, 124 | ) 125 | out.backward(dout) 126 | 127 | begin = torch.cuda.Event(enable_timing=True) 128 | begin.record() 129 | 130 | if forward_only: 131 | with torch.no_grad(): 132 | for _ in range(num_iter): 133 | _ = longctx_attn( 134 | qkv, 135 | dropout_p=dropout_p, 136 | causal=causal, 137 | window_size=(-1, -1), 138 | softcap=0.0, 139 | alibi_slopes=None, 140 | deterministic=deterministic, 141 | return_attn_probs=False, 142 | ) 143 | 144 | else: 145 | for _ in range(num_iter): 146 | qkv.grad = None 147 | out = longctx_attn( 148 | qkv, 149 | dropout_p=dropout_p, 150 | causal=causal, 151 | window_size=(-1, -1), 152 | softcap=0.0, 153 | alibi_slopes=None, 154 | deterministic=deterministic, 155 | return_attn_probs=False, 156 | ) 157 | out.backward(dout) 158 | end = torch.cuda.Event(enable_timing=True) 159 | end.record() 160 | torch.cuda.synchronize(device=device) 161 | time = begin.elapsed_time(end) / 1000.0 162 | 163 | if rank == 0 and log: 164 | color_print(f"{num_iter / time:.3f} iter/s, {time:.3f} sec") 165 | 166 | 167 | if __name__ == "__main__": 168 | dist.init_process_group("nccl") 169 | rank = dist.get_rank() 170 | 171 | forward_only = args.fwd_only 172 | 173 | torch.cuda.empty_cache() 174 | if rank == 0: 175 | color_print(vars(args)) 176 | color_print( 177 | f"# long context attention qkvpacked {args.ring_impl_type}. ulysses_degree : {args.ulysses_degree} " 178 | f"fwd_only {forward_only} " 179 | f"use_ulysses_lowdim {args.use_ulysses_lowdim} " 180 | ) 181 | benchmark(forward_only=forward_only, log=False) 182 | benchmark(forward_only=forward_only, log=True) 183 | -------------------------------------------------------------------------------- /yunchang/ring/ring_npu_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from .utils import RingComm, update_out_and_lse 4 | from yunchang.kernels.attention import ( 5 | npu_fused_attn_forward, 6 | npu_fused_attn_backward, 7 | ) 8 | from datetime import datetime 9 | 10 | 11 | def ring_npu_flash_attn_forward( 12 | process_group, 13 | q: torch.Tensor, 14 | k: torch.Tensor, 15 | v: torch.Tensor, 16 | head_num: int=None, 17 | input_layout: str="BSND" 18 | ): 19 | comm = RingComm(process_group) 20 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()}, ring_npu_flash_attn_forward") 21 | # 单卡场景直接计算 22 | if comm.world_size == 1: 23 | return npu_fused_attn_forward(q, k, v, head_num, input_layout) 24 | 25 | attention_out,softmax_max, softmax_sum, scale_value = None,None,None,None 26 | 27 | next_k, next_v = None, None 28 | 29 | for step in range(comm.world_size): 30 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_forward step: {step}") 31 | # 非最后一步:发起下一个kv的通信(异步) 32 | if step + 1 != comm.world_size: 33 | next_k = comm.send_recv(k) 34 | next_v = comm.send_recv(v) 35 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_forward commit: {step}") 36 | comm.commit() 37 | 38 | # 当前step计算(仅当step <= 当前rank时处理本地kv) 39 | if step <= comm.rank: 40 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_forward calculation: {step}") 41 | attention_out, softmax_max, softmax_sum, scale_value = npu_fused_attn_forward(q, k, v, head_num, input_layout) 42 | 43 | # 非最后一步:等待通信完成,更新kv 44 | if step + 1 != comm.world_size: 45 | comm.wait() 46 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_forward wait: {step}") 47 | k = next_k 48 | v = next_v 49 | return attention_out, softmax_max, softmax_sum, scale_value 50 | 51 | 52 | def ring_npu_flash_attn_backward( 53 | process_group,q, k, v, grad_attention_out, head_num=None, input_layout="BSND", softmax_max=None,softmax_sum=None,attention_in=None, scale_value=None): 54 | kv_comm = RingComm(process_group) 55 | d_kv_comm = RingComm(process_group) 56 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()}, ring_npu_flash_attn_backward") 57 | 58 | # 初始化梯度张量(避免None,用0张量初始化) 59 | dq = torch.zeros_like(q, dtype=torch.float32) 60 | dk = torch.zeros_like(k, dtype=torch.float32) 61 | dv = torch.zeros_like(v, dtype=torch.float32) 62 | next_k, next_v = None, None 63 | next_dk, next_dv = None, None 64 | 65 | for step in range(kv_comm.world_size): 66 | # 1. 发起kv通信(获取下一个step的kv) 67 | if step + 1 != kv_comm.world_size: 68 | next_k = kv_comm.send_recv(k) 69 | next_v = kv_comm.send_recv(v) 70 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward commit: {step}") 71 | kv_comm.commit() 72 | 73 | # 2. 计算当前step的梯度 74 | if step <= kv_comm.rank: 75 | grad_query, grad_key, grad_value = npu_fused_attn_backward( 76 | q, k, v, grad_attention_out, head_num, input_layout, softmax_max=softmax_max, softmax_sum=softmax_sum, attention_in=attention_in, scale_value=scale_value) 77 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward calculation: {step}") 78 | # 累加query梯度(每个rank只计算自己的q梯度) 79 | dq += grad_query.to(torch.float32) 80 | 81 | # 累加kv梯度:如果不是第一步,需要加上通信过来的梯度 82 | if step > 0: 83 | d_kv_comm.wait() # 等待上一轮dk/dv通信完成 84 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward d_kv_comm wait: {step}") 85 | dk += grad_key.to(torch.float32) + next_dk 86 | dv += grad_value.to(torch.float32) + next_dv 87 | else: 88 | # 第一步直接赋值 89 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward dkdv: {step}") 90 | dk = grad_key.to(torch.float32) 91 | dv = grad_value.to(torch.float32) 92 | else: 93 | # step > 当前rank:仅接收上一轮的dk/dv 94 | if step > 0: 95 | d_kv_comm.wait() 96 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward d_kv_comm wait to next_dk: {step}") 97 | dk = next_dk 98 | dv = next_dv 99 | 100 | # 3. 等待kv通信完成,更新kv 101 | if step + 1 != kv_comm.world_size: 102 | kv_comm.wait() 103 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward kv_comm wait for update: {step}") 104 | k = next_k 105 | v = next_v 106 | 107 | next_dk = d_kv_comm.send_recv(dk) 108 | next_dv = d_kv_comm.send_recv(dv) 109 | d_kv_comm.commit() 110 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward d_kv_comm commit: {step}") 111 | 112 | # 等待最后一轮dk/dv通信完成 113 | d_kv_comm.wait() 114 | # print(f"{datetime.now()} current device is: {torch.cuda.current_device()},ring_npu_flash_attn_backward d_kv_comm wait for last: {step}") 115 | 116 | # 转换为输入 dtype 并返回 117 | return (dq.to(q.dtype), dk.to(q.dtype), dv.to(q.dtype)) 118 | 119 | class RingNpuFlashAttnFunc(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, group, q, k, v, head_num, input_layout="BSND"): 122 | # 前向传播逻辑 123 | attention_out,softmax_max, softmax_sum, scale = ring_npu_flash_attn_forward(group,q=q, k=k, v=v, head_num=head_num, input_layout=input_layout) 124 | # 保存中间结果,以便在反向传播中使用 125 | ctx.save_for_backward(q, k, v, attention_out,softmax_max, softmax_sum) 126 | ctx.head_num = head_num 127 | ctx.input_layout = input_layout 128 | ctx.group = group 129 | ctx.scale=scale 130 | 131 | return attention_out 132 | 133 | @staticmethod 134 | def backward(ctx, grad_attention_out): 135 | # 获取保存的中间结果 136 | q, k, v, attention_out,softmax_max, softmax_sum = ctx.saved_tensors 137 | # 反向传播逻辑 138 | # 这里假设有一个实现反向传播的函数 `npu_fusion_attention_backward` 139 | grad_query, grad_key, grad_value = ring_npu_flash_attn_backward(ctx.group,q, k, v, grad_attention_out, 140 | ctx.head_num, ctx.input_layout,softmax_max, softmax_sum, attention_out, ctx.scale) 141 | return None, grad_query, grad_key, grad_value,None,None 142 | 143 | def ring_npu_flash_attn_func( 144 | group, 145 | q: torch.Tensor, 146 | k: torch.Tensor, 147 | v: torch.Tensor, 148 | head_num: int=None, 149 | input_layout: str="BSND" 150 | ): 151 | head_num = q.shape[-2] 152 | return RingNpuFlashAttnFunc.apply( 153 | group, 154 | q, 155 | k, 156 | v, 157 | head_num, 158 | input_layout 159 | ) -------------------------------------------------------------------------------- /yunchang/hybrid/async_attn_layer.py: -------------------------------------------------------------------------------- 1 | from yunchang.comm.all_to_all import SeqAllToAll4D, SeqAllToAll5D 2 | 3 | import torch 4 | 5 | from typing import Any 6 | from torch import Tensor 7 | 8 | import torch.distributed as dist 9 | from .utils import RING_IMPL_DICT, RING_IMPL_QKVPACKED_DICT 10 | from yunchang.globals import PROCESS_GROUP 11 | 12 | 13 | class AsyncLongContextAttention(torch.nn.Module): 14 | """Initialization. 15 | 16 | Arguments: 17 | ulysses_pg (ProcessGroup): ulysses process group 18 | ring_pg (ProcessGroup): ring process group 19 | scatter_idx (int): scatter_idx for all2all comm 20 | gather_idx (int): gather_idx for all2all comm 21 | """ 22 | 23 | def __init__( 24 | self, 25 | scatter_idx: int = 2, 26 | gather_idx: int = 1, 27 | ring_impl_type: str = "basic", 28 | ) -> None: 29 | 30 | super(AsyncLongContextAttention, self).__init__() 31 | self.ring_pg = PROCESS_GROUP.RING_PG 32 | self.ulysses_pg = PROCESS_GROUP.ULYSSES_PG 33 | 34 | self.stream = torch.cuda.Stream() 35 | self._async_op = True 36 | 37 | assert ( 38 | self.ulysses_pg is not None or self.ring_pg is not None 39 | ), f"use set_seq_parallel_pg() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}" 40 | self.scatter_idx = scatter_idx 41 | self.gather_idx = gather_idx 42 | self.ring_attn_fn = RING_IMPL_DICT[ring_impl_type] 43 | 44 | def forward( 45 | self, 46 | query: Tensor, 47 | key: Tensor, 48 | value: Tensor, 49 | dropout_p=0.0, 50 | softmax_scale=None, 51 | causal=False, 52 | window_size=(-1, -1), 53 | softcap=0.0, 54 | alibi_slopes=None, 55 | deterministic=False, 56 | return_attn_probs=False, 57 | *args: Any, 58 | ) -> Tensor: 59 | """forward 60 | 61 | Arguments: 62 | query (Tensor): query input to the layer (bs, seqlen/P, hc, hs) 63 | key (Tensor): key input to the layer (bs, seqlen/P, hc_kv, hs) 64 | value (Tensor): value input to the layer (bs, seqlen/P, hc_kv, hs) 65 | args: other args 66 | 67 | Returns: 68 | * output (Tensor): context output 69 | """ 70 | 71 | # un*ud = hc 72 | 73 | ulysses_degree = dist.get_world_size(self.ulysses_pg) 74 | 75 | bs, shard_seqlen, hc, hs = query.shape 76 | bs, shard_seqlen, hc_kv, hs = key.shape 77 | seq_len = shard_seqlen * ulysses_degree 78 | un = hc // ulysses_degree 79 | un_kv = hc_kv // ulysses_degree 80 | 81 | assert un_kv == un, f"un_kv {un_kv} un {un}" 82 | 83 | qkv = torch.cat([query, key, value]).contiguous() 84 | # (3*bs, seqlen/P, hc, hs) -> (hc, seqlen/P, 3*bs, hs) -> (un, ud, seqlen/P, 3*bs, hs), where hc = un*ud 85 | qkv_list = torch.unbind( 86 | qkv.transpose(0, 2) 87 | .contiguous() 88 | .reshape(un, ulysses_degree, shard_seqlen, 3 * bs, hs) 89 | ) 90 | # 3xall-to-all output buffer 91 | qkv_trans_list = [ 92 | torch.zeros( 93 | ulysses_degree, 94 | 1, 95 | shard_seqlen, 96 | 3 * bs, 97 | hs, 98 | dtype=query.dtype, 99 | device=query.device, 100 | ) 101 | for i in range(len(qkv_list)) 102 | ] 103 | # last all-to-all buffter 104 | context_layer_list = [ 105 | torch.zeros( 106 | ulysses_degree, 107 | 1, 108 | shard_seqlen, 109 | bs, 110 | hs, 111 | dtype=query.dtype, 112 | device=query.device, 113 | ) 114 | for i in range(len(qkv_list)) 115 | ] 116 | 117 | comm_handle_list = [] 118 | 119 | # un * (ud, shard_seqlen, 3*bs, hs) 120 | for i, qkv in enumerate(qkv_list): 121 | with torch.cuda.stream(self.stream): 122 | ret = dist.all_to_all_single( 123 | qkv_trans_list[i], 124 | qkv, 125 | group=self.ulysses_pg, 126 | async_op=self._async_op, 127 | ) 128 | comm_handle_list.append(ret) 129 | 130 | last_comm_handle_list = [] 131 | for i, qkv_trans in enumerate(qkv_trans_list): 132 | if comm_handle_list[i] is not None: 133 | comm_handle_list[i].wait() 134 | qkv_trans = ( 135 | qkv_trans.reshape(seq_len, 3 * bs, 1, hs) 136 | .transpose(0, 1) 137 | .contiguous() 138 | .reshape(3 * bs, seq_len, 1, hs) 139 | ) 140 | 141 | # qkv_trans = all_to_all_4D_async(qkv, qkv_trans_list[i], self.scatter_idx, self.gather_idx, self.ulysses_pg) 142 | qkv_trans = torch.chunk(qkv_trans, 3, dim=0) 143 | 144 | out = self.ring_attn_fn( 145 | qkv_trans[0], 146 | qkv_trans[1], 147 | qkv_trans[2], 148 | dropout_p=dropout_p, 149 | softmax_scale=softmax_scale, 150 | causal=causal, 151 | window_size=window_size, 152 | softcap=softcap, 153 | alibi_slopes=alibi_slopes, 154 | deterministic=deterministic, 155 | return_attn_probs=return_attn_probs, 156 | group=self.ring_pg, 157 | ) 158 | 159 | if type(out) == tuple: 160 | context_layer, _, _ = out 161 | else: 162 | context_layer = out 163 | 164 | # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) 165 | # scatter 1, gather 2 166 | 167 | context_layer = ( 168 | context_layer.reshape(bs, ulysses_degree, shard_seqlen, 1, hs) 169 | .transpose(0, 3) 170 | .transpose(0, 1) 171 | .contiguous() 172 | .reshape(ulysses_degree, 1, shard_seqlen, bs, hs) 173 | ) 174 | with torch.cuda.stream(self.stream): 175 | ret = dist.all_to_all_single( 176 | context_layer_list[i], 177 | context_layer, 178 | group=self.ulysses_pg, 179 | async_op=self._async_op, 180 | ) 181 | last_comm_handle_list.append(ret) 182 | 183 | # hc = un * P 184 | # un x (hc = P, seq_len/P, bs, hs) -> (bs, seq_len, hc = P, hs) 185 | for i, ret in enumerate(last_comm_handle_list): 186 | if ret is not None: 187 | ret.wait() 188 | context_layer_list[i] = ( 189 | context_layer_list[i] 190 | .reshape(ulysses_degree, shard_seqlen, bs, hs) 191 | .transpose(0, 2) 192 | .contiguous() 193 | .reshape(bs, shard_seqlen, ulysses_degree, hs) 194 | ) 195 | 196 | output = torch.cat(context_layer_list, dim=2) 197 | return output 198 | 199 | def backward(self, *args, **kwargs): 200 | raise RuntimeError( 201 | "Backward computation is not allowed for AsyncLongContextAttention." 202 | ) 203 | -------------------------------------------------------------------------------- /benchmark/benchmark_longctx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from yunchang import ( 4 | AsyncLongContextAttention, 5 | LongContextAttention, 6 | set_seq_parallel_pg, 7 | UlyssesAttention, 8 | ) 9 | from yunchang.comm import EXTRACT_FUNC_DICT 10 | import torch.cuda 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser(description="args for benchmark.") 14 | 15 | parser.add_argument( 16 | "--ring_impl_type", 17 | type=str, 18 | default="basic", 19 | choices=["basic", "zigzag", "strip"], 20 | help="ring attn implementation type", 21 | ) 22 | parser.add_argument("--nheads", type=int, default=2, help="head number") 23 | parser.add_argument("--head_size", type=int, default=128, help="head size") 24 | parser.add_argument("--seq_len", type=int, default=4 * 1024, help="sequence length") 25 | parser.add_argument("--group_num", type=int, default=1, help="group number") 26 | parser.add_argument("--batch_size", type=int, default=2, help="batch size") 27 | parser.add_argument( 28 | "--fwd_only", action="store_true", help="benchmark forward pass only" 29 | ) 30 | parser.add_argument( 31 | "--use_ulysses_lowdim", 32 | action="store_true", 33 | default=True, 34 | help="ulysses process group on low dimension", 35 | ) 36 | parser.add_argument( 37 | "--use_qkvpack", 38 | action="store_true", 39 | default=False, 40 | help="pack qkv before all-to-all", 41 | ) 42 | parser.add_argument( 43 | "--ulysses_degree", 44 | type=int, 45 | default=1, 46 | help="ulysses attention sequence parallel degree", 47 | ) 48 | parser.add_argument( 49 | "--use_profiler", 50 | action="store_true", 51 | default=False, 52 | help="use torch profiler", 53 | ) 54 | parser.add_argument( 55 | "--use_ulysses", 56 | action="store_true", 57 | default=False, 58 | help="use ulysses", 59 | ) 60 | parser.add_argument( 61 | "--attn_type", 62 | type=str, 63 | default="fa", 64 | choices=["fa", "fa3", "torch"], 65 | help="attention type", 66 | ) 67 | # decault causal=True for LLM. no_causal is for DiT. 68 | parser.add_argument( 69 | "--no_causal", 70 | action="store_true", 71 | default=False, 72 | help="use no causal attention", 73 | ) 74 | 75 | args = parser.parse_args() 76 | 77 | 78 | def color_print(text): 79 | print("\033[91m {}\033[00m".format(text)) 80 | 81 | 82 | def init_prof(use_profiler): 83 | activities = [] 84 | activities.append(torch.profiler.ProfilerActivity.CPU) 85 | activities.append(torch.profiler.ProfilerActivity.CUDA) 86 | 87 | from contextlib import nullcontext 88 | 89 | ctx = ( 90 | torch.profiler.profile( 91 | activities=activities, 92 | schedule=torch.profiler.schedule(wait=0, warmup=2, active=4, repeat=1), 93 | on_trace_ready=torch.profiler.tensorboard_trace_handler("./profile/"), 94 | record_shapes=True, 95 | with_stack=True, 96 | ) 97 | if use_profiler 98 | else nullcontext() 99 | ) 100 | return ctx 101 | 102 | import os 103 | 104 | def get_local_rank(): 105 | local_rank = int(os.getenv('LOCAL_RANK', '0')) 106 | return local_rank 107 | 108 | def benchmark(num_iter=10, forward_only=True, log=True, profile=False): 109 | dtype = torch.float16 110 | rank = dist.get_rank() 111 | local_rank = get_local_rank() 112 | world_size = dist.get_world_size() 113 | device = torch.device(f"cuda:{local_rank}") 114 | torch.cuda.set_device(device) 115 | 116 | batch_size = args.batch_size 117 | seqlen = args.seq_len 118 | nheads = args.nheads 119 | group_num = args.group_num 120 | d = args.head_size 121 | 122 | dropout_p = 0.0 123 | causal = not args.no_causal 124 | deterministic = False 125 | 126 | assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}" 127 | assert d % 8 == 0 128 | assert nheads % group_num == 0, f"nheads {nheads} group_num {group_num}" 129 | assert ( 130 | nheads // group_num % args.ulysses_degree == 0 131 | ), f"nheads {nheads}, group_num {group_num}, ulysses_degree {args.ulysses_degree}" 132 | 133 | q = torch.randn( 134 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 135 | ) 136 | k = torch.randn( 137 | batch_size, 138 | seqlen, 139 | nheads // group_num, 140 | d, 141 | device=device, 142 | dtype=dtype, 143 | requires_grad=True, 144 | ) 145 | v = torch.randn( 146 | batch_size, 147 | seqlen, 148 | nheads // group_num, 149 | d, 150 | device=device, 151 | dtype=dtype, 152 | requires_grad=True, 153 | ) 154 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 155 | 156 | sp_ulysses_degree = min(args.ulysses_degree, world_size) 157 | sp_ring_degree = world_size // sp_ulysses_degree 158 | 159 | set_seq_parallel_pg( 160 | sp_ulysses_degree, sp_ring_degree, rank, world_size, args.use_ulysses_lowdim 161 | ) 162 | 163 | from yunchang.kernels import AttnType 164 | attn_type = AttnType.from_string(args.attn_type) 165 | if args.use_ulysses: 166 | longctx_attn = UlyssesAttention(attn_type=attn_type) 167 | else: 168 | longctx_attn = LongContextAttention(ring_impl_type=args.ring_impl_type, attn_type=attn_type) 169 | 170 | out = longctx_attn( 171 | q, 172 | k, 173 | v, 174 | dropout_p=dropout_p, 175 | causal=causal, 176 | window_size=(-1, -1), 177 | softcap=0.0, 178 | alibi_slopes=None, 179 | deterministic=deterministic, 180 | return_attn_probs=False, 181 | ) 182 | if not args.fwd_only: 183 | out.backward(dout) 184 | 185 | out = longctx_attn( 186 | q, 187 | k, 188 | v, 189 | dropout_p=dropout_p, 190 | causal=causal, 191 | window_size=(-1, -1), 192 | softcap=0.0, 193 | alibi_slopes=None, 194 | deterministic=deterministic, 195 | return_attn_probs=False, 196 | ) 197 | if not args.fwd_only: 198 | out.backward(dout) 199 | 200 | begin = torch.cuda.Event(enable_timing=True) 201 | begin.record() 202 | 203 | ctx = init_prof(profile) 204 | 205 | with ctx as prof: 206 | if forward_only: 207 | with torch.no_grad(): 208 | for _ in range(num_iter): 209 | _ = longctx_attn( 210 | q, 211 | k, 212 | v, 213 | dropout_p=dropout_p, 214 | causal=causal, 215 | window_size=(-1, -1), 216 | softcap=0.0, 217 | alibi_slopes=None, 218 | deterministic=deterministic, 219 | return_attn_probs=False, 220 | ) 221 | 222 | torch.cuda.synchronize(device=device) 223 | 224 | if profile: 225 | prof.step() 226 | else: 227 | for _ in range(num_iter): 228 | q.grad = None 229 | k.grad = None 230 | v.grad = None 231 | out = longctx_attn( 232 | q, 233 | k, 234 | v, 235 | dropout_p=dropout_p, 236 | causal=causal, 237 | window_size=(-1, -1), 238 | softcap=0.0, 239 | alibi_slopes=None, 240 | deterministic=deterministic, 241 | return_attn_probs=False, 242 | ) 243 | out.backward(dout) 244 | 245 | if profile: 246 | prof.step() 247 | 248 | end = torch.cuda.Event(enable_timing=True) 249 | end.record() 250 | 251 | torch.cuda.synchronize(device=device) 252 | elapse = begin.elapsed_time(end) / 1000.0 253 | 254 | if rank == 0 and log: 255 | color_print(f"{num_iter / elapse:.3f} iter/s, {elapse:.3f} sec") 256 | 257 | 258 | if __name__ == "__main__": 259 | dist.init_process_group("nccl") 260 | rank = dist.get_rank() 261 | 262 | forward_only = args.fwd_only 263 | 264 | torch.cuda.empty_cache() 265 | if rank == 0: 266 | color_print( 267 | f"ring_impl_type: {args.ring_impl_type}. " 268 | f"nheads: {args.nheads} head_size: {args.head_size} seq_len: {args.seq_len} " 269 | f"ulysses_degree : {args.ulysses_degree} fwd_only {forward_only} use_ulysses_lowdim {args.use_ulysses_lowdim}. " 270 | f"use_qkvpack: {args.use_qkvpack} " 271 | f"use_ulysses: {args.use_ulysses} " 272 | f"causal: {not args.no_causal} " 273 | f"attn_type: {args.attn_type} " 274 | ) 275 | torch.cuda.empty_cache() 276 | benchmark(forward_only=forward_only, log=False) 277 | benchmark(forward_only=forward_only, log=True, profile=args.use_profiler) 278 | dist.destroy_process_group() -------------------------------------------------------------------------------- /yunchang/hybrid/attn_layer.py: -------------------------------------------------------------------------------- 1 | from yunchang.comm.all_to_all import SeqAllToAll4D, SeqAllToAll5D 2 | 3 | import torch 4 | 5 | from typing import Any 6 | from torch import Tensor 7 | 8 | import torch.distributed as dist 9 | from .utils import RING_IMPL_DICT, RING_IMPL_QKVPACKED_DICT 10 | from yunchang.globals import PROCESS_GROUP, HAS_SPARSE_SAGE_ATTENTION 11 | from yunchang.kernels import AttnType 12 | 13 | 14 | class LongContextAttention(torch.nn.Module): 15 | """Initialization. 16 | 17 | Arguments: 18 | ulysses_pg (ProcessGroup): ulysses process group 19 | ring_pg (ProcessGroup): ring process group 20 | scatter_idx (int): scatter_idx for all2all comm 21 | gather_idx (int): gather_idx for all2all comm 22 | use_sync (bool): whether to synchronize after all-to-all 23 | """ 24 | 25 | def __init__( 26 | self, 27 | scatter_idx: int = 2, 28 | gather_idx: int = 1, 29 | ring_impl_type: str = "basic", 30 | use_pack_qkv: bool = False, 31 | use_sync: bool = False, 32 | attn_type: AttnType = AttnType.FA, 33 | attn_processor: torch.nn.Module = None, 34 | ) -> None: 35 | 36 | super(LongContextAttention, self).__init__() 37 | self.ring_pg = PROCESS_GROUP.RING_PG 38 | self.ulysses_pg = PROCESS_GROUP.ULYSSES_PG 39 | 40 | self.use_pack_qkv = use_pack_qkv 41 | self.use_sync = use_sync 42 | self.attn_type = attn_type 43 | assert ( 44 | self.ulysses_pg is not None or self.ring_pg is not None 45 | ), f"use set_seq_parallel_pg() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}" 46 | self.scatter_idx = scatter_idx 47 | self.gather_idx = gather_idx 48 | self.attn_processor = attn_processor 49 | self.ring_attn_fn = RING_IMPL_DICT[ring_impl_type] 50 | 51 | if HAS_SPARSE_SAGE_ATTENTION: 52 | from spas_sage_attn.autotune import SparseAttentionMeansim 53 | if isinstance(attn_processor, SparseAttentionMeansim) and dist.get_world_size(self.ring_pg) > 1: 54 | raise RuntimeError("Sparse Sage attention does not support ring degree > 1.") 55 | 56 | 57 | def forward( 58 | self, 59 | query: Tensor, 60 | key: Tensor, 61 | value: Tensor, 62 | dropout_p=0.0, 63 | softmax_scale=None, 64 | causal=False, 65 | window_size=(-1, -1), 66 | softcap=0.0, 67 | alibi_slopes=None, 68 | deterministic=False, 69 | return_attn_probs=False, 70 | *args: Any, 71 | ) -> Tensor: 72 | """forward 73 | 74 | Arguments: 75 | query (Tensor): query input to the layer 76 | key (Tensor): key input to the layer 77 | value (Tensor): value input to the layer 78 | args: other args 79 | 80 | Returns: 81 | * output (Tensor): context output 82 | """ 83 | 84 | # 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size) 85 | # scatter 2, gather 1 86 | if self.use_pack_qkv: 87 | # (3*bs, seq_len/N, head_cnt, head_size) 88 | qkv = torch.cat([query, key, value]).continous() 89 | # (3*bs, seq_len, head_cnt/N, head_size) 90 | qkv = SeqAllToAll4D.apply( 91 | self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx, use_sync=self.use_sync 92 | ) 93 | qkv = torch.chunk(qkv, 3, dim=0) 94 | out = self.ring_attn_fn( 95 | qkv[0], 96 | qkv[1], 97 | qkv[2], 98 | dropout_p=dropout_p, 99 | softmax_scale=softmax_scale, 100 | causal=causal, 101 | window_size=window_size, 102 | softcap=softcap, 103 | alibi_slopes=alibi_slopes, 104 | deterministic=deterministic, 105 | return_attn_probs=return_attn_probs, 106 | group=self.ring_pg, 107 | attn_type=self.attn_type, 108 | attn_processor=self.attn_processor, 109 | ) 110 | else: 111 | query_layer = SeqAllToAll4D.apply( 112 | self.ulysses_pg, query, self.scatter_idx, self.gather_idx, self.use_sync 113 | ) 114 | key_layer = SeqAllToAll4D.apply( 115 | self.ulysses_pg, key, self.scatter_idx, self.gather_idx, self.use_sync 116 | ) 117 | value_layer = SeqAllToAll4D.apply( 118 | self.ulysses_pg, value, self.scatter_idx, self.gather_idx, self.use_sync 119 | ) 120 | if self.attn_type is AttnType.NPU: 121 | out = self.ring_attn_fn( 122 | self.ring_pg, 123 | query_layer, 124 | key_layer, 125 | value_layer 126 | ) 127 | else: 128 | out = self.ring_attn_fn( 129 | query_layer, 130 | key_layer, 131 | value_layer, 132 | dropout_p=dropout_p, 133 | softmax_scale=softmax_scale, 134 | causal=causal, 135 | window_size=window_size, 136 | softcap=softcap, 137 | alibi_slopes=alibi_slopes, 138 | deterministic=deterministic, 139 | return_attn_probs=return_attn_probs, 140 | group=self.ring_pg, 141 | attn_type=self.attn_type, 142 | attn_processor=self.attn_processor, 143 | ) 144 | 145 | if type(out) == tuple: 146 | context_layer, _, _ = out 147 | else: 148 | context_layer = out 149 | 150 | # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) 151 | # scatter 1, gather 2 152 | output = SeqAllToAll4D.apply( 153 | self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx, self.use_sync 154 | ) 155 | 156 | # out e.g., [s/p::h] 157 | return output 158 | 159 | 160 | class LongContextAttentionQKVPacked(torch.nn.Module): 161 | """Initialization. 162 | 163 | Arguments: 164 | ulysses_pg (ProcessGroup): ulysses process group 165 | ring_pg (ProcessGroup): ring process group 166 | scatter_idx (int): scatter_idx for all2all comm 167 | gather_idx (int): gather_idx for all2all comm 168 | use_sync (bool): whether to synchronize after all-to-all 169 | """ 170 | 171 | def __init__( 172 | self, 173 | scatter_idx: int = 3, 174 | gather_idx: int = 1, 175 | ring_impl_type: str = "basic", 176 | use_sync: bool = False, 177 | attn_type: AttnType = AttnType.FA, 178 | ) -> None: 179 | 180 | super(LongContextAttentionQKVPacked, self).__init__() 181 | 182 | self.ring_pg = PROCESS_GROUP.RING_PG 183 | self.ulysses_pg = PROCESS_GROUP.ULYSSES_PG 184 | 185 | assert ( 186 | self.ulysses_pg is not None or self.ring_pg is not None 187 | ), f"use set_seq_parallel_pg() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}" 188 | self.scatter_idx = scatter_idx 189 | self.gather_idx = gather_idx 190 | self.use_sync = use_sync 191 | self.ring_attn_fn = RING_IMPL_QKVPACKED_DICT[ring_impl_type] 192 | self.attn_type = attn_type 193 | 194 | def forward( 195 | self, 196 | qkv, 197 | dropout_p=0.0, 198 | softmax_scale=None, 199 | causal=False, 200 | window_size=(-1, -1), 201 | softcap=0.0, 202 | alibi_slopes=None, 203 | deterministic=False, 204 | return_attn_probs=False, 205 | *args: Any, 206 | ) -> Tensor: 207 | """forward 208 | 209 | Arguments: 210 | query (Tensor): query input to the layer 211 | key (Tensor): key input to the layer 212 | value (Tensor): value input to the layer 213 | args: other args 214 | 215 | Returns: 216 | * output (Tensor): context output 217 | """ 218 | 219 | # scatter 3, gather 1 220 | 221 | world_size = dist.get_world_size(self.ulysses_pg) 222 | 223 | if world_size > 1: 224 | qkv = SeqAllToAll5D.apply( 225 | self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx, self.use_sync 226 | ) 227 | 228 | out = self.ring_attn_fn( 229 | qkv, 230 | dropout_p=dropout_p, 231 | softmax_scale=softmax_scale, 232 | causal=causal, 233 | window_size=window_size, 234 | softcap=softcap, 235 | alibi_slopes=alibi_slopes, 236 | deterministic=deterministic, 237 | return_attn_probs=return_attn_probs, 238 | group=self.ring_pg, 239 | attn_type=self.attn_type, 240 | ) 241 | 242 | # print(f"out {out.shape}") 243 | 244 | if type(out) == tuple: 245 | out = out[0] 246 | 247 | # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) 248 | # scatter 1, gather 2 249 | 250 | if world_size > 1: 251 | out = SeqAllToAll4D.apply( 252 | self.ulysses_pg, out, self.gather_idx, self.scatter_idx - 1, self.use_sync 253 | ) 254 | # out e.g., [s/p::h] 255 | return out 256 | -------------------------------------------------------------------------------- /yunchang/ring/ring_flashinfer_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | from .utils import RingComm, update_out_and_lse 5 | from yunchang.kernels import select_flash_attn_impl, AttnType 6 | 7 | 8 | def ring_flashinfer_attn_forward( 9 | process_group, 10 | q: torch.Tensor, 11 | k: torch.Tensor, 12 | v: torch.Tensor, 13 | softmax_scale, 14 | dropout_p=0, 15 | causal=True, 16 | window_size=(-1, -1), 17 | softcap=0.0, 18 | alibi_slopes=None, 19 | deterministic=False, 20 | attn_type: AttnType = AttnType.FLASHINFER, 21 | attn_processor=None, 22 | ): 23 | comm = RingComm(process_group) 24 | 25 | out = None 26 | lse = None 27 | 28 | next_k, next_v = None, None 29 | 30 | for step in range(comm.world_size): 31 | if step + 1 != comm.world_size: 32 | next_k: torch.Tensor = comm.send_recv(k) 33 | next_v: torch.Tensor = comm.send_recv(v) 34 | comm.commit() 35 | 36 | if not causal or step <= comm.rank: 37 | fn = select_flash_attn_impl(attn_type, stage="fwd-only", attn_processor=attn_processor) 38 | block_out, block_lse = fn( 39 | q, 40 | k, 41 | v, 42 | dropout_p=dropout_p, 43 | softmax_scale=softmax_scale, 44 | causal=causal and step == 0, 45 | window_size=window_size, 46 | softcap=softcap, 47 | alibi_slopes=alibi_slopes, 48 | return_softmax=True and dropout_p > 0, 49 | ) 50 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 51 | 52 | if step + 1 != comm.world_size: 53 | comm.wait() 54 | k = next_k 55 | v = next_v 56 | 57 | out = out.to(q.dtype) 58 | lse = lse.squeeze(dim=-1).transpose(1, 2) 59 | return out, lse 60 | 61 | 62 | def ring_flashinfer_attn_backward( 63 | process_group, 64 | dout, 65 | q, 66 | k, 67 | v, 68 | out, 69 | softmax_lse, 70 | softmax_scale, 71 | dropout_p=0, 72 | causal=True, 73 | window_size=(-1, -1), 74 | softcap=0.0, 75 | alibi_slopes=None, 76 | deterministic=False, 77 | attn_type: AttnType = AttnType.FLASHINFER, 78 | ): 79 | kv_comm = RingComm(process_group) 80 | d_kv_comm = RingComm(process_group) 81 | dq, dk, dv = None, None, None 82 | next_dk, next_dv = None, None 83 | 84 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 85 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 86 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 87 | 88 | next_dk, next_dv = None, None 89 | next_k, next_v = None, None 90 | 91 | for step in range(kv_comm.world_size): 92 | if step + 1 != kv_comm.world_size: 93 | next_k = kv_comm.send_recv(k) 94 | next_v = kv_comm.send_recv(v) 95 | kv_comm.commit() 96 | if step <= kv_comm.rank or not causal: 97 | bwd_causal = causal and step == 0 98 | fn = select_flash_attn_impl(attn_type, stage="bwd-only") 99 | fn( 100 | dout, 101 | q, 102 | k, 103 | v, 104 | out, 105 | softmax_lse, 106 | block_dq_buffer, 107 | block_dk_buffer, 108 | block_dv_buffer, 109 | dropout_p, 110 | softmax_scale, 111 | bwd_causal, 112 | window_size, 113 | softcap, 114 | alibi_slopes, 115 | deterministic, 116 | rng_state=None, 117 | ) 118 | 119 | if dq is None: 120 | dq = block_dq_buffer.to(torch.float32) 121 | dk = block_dk_buffer.to(torch.float32) 122 | dv = block_dv_buffer.to(torch.float32) 123 | else: 124 | dq += block_dq_buffer 125 | d_kv_comm.wait() 126 | dk = block_dk_buffer + next_dk 127 | dv = block_dv_buffer + next_dv 128 | elif step != 0: 129 | d_kv_comm.wait() 130 | dk = next_dk 131 | dv = next_dv 132 | 133 | if step + 1 != kv_comm.world_size: 134 | kv_comm.wait() 135 | k = next_k 136 | v = next_v 137 | 138 | next_dk = d_kv_comm.send_recv(dk) 139 | next_dv = d_kv_comm.send_recv(dv) 140 | d_kv_comm.commit() 141 | 142 | d_kv_comm.wait() 143 | 144 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) 145 | 146 | 147 | WORKSPACE_BUFFER = None 148 | PREFILL_WRAPPER = None 149 | 150 | 151 | class RingFlashInferAttnFunc(torch.autograd.Function): 152 | @staticmethod 153 | def forward( 154 | ctx, 155 | q, 156 | k, 157 | v, 158 | dropout_p, 159 | softmax_scale, 160 | causal, 161 | window_size, 162 | softcap, 163 | alibi_slopes, 164 | deterministic, 165 | return_softmax, 166 | group, 167 | attn_type, 168 | attn_processor, 169 | ): 170 | if softmax_scale is None: 171 | softmax_scale = q.shape[-1] ** (-0.5) 172 | 173 | assert alibi_slopes is None 174 | k = k.contiguous() 175 | v = v.contiguous() 176 | out, softmax_lse = ring_flashinfer_attn_forward( 177 | group, 178 | q, 179 | k, 180 | v, 181 | softmax_scale=softmax_scale, 182 | dropout_p=dropout_p, 183 | causal=causal, 184 | window_size=window_size, 185 | softcap=softcap, 186 | alibi_slopes=alibi_slopes, 187 | deterministic=False, 188 | attn_type=attn_type, 189 | attn_processor=attn_processor, 190 | ) 191 | # this should be out_padded 192 | ctx.save_for_backward(q, k, v, out, softmax_lse) 193 | ctx.dropout_p = dropout_p 194 | ctx.softmax_scale = softmax_scale 195 | ctx.causal = causal 196 | ctx.window_size = window_size 197 | ctx.softcap = softcap 198 | ctx.alibi_slopes = alibi_slopes 199 | ctx.deterministic = deterministic 200 | ctx.group = group 201 | ctx.attn_type = attn_type 202 | ctx.attn_processor = attn_processor 203 | return out if not return_softmax else (out, softmax_lse, None) 204 | 205 | @staticmethod 206 | def backward(ctx, dout, *args): 207 | q, k, v, out, softmax_lse = ctx.saved_tensors 208 | dq, dk, dv = ring_flashinfer_attn_backward( 209 | ctx.group, 210 | dout, 211 | q, 212 | k, 213 | v, 214 | out, 215 | softmax_lse, 216 | softmax_scale=ctx.softmax_scale, 217 | dropout_p=ctx.dropout_p, 218 | causal=ctx.causal, 219 | window_size=ctx.window_size, 220 | softcap=ctx.softcap, 221 | alibi_slopes=ctx.alibi_slopes, 222 | deterministic=ctx.deterministic, 223 | attn_type=ctx.attn_type, 224 | ) 225 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None 226 | 227 | 228 | def ring_flashinfer_attn_qkvpacked_func( 229 | qkv, 230 | dropout_p=0.0, 231 | softmax_scale=None, 232 | causal=False, 233 | window_size=(-1, -1), 234 | softcap=0.0, 235 | alibi_slopes=None, 236 | deterministic=False, 237 | return_attn_probs=False, 238 | group=None, 239 | attn_type: AttnType = AttnType.FLASHINFER, 240 | ): 241 | return RingFlashInferAttnFunc.apply( 242 | qkv[:, :, 0], 243 | qkv[:, :, 1], 244 | qkv[:, :, 2], 245 | dropout_p, 246 | softmax_scale, 247 | causal, 248 | window_size, 249 | softcap, 250 | alibi_slopes, 251 | deterministic, 252 | return_attn_probs, 253 | group, 254 | attn_type, 255 | ) 256 | 257 | 258 | def ring_flashinfer_attn_kvpacked_func( 259 | q, 260 | kv, 261 | dropout_p=0.0, 262 | softmax_scale=None, 263 | causal=False, 264 | window_size=(-1, -1), 265 | softcap=0.0, 266 | alibi_slopes=None, 267 | deterministic=False, 268 | return_attn_probs=False, 269 | group=None, 270 | attn_type: AttnType = AttnType.FLASHINFER, 271 | ): 272 | return RingFlashInferAttnFunc.apply( 273 | q, 274 | kv[:, :, 0], 275 | kv[:, :, 1], 276 | dropout_p, 277 | softmax_scale, 278 | causal, 279 | window_size, 280 | softcap, 281 | alibi_slopes, 282 | deterministic, 283 | return_attn_probs, 284 | group, 285 | attn_type, 286 | ) 287 | 288 | 289 | def ring_flashinfer_attn_func( 290 | q, 291 | k, 292 | v, 293 | dropout_p=0.0, 294 | softmax_scale=None, 295 | causal=False, 296 | window_size=(-1, -1), 297 | softcap=0.0, 298 | alibi_slopes=None, 299 | deterministic=False, 300 | return_attn_probs=False, 301 | group=None, 302 | attn_type: AttnType = AttnType.FLASHINFER, 303 | attn_processor=None, 304 | ): 305 | return RingFlashInferAttnFunc.apply( 306 | q, 307 | k, 308 | v, 309 | dropout_p, 310 | softmax_scale, 311 | causal, 312 | window_size, 313 | softcap, 314 | alibi_slopes, 315 | deterministic, 316 | return_attn_probs, 317 | group, 318 | attn_type, 319 | attn_processor, 320 | ) 321 | -------------------------------------------------------------------------------- /yunchang/ring/ring_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | # from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward 4 | from .utils import RingComm, update_out_and_lse 5 | from yunchang.kernels import select_flash_attn_impl, AttnType 6 | 7 | def ring_flash_attn_forward( 8 | process_group, 9 | q: torch.Tensor, 10 | k: torch.Tensor, 11 | v: torch.Tensor, 12 | softmax_scale, 13 | dropout_p=0, 14 | causal=True, 15 | window_size=(-1, -1), 16 | softcap=0.0, 17 | alibi_slopes=None, 18 | deterministic=False, 19 | attn_type: AttnType = AttnType.FA, 20 | attn_processor=None, 21 | ): 22 | comm = RingComm(process_group) 23 | 24 | out = None 25 | lse = None 26 | 27 | next_k, next_v = None, None 28 | 29 | for step in range(comm.world_size): 30 | if step + 1 != comm.world_size: 31 | next_k: torch.Tensor = comm.send_recv(k) 32 | next_v: torch.Tensor = comm.send_recv(v) 33 | comm.commit() 34 | 35 | if not causal or step <= comm.rank: 36 | fn = select_flash_attn_impl(attn_type, stage="fwd-only", attn_processor=attn_processor) 37 | block_out, block_lse = fn( 38 | q, 39 | k, 40 | v, 41 | dropout_p=dropout_p, 42 | softmax_scale=softmax_scale, 43 | causal=causal and step == 0, 44 | window_size=window_size, 45 | softcap=softcap, 46 | alibi_slopes=alibi_slopes, 47 | return_softmax=True and dropout_p > 0, 48 | ) 49 | if attn_type == AttnType.SPARSE_SAGE: 50 | out, lse = block_out, block_lse 51 | else: 52 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 53 | 54 | if step + 1 != comm.world_size: 55 | comm.wait() 56 | k = next_k 57 | v = next_v 58 | 59 | out = out.to(q.dtype) 60 | if attn_type != AttnType.SPARSE_SAGE: 61 | lse = lse.squeeze(dim=-1).transpose(1, 2) 62 | return out, lse 63 | 64 | 65 | def ring_flash_attn_backward( 66 | process_group, 67 | dout, 68 | q, 69 | k, 70 | v, 71 | out, 72 | softmax_lse, 73 | softmax_scale, 74 | dropout_p=0, 75 | causal=True, 76 | window_size=(-1, -1), 77 | softcap=0.0, 78 | alibi_slopes=None, 79 | deterministic=False, 80 | attn_type: AttnType = AttnType.FA, 81 | ): 82 | kv_comm = RingComm(process_group) 83 | d_kv_comm = RingComm(process_group) 84 | dq, dk, dv = None, None, None 85 | next_dk, next_dv = None, None 86 | 87 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 88 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 89 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 90 | 91 | next_dk, next_dv = None, None 92 | next_k, next_v = None, None 93 | 94 | for step in range(kv_comm.world_size): 95 | if step + 1 != kv_comm.world_size: 96 | next_k = kv_comm.send_recv(k) 97 | next_v = kv_comm.send_recv(v) 98 | kv_comm.commit() 99 | if step <= kv_comm.rank or not causal: 100 | bwd_causal = causal and step == 0 101 | fn = select_flash_attn_impl(attn_type, stage="bwd-only") 102 | fn( 103 | dout, 104 | q, 105 | k, 106 | v, 107 | out, 108 | softmax_lse, 109 | block_dq_buffer, 110 | block_dk_buffer, 111 | block_dv_buffer, 112 | dropout_p, 113 | softmax_scale, 114 | bwd_causal, 115 | window_size, 116 | softcap, 117 | alibi_slopes, 118 | deterministic, 119 | rng_state=None, 120 | ) 121 | 122 | if dq is None: 123 | dq = block_dq_buffer.to(torch.float32) 124 | dk = block_dk_buffer.to(torch.float32) 125 | dv = block_dv_buffer.to(torch.float32) 126 | else: 127 | dq += block_dq_buffer 128 | d_kv_comm.wait() 129 | dk = block_dk_buffer + next_dk 130 | dv = block_dv_buffer + next_dv 131 | elif step != 0: 132 | d_kv_comm.wait() 133 | dk = next_dk 134 | dv = next_dv 135 | 136 | if step + 1 != kv_comm.world_size: 137 | kv_comm.wait() 138 | k = next_k 139 | v = next_v 140 | 141 | next_dk = d_kv_comm.send_recv(dk) 142 | next_dv = d_kv_comm.send_recv(dv) 143 | d_kv_comm.commit() 144 | 145 | d_kv_comm.wait() 146 | 147 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) 148 | 149 | 150 | class RingFlashAttnFunc(torch.autograd.Function): 151 | @staticmethod 152 | def forward( 153 | ctx, 154 | q, 155 | k, 156 | v, 157 | dropout_p, 158 | softmax_scale, 159 | causal, 160 | window_size, 161 | softcap, 162 | alibi_slopes, 163 | deterministic, 164 | return_softmax, 165 | group, 166 | attn_type, 167 | attn_processor, 168 | ): 169 | if softmax_scale is None: 170 | softmax_scale = q.shape[-1] ** (-0.5) 171 | 172 | assert alibi_slopes is None 173 | k = k.contiguous() 174 | v = v.contiguous() 175 | out, softmax_lse = ring_flash_attn_forward( 176 | group, 177 | q, 178 | k, 179 | v, 180 | softmax_scale=softmax_scale, 181 | dropout_p=dropout_p, 182 | causal=causal, 183 | window_size=window_size, 184 | softcap=softcap, 185 | alibi_slopes=alibi_slopes, 186 | deterministic=False, 187 | attn_type=attn_type, 188 | attn_processor=attn_processor, 189 | ) 190 | # this should be out_padded 191 | ctx.save_for_backward(q, k, v, out, softmax_lse) 192 | ctx.dropout_p = dropout_p 193 | ctx.softmax_scale = softmax_scale 194 | ctx.causal = causal 195 | ctx.window_size = window_size 196 | ctx.softcap = softcap 197 | ctx.alibi_slopes = alibi_slopes 198 | ctx.deterministic = deterministic 199 | ctx.group = group 200 | ctx.attn_type = attn_type 201 | ctx.attn_processor = attn_processor 202 | return out if not return_softmax else (out, softmax_lse, None) 203 | 204 | @staticmethod 205 | def backward(ctx, dout, *args): 206 | q, k, v, out, softmax_lse = ctx.saved_tensors 207 | dq, dk, dv = ring_flash_attn_backward( 208 | ctx.group, 209 | dout, 210 | q, 211 | k, 212 | v, 213 | out, 214 | softmax_lse, 215 | softmax_scale=ctx.softmax_scale, 216 | dropout_p=ctx.dropout_p, 217 | causal=ctx.causal, 218 | window_size=ctx.window_size, 219 | softcap=ctx.softcap, 220 | alibi_slopes=ctx.alibi_slopes, 221 | deterministic=ctx.deterministic, 222 | attn_type=ctx.attn_type, 223 | ) 224 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None 225 | 226 | 227 | def ring_flash_attn_qkvpacked_func( 228 | qkv, 229 | dropout_p=0.0, 230 | softmax_scale=None, 231 | causal=False, 232 | window_size=(-1, -1), 233 | softcap=0.0, 234 | alibi_slopes=None, 235 | deterministic=False, 236 | return_attn_probs=False, 237 | group=None, 238 | attn_type: AttnType = AttnType.FA, 239 | ): 240 | return RingFlashAttnFunc.apply( 241 | qkv[:, :, 0], 242 | qkv[:, :, 1], 243 | qkv[:, :, 2], 244 | dropout_p, 245 | softmax_scale, 246 | causal, 247 | window_size, 248 | softcap, 249 | alibi_slopes, 250 | deterministic, 251 | return_attn_probs, 252 | group, 253 | attn_type, 254 | ) 255 | 256 | 257 | def ring_flash_attn_kvpacked_func( 258 | q, 259 | kv, 260 | dropout_p=0.0, 261 | softmax_scale=None, 262 | causal=False, 263 | window_size=(-1, -1), 264 | softcap=0.0, 265 | alibi_slopes=None, 266 | deterministic=False, 267 | return_attn_probs=False, 268 | group=None, 269 | attn_type: AttnType = AttnType.FA, 270 | ): 271 | return RingFlashAttnFunc.apply( 272 | q, 273 | kv[:, :, 0], 274 | kv[:, :, 1], 275 | dropout_p, 276 | softmax_scale, 277 | causal, 278 | window_size, 279 | softcap, 280 | alibi_slopes, 281 | deterministic, 282 | return_attn_probs, 283 | group, 284 | attn_type, 285 | ) 286 | 287 | 288 | def ring_flash_attn_func( 289 | q, 290 | k, 291 | v, 292 | dropout_p=0.0, 293 | softmax_scale=None, 294 | causal=False, 295 | window_size=(-1, -1), 296 | softcap=0.0, 297 | alibi_slopes=None, 298 | deterministic=False, 299 | return_attn_probs=False, 300 | group=None, 301 | attn_type: AttnType = AttnType.FA, 302 | attn_processor=None, 303 | ): 304 | return RingFlashAttnFunc.apply( 305 | q, 306 | k, 307 | v, 308 | dropout_p, 309 | softmax_scale, 310 | causal, 311 | window_size, 312 | softcap, 313 | alibi_slopes, 314 | deterministic, 315 | return_attn_probs, 316 | group, 317 | attn_type, 318 | attn_processor, 319 | ) 320 | -------------------------------------------------------------------------------- /yunchang/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from .attention import ( 5 | flash_attn_forward_aiter, 6 | flash_attn_forward, 7 | flash_attn_backward, 8 | flash_attn3_func_forward, 9 | flash_attn3_func_backward, 10 | pytorch_attn_forward, 11 | pytorch_attn_backward, 12 | flashinfer_attn_forward, 13 | flashinfer_attn_backbward, 14 | npu_fused_attn_forward, 15 | npu_fused_attn_backward, 16 | HAS_FLASH_ATTN_HOPPER, 17 | ) 18 | from enum import Enum, auto 19 | 20 | from yunchang.globals import ( 21 | HAS_AITER, 22 | HAS_FLASH_ATTN, 23 | HAS_SAGE_ATTENTION, 24 | HAS_SPARSE_SAGE_ATTENTION, 25 | HAS_NPU, 26 | ) 27 | 28 | if HAS_FLASH_ATTN: 29 | from flash_attn import flash_attn_func 30 | 31 | if HAS_SAGE_ATTENTION: 32 | import sageattention 33 | 34 | if HAS_SPARSE_SAGE_ATTENTION: 35 | from spas_sage_attn.autotune import SparseAttentionMeansim 36 | 37 | 38 | class AttnType(Enum): 39 | AITER = "aiter" 40 | FA = "fa" 41 | FA3 = "fa3" 42 | FLASHINFER = "flashinfer" 43 | TORCH = "torch" 44 | SAGE_AUTO = "sage_auto" 45 | SAGE_FP16 = "sage_fp16" 46 | SAGE_FP16_TRITON = "sage_fp16_triton" 47 | SAGE_FP8 = "sage_fp8" 48 | SAGE_FP8_SM90 = "sage_fp8_sm90" 49 | SPARSE_SAGE = "sparse_sage" 50 | NPU = 'npu' 51 | 52 | @classmethod 53 | def from_string(cls, s: str): 54 | for member in cls: 55 | if member.value == s: 56 | return member 57 | raise ValueError(f"'{s}' is not a valid {cls.__name__}") 58 | 59 | 60 | def select_flash_attn_impl( 61 | impl_type: AttnType, stage: str = "fwd-bwd", attn_processor: torch.nn.Module = None 62 | ): 63 | if impl_type == AttnType.AITER: 64 | if stage == "fwd-only": 65 | return flash_attn_forward_aiter 66 | elif stage == "bwd-only": 67 | raise ValueError("Aiter does not support bwd-only stage.") 68 | elif stage == "fwd-bwd": 69 | raise ValueError("Aiter does not support fwd-bwd stage.") 70 | else: 71 | raise ValueError(f"Unknown stage: {stage}") 72 | 73 | elif impl_type == AttnType.FA: 74 | if stage == "fwd-only": 75 | return flash_attn_forward 76 | elif stage == "bwd-only": 77 | return flash_attn_backward 78 | elif stage == "fwd-bwd": 79 | assert HAS_FLASH_ATTN, "FlashAttention is not available" 80 | return flash_attn_func 81 | else: 82 | raise ValueError(f"Unknown stage: {stage}") 83 | 84 | elif impl_type == AttnType.FA3: 85 | if stage == "fwd-only": 86 | return flash_attn3_func_forward 87 | elif stage == "bwd-only": 88 | return flash_attn3_func_backward 89 | elif stage == "fwd-bwd": 90 | 91 | def fn( 92 | q, 93 | k, 94 | v, 95 | dropout_p=0.0, 96 | softmax_scale=None, 97 | causal=False, 98 | *args, 99 | **kwargs, 100 | ): 101 | assert ( 102 | HAS_FLASH_ATTN_HOPPER 103 | ), "FlashAttention3 is not available! install it from https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release" 104 | # (q, k, v, softmax_scale=None, causal=False, window_size=(-1, -1), 105 | # deterministic=False, descale_q=None, descale_k=None, descale_v=None, gqa_parallel=False) 106 | from .attention import flash3_attn_func 107 | 108 | assert softmax_scale is not None, f"softmax_scale is required for FA3" 109 | assert ( 110 | dropout_p == 0.0 111 | ), f"dropout_p: {dropout_p} is not supported for FA3" 112 | return flash3_attn_func( 113 | q, k, v, softmax_scale=softmax_scale, causal=causal 114 | ) 115 | 116 | return fn 117 | else: 118 | raise ValueError(f"Unknown stage: {stage}") 119 | 120 | elif impl_type == AttnType.FLASHINFER: 121 | if stage == "fwd-only": 122 | return flashinfer_attn_forward 123 | elif stage == "bwd-only": 124 | return flashinfer_attn_backbward 125 | elif stage == "fwd-bwd": 126 | raise ValueError("FlashInfer does not support fwd-bwd stage.") 127 | else: 128 | raise ValueError(f"Unknown stage: {stage}") 129 | 130 | elif impl_type == AttnType.TORCH: 131 | if stage == "fwd-only": 132 | return pytorch_attn_forward 133 | elif stage == "bwd-only": 134 | return pytorch_attn_backward 135 | elif stage == "fwd-bwd": 136 | from yunchang.ring.ring_pytorch_attn import pytorch_attn_func 137 | 138 | return pytorch_attn_func 139 | else: 140 | raise ValueError(f"Unknown stage: {stage}") 141 | 142 | elif impl_type == AttnType.SAGE_AUTO: 143 | if not HAS_SAGE_ATTENTION: 144 | raise ImportError("SageAttention is not available!") 145 | if stage == "fwd-only": 146 | return partial( 147 | sageattention.sageattn, 148 | tensor_layout="NHD", 149 | return_lse=True, 150 | ) 151 | else: 152 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 153 | 154 | elif impl_type == AttnType.SAGE_FP16: 155 | if not HAS_SAGE_ATTENTION: 156 | raise ImportError("SageAttention is not available!") 157 | 158 | if stage == "fwd-only": 159 | return partial( 160 | sageattention.sageattn_qk_int8_pv_fp16_cuda, 161 | pv_accum_dtype="fp32", 162 | tensor_layout="NHD", 163 | return_lse=True, 164 | ) 165 | else: 166 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 167 | 168 | elif impl_type == AttnType.SAGE_FP16_TRITON: 169 | if not HAS_SAGE_ATTENTION: 170 | raise ImportError("SageAttention is not available!") 171 | 172 | if stage == "fwd-only": 173 | return partial( 174 | sageattention.sageattn_qk_int8_pv_fp16_triton, 175 | tensor_layout="NHD", 176 | return_lse=True, 177 | ) 178 | else: 179 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 180 | 181 | elif impl_type == AttnType.SAGE_FP8: 182 | if not HAS_SAGE_ATTENTION: 183 | raise ImportError("SageAttention is not available!") 184 | if stage == "fwd-only": 185 | return partial( 186 | sageattention.sageattn_qk_int8_pv_fp8_cuda, 187 | pv_accum_dtype="fp32+fp32", 188 | tensor_layout="NHD", 189 | return_lse=True, 190 | ) 191 | else: 192 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 193 | 194 | elif impl_type == AttnType.SAGE_FP8_SM90: 195 | if not HAS_SAGE_ATTENTION: 196 | raise ImportError("SageAttention is not available!") 197 | if stage == "fwd-only": 198 | return partial( 199 | sageattention.sageattn_qk_int8_pv_fp8_cuda_sm90, 200 | pv_accum_dtype="fp32+fp32", 201 | tensor_layout="NHD", 202 | return_lse=True, 203 | ) 204 | else: 205 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 206 | 207 | elif impl_type == AttnType.SAGE_FP16_TRITON: 208 | if not HAS_SAGE_ATTENTION: 209 | raise ImportError("SageAttention is not available!") 210 | if stage == "fwd-only": 211 | return partial( 212 | sageattention.sageattn_qk_int8_pv_fp16_triton, 213 | pv_accum_dtype="fp32", 214 | tensor_layout="NHD", 215 | quantization_backend="cuda", 216 | return_lse=True, 217 | ) 218 | else: 219 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 220 | 221 | elif impl_type == AttnType.SPARSE_SAGE: 222 | if not HAS_SPARSE_SAGE_ATTENTION: 223 | raise ImportError("SparseSageAttention is not available!") 224 | if not isinstance(attn_processor, SparseAttentionMeansim): 225 | raise ImportError( 226 | "SparseSageAttention is only available with a SparseAttentionProcessor class passed in" 227 | ) 228 | if stage == "fwd-only": 229 | 230 | def fn(q, k, v, causal=False, softmax_scale=None, *args, **kwargs): 231 | return ( 232 | attn_processor( 233 | q, 234 | k, 235 | v, 236 | is_causal=causal, 237 | scale=softmax_scale, 238 | tensor_layout="NHD", 239 | ), 240 | None, 241 | ) 242 | 243 | return fn 244 | else: 245 | raise ValueError(f"Unknown/Unsupported stage: {stage}") 246 | 247 | elif impl_type == AttnType.NPU: 248 | if stage == "fwd-only": 249 | return npu_fused_attn_forward 250 | elif stage == "bwd-only": 251 | return npu_fused_attn_backward 252 | elif stage == "fwd-bwd": 253 | return npu_fused_attn_forward 254 | else: 255 | raise ValueError(f"Unknown stage: {stage}") 256 | 257 | elif attn_processor is not None: 258 | return attn_processor 259 | else: 260 | raise ValueError(f"Unknown flash attention implementation: {impl_type}") 261 | 262 | 263 | __all__ = [ 264 | "flash_attn_forward", 265 | "flash_attn_backward", 266 | "flash_attn3_func_forward", 267 | "flash_attn3_func_forward", 268 | "flashinfer_attn_forward", 269 | "flashinfer_attn_backbward", 270 | "AttnType", 271 | ] 272 | -------------------------------------------------------------------------------- /yunchang/ring/ring_flash_attn_varlen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | from yunchang.globals import HAS_FLASH_ATTN 5 | 6 | if HAS_FLASH_ATTN: 7 | from flash_attn.flash_attn_interface import ( 8 | _flash_attn_varlen_forward, 9 | _flash_attn_varlen_backward, 10 | ) 11 | from .utils import ( 12 | RingComm, 13 | update_out_and_lse, 14 | ) 15 | 16 | try: 17 | from .triton_utils import ( 18 | flatten_varlen_lse, 19 | unflatten_varlen_lse, 20 | ) 21 | except: 22 | from .utils import ( 23 | flatten_varlen_lse, 24 | unflatten_varlen_lse, 25 | ) 26 | 27 | 28 | def ring_flash_attn_varlen_forward( 29 | process_group, 30 | q: torch.Tensor, 31 | k: torch.Tensor, 32 | v: torch.Tensor, 33 | cu_seqlens, 34 | max_seqlen, 35 | softmax_scale, 36 | dropout_p=0, 37 | causal=True, 38 | window_size=(-1, -1), 39 | softcap=0.0, 40 | alibi_slopes=None, 41 | deterministic=False, 42 | ): 43 | comm = RingComm(process_group) 44 | 45 | out = None 46 | lse = None 47 | next_k, next_v = None, None 48 | 49 | for step in range(comm.world_size): 50 | if step + 1 != comm.world_size: 51 | next_k: torch.Tensor = comm.send_recv(k) 52 | next_v: torch.Tensor = comm.send_recv(v) 53 | comm.commit() 54 | if not causal or step <= comm.rank: 55 | assert HAS_FLASH_ATTN, "FlashAttention is not available" 56 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( 57 | q, 58 | k, 59 | v, 60 | cu_seqlens, 61 | cu_seqlens, 62 | max_seqlen, 63 | max_seqlen, 64 | dropout_p, 65 | softmax_scale, 66 | causal=causal and step == 0, 67 | window_size=window_size, 68 | softcap=softcap, 69 | alibi_slopes=alibi_slopes, 70 | return_softmax=True and dropout_p > 0, 71 | ) 72 | block_lse = flatten_varlen_lse( 73 | block_lse, 74 | cu_seqlens=cu_seqlens, 75 | ) 76 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 77 | 78 | if step + 1 != comm.world_size: 79 | comm.wait() 80 | k = next_k 81 | v = next_v 82 | 83 | out = out.to(q.dtype) 84 | lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) 85 | return out, lse 86 | 87 | 88 | def ring_flash_attn_varlen_backward( 89 | process_group, 90 | dout, 91 | q, 92 | k, 93 | v, 94 | out, 95 | softmax_lse, 96 | cu_seqlens, 97 | max_seqlen, 98 | softmax_scale, 99 | dropout_p=0, 100 | causal=True, 101 | window_size=(-1, -1), 102 | softcap=0.0, 103 | alibi_slopes=None, 104 | deterministic=False, 105 | ): 106 | kv_comm = RingComm(process_group) 107 | d_kv_comm = RingComm(process_group) 108 | dq, dk, dv = None, None, None 109 | next_dk, next_dv = None, None 110 | 111 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 112 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 113 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 114 | 115 | next_dk, next_dv = None, None 116 | next_k, next_v = None, None 117 | for step in range(kv_comm.world_size): 118 | if step + 1 != kv_comm.world_size: 119 | next_k = kv_comm.send_recv(k) 120 | next_v = kv_comm.send_recv(v) 121 | kv_comm.commit() 122 | if step <= kv_comm.rank or not causal: 123 | bwd_causal = causal and step == 0 124 | assert HAS_FLASH_ATTN, "FlashAttention is not available" 125 | _flash_attn_varlen_backward( 126 | dout, 127 | q, 128 | k, 129 | v, 130 | out, 131 | softmax_lse, 132 | block_dq_buffer, 133 | block_dk_buffer, 134 | block_dv_buffer, 135 | cu_seqlens, 136 | cu_seqlens, 137 | max_seqlen, 138 | max_seqlen, 139 | dropout_p, 140 | softmax_scale, 141 | bwd_causal, 142 | window_size, 143 | softcap, 144 | alibi_slopes, 145 | deterministic, 146 | rng_state=None, 147 | ) 148 | 149 | if dq is None: 150 | dq = block_dq_buffer.to(torch.float32) 151 | dk = block_dk_buffer.to(torch.float32) 152 | dv = block_dv_buffer.to(torch.float32) 153 | else: 154 | dq += block_dq_buffer 155 | d_kv_comm.wait() 156 | dk = block_dk_buffer + next_dk 157 | dv = block_dv_buffer + next_dv 158 | elif step != 0: 159 | d_kv_comm.wait() 160 | dk = next_dk 161 | dv = next_dv 162 | 163 | if step + 1 != kv_comm.world_size: 164 | kv_comm.wait() 165 | k = next_k 166 | v = next_v 167 | 168 | next_dk = d_kv_comm.send_recv(dk) 169 | next_dv = d_kv_comm.send_recv(dv) 170 | d_kv_comm.commit() 171 | 172 | d_kv_comm.wait() 173 | 174 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) 175 | 176 | 177 | class RingFlashAttnVarlenFunc(torch.autograd.Function): 178 | @staticmethod 179 | def forward( 180 | ctx, 181 | q, 182 | k, 183 | v, 184 | cu_seqlens, 185 | max_seqlen, 186 | dropout_p, 187 | softmax_scale, 188 | causal, 189 | window_size, 190 | softcap, 191 | alibi_slopes, 192 | deterministic, 193 | return_softmax, 194 | group, 195 | ): 196 | if softmax_scale is None: 197 | softmax_scale = q.shape[-1] ** (-0.5) 198 | 199 | assert alibi_slopes is None 200 | k = k.contiguous() 201 | v = v.contiguous() 202 | out, softmax_lse = ring_flash_attn_varlen_forward( 203 | group, 204 | q, 205 | k, 206 | v, 207 | cu_seqlens, 208 | max_seqlen, 209 | softmax_scale=softmax_scale, 210 | dropout_p=dropout_p, 211 | causal=causal, 212 | window_size=window_size, 213 | softcap=softcap, 214 | alibi_slopes=alibi_slopes, 215 | deterministic=False, 216 | ) 217 | # this should be out_padded 218 | ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) 219 | ctx.max_seqlen = max_seqlen 220 | ctx.dropout_p = dropout_p 221 | ctx.softmax_scale = softmax_scale 222 | ctx.causal = causal 223 | ctx.window_size = window_size 224 | ctx.softcap = softcap 225 | ctx.alibi_slopes = alibi_slopes 226 | ctx.deterministic = deterministic 227 | ctx.group = group 228 | return out if not return_softmax else (out, softmax_lse, None) 229 | 230 | @staticmethod 231 | def backward(ctx, dout, *args): 232 | q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors 233 | dq, dk, dv = ring_flash_attn_varlen_backward( 234 | ctx.group, 235 | dout, 236 | q, 237 | k, 238 | v, 239 | out, 240 | softmax_lse, 241 | cu_seqlens, 242 | ctx.max_seqlen, 243 | softmax_scale=ctx.softmax_scale, 244 | dropout_p=ctx.dropout_p, 245 | causal=ctx.causal, 246 | window_size=ctx.window_size, 247 | softcap=ctx.softcap, 248 | alibi_slopes=ctx.alibi_slopes, 249 | deterministic=ctx.deterministic, 250 | ) 251 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None 252 | 253 | 254 | def ring_flash_attn_varlen_qkvpacked_func( 255 | qkv, 256 | cu_seqlens, 257 | max_seqlen, 258 | dropout_p=0.0, 259 | softmax_scale=None, 260 | causal=False, 261 | window_size=(-1, -1), # -1 means infinite context window 262 | softcap=0.0, 263 | alibi_slopes=None, 264 | deterministic=False, 265 | return_attn_probs=False, 266 | group=None, 267 | ): 268 | return RingFlashAttnVarlenFunc.apply( 269 | qkv[:, 0], 270 | qkv[:, 1], 271 | qkv[:, 2], 272 | cu_seqlens, 273 | max_seqlen, 274 | dropout_p, 275 | softmax_scale, 276 | causal, 277 | window_size, 278 | softcap, 279 | alibi_slopes, 280 | deterministic, 281 | return_attn_probs, 282 | group, 283 | ) 284 | 285 | 286 | def ring_flash_attn_varlen_kvpacked_func( 287 | q, 288 | kv, 289 | cu_seqlens, 290 | max_seqlen, 291 | dropout_p=0.0, 292 | softmax_scale=None, 293 | causal=False, 294 | window_size=(-1, -1), # -1 means infinite context window 295 | softcap=0.0, 296 | alibi_slopes=None, 297 | deterministic=False, 298 | return_attn_probs=False, 299 | group=None, 300 | ): 301 | return RingFlashAttnVarlenFunc.apply( 302 | q, 303 | kv[:, 0], 304 | kv[:, 1], 305 | cu_seqlens, 306 | max_seqlen, 307 | dropout_p, 308 | softmax_scale, 309 | causal, 310 | window_size, 311 | softcap, 312 | alibi_slopes, 313 | deterministic, 314 | return_attn_probs, 315 | group, 316 | ) 317 | 318 | 319 | def ring_flash_attn_varlen_func( 320 | q, 321 | k, 322 | v, 323 | cu_seqlens, 324 | max_seqlen, 325 | dropout_p=0.0, 326 | softmax_scale=None, 327 | causal=False, 328 | window_size=(-1, -1), # -1 means infinite context window 329 | softcap=0.0, 330 | alibi_slopes=None, 331 | deterministic=False, 332 | return_attn_probs=False, 333 | group=None, 334 | ): 335 | return RingFlashAttnVarlenFunc.apply( 336 | q, 337 | k, 338 | v, 339 | cu_seqlens, 340 | max_seqlen, 341 | dropout_p, 342 | softmax_scale, 343 | causal, 344 | window_size, 345 | softcap, 346 | alibi_slopes, 347 | deterministic, 348 | return_attn_probs, 349 | group, 350 | ) 351 | -------------------------------------------------------------------------------- /test/test_hybrid_attn_npu.py: -------------------------------------------------------------------------------- 1 | import os 2 | from yunchang import LongContextAttention, set_seq_parallel_pg, EXTRACT_FUNC_DICT 3 | import torch 4 | import torch_npu 5 | from torch_npu.contrib import transfer_to_npu 6 | import torch.distributed as dist 7 | 8 | 9 | from yunchang.kernels import AttnType 10 | from test_utils import attention_ref 11 | import argparse 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description="Test hybrid attention with configurable sequence length" 17 | ) 18 | parser.add_argument( 19 | "--seqlen", type=int, default=1024, help="sequence length (default: 1024)" 20 | ) 21 | parser.add_argument( 22 | "--use_bwd", 23 | action="store_true", 24 | help="whether to test backward pass (default: False)", 25 | ) 26 | parser.add_argument( 27 | "--sp_ulysses_degree", 28 | type=int, 29 | default=None, 30 | help="sp_ulysses_degree (default: world_size)", 31 | ) 32 | parser.add_argument( 33 | "--ring_impl_type", 34 | type=str, 35 | default="basic_npu", 36 | choices=["basic_npu"], 37 | help="ring implementation type (default: basic_npu)", 38 | ) 39 | parser.add_argument( 40 | "--causal", 41 | action="store_true", 42 | help="whether to use causal attention (default: False)", 43 | ) 44 | parser.add_argument( 45 | "--attn_impl", 46 | type=str, 47 | default="npu", 48 | choices=[ 49 | "npu", 50 | ], 51 | help="attention implementation type (default: torch)", 52 | ) 53 | parser.add_argument( 54 | "--sparse_sage_l1", 55 | type=float, 56 | default=0.07, 57 | help="l1 for sparse sage attention (default: 0.07)", 58 | ) 59 | parser.add_argument( 60 | "--sparse_sage_pv_l1", 61 | type=float, 62 | default=0.08, 63 | help="pv_l1 for sparse sage attention (default: 0.08)", 64 | ) 65 | parser.add_argument( 66 | "--sparse_sage_tune_mode", 67 | action="store_true", 68 | default=False, 69 | help="enable tune mode for sparse sage attention (default: False)", 70 | ) 71 | parser.add_argument( 72 | "--sparse_sage_tune_path", 73 | type=str, 74 | default="./sparsesage_autotune.pt", 75 | help="path to the sparse sage autotune results (default: ./sparsesage_autotune.pt)", 76 | ) 77 | return parser.parse_args() 78 | 79 | 80 | def log(msg, a, rank0_only=False): 81 | world_size = dist.get_world_size() 82 | rank = dist.get_rank() 83 | if rank0_only: 84 | if rank == 0: 85 | print( 86 | f"[Rank#0] {msg}: " 87 | f"max {a.abs().max().item()}, " 88 | f"mean {a.abs().mean().item()}", 89 | flush=True, 90 | ) 91 | return 92 | 93 | for i in range(world_size): 94 | if i == rank: 95 | if rank == 0: 96 | print(f"{msg}:") 97 | print( 98 | f"[Rank#{rank}] " 99 | f"max {a.abs().max().item()}, " 100 | f"mean {a.abs().mean().item()}", 101 | flush=True, 102 | ) 103 | dist.barrier() 104 | 105 | 106 | # test it with: 107 | # torchrun --nproc_per_node=4 test/test_hybrid_attn.py 108 | if __name__ == "__main__": 109 | args = parse_args() 110 | 111 | torch.random.manual_seed(0) 112 | 113 | dist.init_process_group("hccl") 114 | 115 | rank = dist.get_rank() 116 | world_size = dist.get_world_size() 117 | 118 | # Inference mainly uses fp16; ROCM flash attention with bf16 precision is slightly larger, will be fixed soon 119 | dtype = torch.bfloat16 120 | device = torch.device(f"npu:{rank}") 121 | 122 | batch_size = 1 123 | seqlen = args.seqlen 124 | nheads = 32 125 | d = 2048 // 32 126 | dropout_p = 0 127 | causal = args.causal 128 | deterministic = False 129 | 130 | use_bwd = args.use_bwd 131 | 132 | assert seqlen % world_size == 0 133 | assert d % 8 == 0 134 | 135 | ring_impl_type = args.ring_impl_type 136 | 137 | # Prepare inputs 138 | q = torch.randn( 139 | batch_size, 140 | seqlen, 141 | nheads, 142 | d, 143 | device=device, 144 | dtype=dtype, 145 | requires_grad=True if use_bwd else False, 146 | ) 147 | k = torch.randn( 148 | batch_size, 149 | seqlen, 150 | nheads, 151 | d, 152 | device=device, 153 | dtype=dtype, 154 | requires_grad=True if use_bwd else False, 155 | ) 156 | v = torch.randn( 157 | batch_size, 158 | seqlen, 159 | nheads, 160 | d, 161 | device=device, 162 | dtype=dtype, 163 | requires_grad=True if use_bwd else False, 164 | ) 165 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 166 | 167 | dist.broadcast(q, src=0) 168 | dist.broadcast(k, src=0) 169 | dist.broadcast(v, src=0) 170 | dist.broadcast(dout, src=0) 171 | 172 | # prepare process group for hybrid sequence parallelism 173 | use_ring_low_dim = True 174 | 175 | sp_ulysses_degree = ( 176 | args.sp_ulysses_degree if args.sp_ulysses_degree is not None else world_size 177 | ) 178 | sp_ring_degree = world_size // sp_ulysses_degree 179 | 180 | print( 181 | f"rank {rank}, sp_ulysses_degree: {sp_ulysses_degree}, sp_ring_degree: {sp_ring_degree}" 182 | ) 183 | 184 | set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size) 185 | 186 | # Use EXTRACT_FUNC_DICT to shard the tensors 187 | local_q = ( 188 | EXTRACT_FUNC_DICT[ring_impl_type]( 189 | q, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 190 | ) 191 | .detach() 192 | .clone() 193 | ) 194 | 195 | local_k = ( 196 | EXTRACT_FUNC_DICT[ring_impl_type]( 197 | k, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 198 | ) 199 | .detach() 200 | .clone() 201 | ) 202 | 203 | local_v = ( 204 | EXTRACT_FUNC_DICT[ring_impl_type]( 205 | v, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 206 | ) 207 | .detach() 208 | .clone() 209 | ) 210 | 211 | if use_bwd: 212 | local_q.requires_grad = True 213 | local_k.requires_grad = True 214 | local_v.requires_grad = True 215 | 216 | # Map argument to AttnType enum 217 | attn_impl_map = { 218 | "npu": AttnType.NPU, 219 | } 220 | 221 | usp_attn = LongContextAttention( 222 | ring_impl_type=ring_impl_type, 223 | attn_type=attn_impl_map[args.attn_impl], 224 | ) 225 | 226 | if rank == 0: 227 | print("#" * 30) 228 | print("# ds-ulysses forward:") 229 | print("#" * 30) 230 | 231 | # common test parameters 232 | window_size = (-1, -1) 233 | alibi_slopes, attn_bias = None, None 234 | dropout_mask = None 235 | 236 | print(f"before usp attn forward: {local_q.shape} {local_k.shape} {local_v.shape}") 237 | 238 | # usp attn forward 239 | local_out = usp_attn( 240 | local_q, 241 | local_k, 242 | local_v 243 | ) 244 | 245 | # extract local dout 246 | local_dout = ( 247 | EXTRACT_FUNC_DICT[ring_impl_type]( 248 | dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 249 | ) 250 | .detach() 251 | .clone() 252 | ) 253 | 254 | max_memory = torch.cuda.max_memory_allocated(device) / ( 255 | 1024 * 1024 256 | ) # Convert to MB 257 | print(f"[Rank#{rank}] Maximum GPU memory used: {max_memory:.2f} MB") 258 | torch.cuda.reset_peak_memory_stats(device) # Reset stats 259 | 260 | if rank == 0: 261 | print("#" * 30) 262 | print("# ds-ulysses backward:") 263 | print("#" * 30) 264 | 265 | # usp attn backward 266 | if use_bwd: 267 | local_out.backward(local_dout) 268 | 269 | dist.barrier() 270 | 271 | if rank == 0: 272 | print("#" * 30) 273 | print("# local forward:") 274 | print("#" * 30) 275 | # reference, a local flash attn 276 | softmax_scale = q.shape[-1] ** -0.5 277 | out_ref = torch_npu.npu_fusion_attention_v2(q, k, v, 278 | head_num = q.shape[-2], 279 | input_layout = "BSND", 280 | scale = softmax_scale, 281 | pre_tokens=65535, 282 | next_tokens=65535)[0] 283 | if rank == 0: 284 | print("#" * 30) 285 | print("# local forward:") 286 | print("#" * 30) 287 | 288 | if use_bwd: 289 | out_ref.backward(dout) 290 | 291 | dist.barrier() 292 | 293 | # check correctness 294 | # When checking correctness, use EXTRACT_FUNC_DICT for reference outputs 295 | local_out_ref = EXTRACT_FUNC_DICT[ring_impl_type]( 296 | out_ref, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 297 | ) 298 | 299 | log("local (rank) out", local_out, rank0_only=True) 300 | log("out (distributed) - out_ref (non-distributed) diff", local_out_ref - local_out) 301 | 302 | # log("out_ref (non-distributed) - out_pt_ref (gpu) diff", local_out_ref - local_out_pt_ref) 303 | 304 | torch.testing.assert_close(local_out, local_out_ref, atol=1e-1, rtol=0) 305 | # torch.testing.assert_close(out_ref, out_pt_ref, atol=1e-2, rtol=0) 306 | 307 | if use_bwd: 308 | local_dq_ref = EXTRACT_FUNC_DICT[ring_impl_type]( 309 | q.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 310 | ) 311 | log("load_dq", local_q.grad) 312 | log("dq diff", local_dq_ref - local_q.grad) 313 | 314 | local_dk_ref = EXTRACT_FUNC_DICT[ring_impl_type]( 315 | k.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 316 | ) 317 | log("load_dk", local_k.grad) 318 | log("dk diff", local_dk_ref - local_k.grad) 319 | 320 | local_dv_ref = EXTRACT_FUNC_DICT[ring_impl_type]( 321 | v.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree 322 | ) 323 | log("load_dv", local_v.grad) 324 | log("dv diff", local_dv_ref - local_v.grad) 325 | 326 | if dist.is_initialized(): 327 | dist.destroy_process_group() 328 | -------------------------------------------------------------------------------- /yunchang/comm/all_to_all.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation and Jiarui Fang 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | 6 | import torch 7 | 8 | from typing import Any, Tuple 9 | from torch import Tensor 10 | from torch.nn import Module 11 | 12 | import torch.distributed as dist 13 | 14 | 15 | def all_to_all_4D( 16 | input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False 17 | ) -> torch.tensor: 18 | """ 19 | all-to-all for QKV 20 | 21 | Args: 22 | input (torch.tensor): a tensor sharded along dim scatter dim 23 | scatter_idx (int): default 1 24 | gather_idx (int): default 2 25 | group : torch process group 26 | use_sync (bool): whether to synchronize after all-to-all 27 | 28 | Returns: 29 | torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) 30 | """ 31 | assert ( 32 | input.dim() == 4 33 | ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" 34 | 35 | seq_world_size = dist.get_world_size(group) 36 | 37 | if scatter_idx == 2 and gather_idx == 1: 38 | # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) 39 | bs, shard_seqlen, hc, hs = input.shape 40 | seqlen = shard_seqlen * seq_world_size 41 | shard_hc = hc // seq_world_size 42 | 43 | # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! 44 | # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) 45 | input_t = ( 46 | input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs) 47 | .transpose(0, 2) 48 | .contiguous() 49 | ) 50 | 51 | output = torch.empty_like(input_t) 52 | # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single 53 | # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head 54 | 55 | if seq_world_size > 1: 56 | dist.all_to_all_single(output, input_t, group=group) 57 | if use_sync: 58 | torch.cuda.synchronize() 59 | else: 60 | output = input_t 61 | # if scattering the seq-dim, transpose the heads back to the original dimension 62 | output = output.reshape(seqlen, bs, shard_hc, hs) 63 | 64 | # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) 65 | output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) 66 | 67 | return output 68 | 69 | elif scatter_idx == 1 and gather_idx == 2: 70 | # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) 71 | bs, seqlen, shard_hc, hs = input.shape 72 | hc = shard_hc * seq_world_size 73 | shard_seqlen = seqlen // seq_world_size 74 | seq_world_size = dist.get_world_size(group) 75 | 76 | # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! 77 | # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) 78 | input_t = ( 79 | input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs) 80 | .transpose(0, 3) 81 | .transpose(0, 1) 82 | .contiguous() 83 | .reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs) 84 | ) 85 | 86 | output = torch.empty_like(input_t) 87 | # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single 88 | # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head 89 | if seq_world_size > 1: 90 | dist.all_to_all_single(output, input_t, group=group) 91 | if use_sync: 92 | torch.cuda.synchronize() 93 | else: 94 | output = input_t 95 | 96 | # if scattering the seq-dim, transpose the heads back to the original dimension 97 | output = output.reshape(hc, shard_seqlen, bs, hs) 98 | 99 | # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) 100 | output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) 101 | 102 | return output 103 | else: 104 | raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") 105 | 106 | 107 | class SeqAllToAll4D(torch.autograd.Function): 108 | @staticmethod 109 | def forward( 110 | ctx: Any, 111 | group: dist.ProcessGroup, 112 | input: Tensor, 113 | scatter_idx: int, 114 | gather_idx: int, 115 | use_sync: bool = False, 116 | ) -> Tensor: 117 | 118 | ctx.group = group 119 | ctx.scatter_idx = scatter_idx 120 | ctx.gather_idx = gather_idx 121 | ctx.use_sync = use_sync 122 | return all_to_all_4D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) 123 | 124 | @staticmethod 125 | def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: 126 | return ( 127 | None, 128 | SeqAllToAll4D.apply( 129 | ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync 130 | ), 131 | None, 132 | None, 133 | None, 134 | ) 135 | 136 | 137 | def all_to_all_5D( 138 | input: torch.tensor, scatter_idx: int = 3, gather_idx: int = 1, group=None, use_sync: bool = False 139 | ) -> torch.tensor: 140 | """ 141 | all-to-all for QKV 142 | forward (bs, seqlen/N, 3, hc, hs) -> (bs, seqlen, 3, hc/N, hs) 143 | 144 | Args: 145 | input (torch.tensor): a tensor sharded along dim scatter dim 146 | scatter_idx (int): default 1 147 | gather_idx (int): default 2 148 | group : torch process group 149 | use_sync: whether to synchronize after all-to-all 150 | 151 | Returns: 152 | torch.tensor: resharded tensor (bs, seqlen/P, 3, hc, hs) 153 | """ 154 | assert ( 155 | input.dim() == 5 156 | ), f"input must be 5D tensor, got {input.dim()} and shape {input.shape}" 157 | 158 | seq_world_size = dist.get_world_size(group) 159 | 160 | if scatter_idx == 3 and gather_idx == 1: 161 | # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, 3, hc, hs) output: (bs, seqlen, 3, hc/P, hs) 162 | bs, shard_seqlen, t_cnt, hc, hs = input.shape 163 | 164 | assert t_cnt == 3 165 | seqlen = shard_seqlen * seq_world_size 166 | shard_hc = hc // seq_world_size 167 | 168 | # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! 169 | # (bs, seqlen/P, 3, hc, hs) -reshape-> (bs, seq_len/P, 3, P, hc/P, hs) -transpose(0,3)-> (P, seq_len/P, 3, bs, hc/P, hs) 170 | input_t = ( 171 | input.reshape(bs, shard_seqlen, 3, seq_world_size, shard_hc, hs) 172 | .transpose(0, 3) 173 | .contiguous() 174 | ) 175 | 176 | output = torch.empty_like(input_t) 177 | # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single 178 | # (P, seq_len/P, 3, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, 3, bs, hc/P, hs) scatter head 179 | if seq_world_size > 1: 180 | dist.all_to_all_single(output, input_t, group=group) 181 | if use_sync: 182 | torch.cuda.synchronize() 183 | else: 184 | output = input_t 185 | 186 | # if scattering the seq-dim, transpose the heads back to the original dimension 187 | output = output.reshape(seqlen, 3, bs, shard_hc, hs) 188 | 189 | # (seq_len, 3, bs, hc/P, hs) -trans-> (bs, seq_len, 3, hc/P, hs) 190 | output = output.transpose(0, 2).transpose(1, 2).contiguous() 191 | 192 | return output.reshape(bs, seqlen, 3, shard_hc, hs).contiguous() 193 | elif scatter_idx == 1 and gather_idx == 3: 194 | # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) 195 | bs, seqlen, _, shard_hc, hs = input.shape 196 | hc = shard_hc * seq_world_size 197 | shard_seqlen = seqlen // seq_world_size 198 | seq_world_size = dist.get_world_size(group) 199 | 200 | # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! 201 | # (bs, seqlen, 3, hc/P, hs) -reshape-> (bs, P, seq_len/P, 3, hc/P, hs) -transpose(0, 4)-> (hc/P, P, seqlen/P, 3, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, 3, bs, hs) 202 | input_t = ( 203 | input.reshape(bs, seq_world_size, shard_seqlen, 3, shard_hc, hs) 204 | .transpose(0, 4) 205 | .transpose(0, 1) 206 | .contiguous() 207 | .reshape(seq_world_size, shard_hc, shard_seqlen, 3, bs, hs) 208 | ) 209 | 210 | output = torch.empty_like(input_t) 211 | # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single 212 | # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head 213 | if seq_world_size > 1: 214 | dist.all_to_all_single(output, input_t, group=group) 215 | if use_sync: 216 | torch.cuda.synchronize() 217 | else: 218 | output = input_t 219 | 220 | # if scattering the seq-dim, transpose the heads back to the original dimension 221 | output = output.reshape(hc, shard_seqlen, 3, bs, hs) 222 | 223 | # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) 224 | output = output.transpose(0, 3).contiguous() 225 | 226 | return output.reshape(bs, shard_seqlen, 3, hc, hs).contiguous() 227 | else: 228 | raise RuntimeError("scatter_idx must be 1 or 3 and gather_idx must be 1 or 3") 229 | 230 | 231 | class SeqAllToAll5D(torch.autograd.Function): 232 | @staticmethod 233 | def forward( 234 | ctx: Any, 235 | group: dist.ProcessGroup, 236 | input: Tensor, 237 | scatter_idx: int = 3, 238 | gather_idx: int = 1, 239 | use_sync: bool = False, 240 | ) -> Tensor: 241 | 242 | ctx.group = group 243 | ctx.scatter_idx = scatter_idx 244 | ctx.gather_idx = gather_idx 245 | ctx.use_sync = use_sync 246 | 247 | return all_to_all_5D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) 248 | 249 | @staticmethod 250 | def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: 251 | return ( 252 | None, 253 | SeqAllToAll5D.apply( 254 | ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync 255 | ), 256 | None, 257 | None, 258 | None, 259 | ) 260 | -------------------------------------------------------------------------------- /yunchang/ring/zigzag_ring_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .utils import RingComm, update_out_and_lse 3 | from yunchang.kernels import AttnType, select_flash_attn_impl 4 | 5 | def zigzag_ring_flash_attn_forward( 6 | process_group, 7 | q: torch.Tensor, 8 | k: torch.Tensor, 9 | v: torch.Tensor, 10 | softmax_scale, 11 | dropout_p=0, 12 | causal=True, 13 | window_size=(-1, -1), 14 | softcap=0.0, 15 | alibi_slopes=None, 16 | deterministic=False, 17 | attn_type: AttnType = AttnType.FA, 18 | ): 19 | assert causal == True, "zigzag ring is meaningless for causal=False" 20 | comm = RingComm(process_group) 21 | 22 | block_seq_len = q.shape[1] // 2 23 | q1 = q[:, block_seq_len:] 24 | 25 | out = None 26 | lse = None 27 | next_k, next_v = None, None 28 | 29 | def forward(q, k, v, causal): 30 | fn = select_flash_attn_impl(attn_type, stage="fwd-only") 31 | block_out, block_lse = fn( 32 | q, 33 | k, 34 | v, 35 | dropout_p, 36 | softmax_scale, 37 | causal=causal, 38 | window_size=window_size, 39 | softcap=softcap, 40 | alibi_slopes=alibi_slopes, 41 | return_softmax=True and dropout_p > 0, 42 | ) 43 | return block_out, block_lse 44 | 45 | for step in range(comm.world_size): 46 | if step + 1 != comm.world_size: 47 | next_k: torch.Tensor = comm.send_recv(k) 48 | next_v: torch.Tensor = comm.send_recv(v) 49 | comm.commit() 50 | 51 | if step == 0: 52 | block_out, block_lse = forward(q, k, v, causal=True) 53 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 54 | elif step <= comm.rank: 55 | k0 = k[:, :block_seq_len] 56 | v0 = v[:, :block_seq_len] 57 | block_out, block_lse = forward(q, k0, v0, causal=False) 58 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 59 | else: 60 | block_out, block_lse = forward(q1, k, v, causal=False) 61 | out, lse = update_out_and_lse( 62 | out, 63 | lse, 64 | block_out, 65 | block_lse, 66 | slice_=(slice(None), slice(block_seq_len, None)), 67 | ) 68 | 69 | if step + 1 != comm.world_size: 70 | comm.wait() 71 | k = next_k 72 | v = next_v 73 | 74 | out = out.to(q.dtype) 75 | lse = lse.squeeze(dim=-1).transpose(1, 2) 76 | return out, lse 77 | 78 | 79 | def zigzag_ring_flash_attn_backward( 80 | process_group, 81 | dout, 82 | q, 83 | k, 84 | v, 85 | out, 86 | softmax_lse, 87 | softmax_scale, 88 | dropout_p=0, 89 | causal=True, 90 | window_size=(-1, -1), 91 | softcap=0.0, 92 | alibi_slopes=None, 93 | deterministic=False, 94 | attn_type: AttnType = AttnType.FA, 95 | ): 96 | assert causal == True, "zigzag ring is meaningless for causal=False" 97 | kv_comm = RingComm(process_group) 98 | d_kv_comm = RingComm(process_group) 99 | dq, dk, dv = None, None, None 100 | next_dk, next_dv = None, None 101 | next_k, next_v = None, None 102 | dk_comm_buffer, dv_comm_buffer = None, None 103 | 104 | dout1 = dout.chunk(2, dim=1)[1] 105 | q1 = q.chunk(2, dim=1)[1] 106 | out1 = out.chunk(2, dim=1)[1] 107 | softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() 108 | block_seq_len = q.shape[1] // 2 109 | 110 | # repeatly allocating buffer may be slow... 111 | dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 112 | dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 113 | dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 114 | 115 | def backward(dout, q, k, v, out, softmax_lse, causal): 116 | seqlen_q = q.shape[1] 117 | seqlen_kv = k.shape[1] 118 | fn = select_flash_attn_impl(attn_type, stage="bwd-only") 119 | fn( 120 | dout, 121 | q, 122 | k, 123 | v, 124 | out, 125 | softmax_lse, 126 | dq_buffer[:, :seqlen_q], 127 | dk_buffer[:, :seqlen_kv], 128 | dv_buffer[:, :seqlen_kv], 129 | dropout_p, 130 | softmax_scale, 131 | causal, 132 | window_size, 133 | softcap, 134 | alibi_slopes, 135 | deterministic, 136 | rng_state=None, 137 | ) 138 | 139 | for step in range(kv_comm.world_size): 140 | if step + 1 != kv_comm.world_size: 141 | next_k = kv_comm.send_recv(k) 142 | next_v = kv_comm.send_recv(v) 143 | kv_comm.commit() 144 | 145 | if step == 0: 146 | backward(dout, q, k, v, out, softmax_lse, causal=True) 147 | dq = dq_buffer.to(torch.float32) 148 | dk = dk_buffer.to(torch.float32) 149 | dv = dv_buffer.to(torch.float32) 150 | else: 151 | if step <= kv_comm.rank: 152 | k0 = k[:, :block_seq_len] 153 | v0 = v[:, :block_seq_len] 154 | backward(dout, q, k0, v0, out, softmax_lse, causal=False) 155 | dq += dq_buffer 156 | else: 157 | backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) 158 | # always use the first half in dq_buffer. 159 | dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] 160 | 161 | d_kv_comm.wait() 162 | dk_comm_buffer, dv_comm_buffer = dk, dv 163 | dk, dv = next_dk, next_dv 164 | 165 | if step <= kv_comm.rank: 166 | dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] 167 | dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] 168 | else: 169 | dk += dk_buffer 170 | dv += dv_buffer 171 | 172 | if step + 1 != kv_comm.world_size: 173 | kv_comm.wait() 174 | k = next_k 175 | v = next_v 176 | 177 | next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) 178 | next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) 179 | d_kv_comm.commit() 180 | 181 | d_kv_comm.wait() 182 | 183 | return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) 184 | 185 | 186 | class ZigZagRingFlashAttnFunc(torch.autograd.Function): 187 | @staticmethod 188 | def forward( 189 | ctx, 190 | q, 191 | k, 192 | v, 193 | dropout_p, 194 | softmax_scale, 195 | causal, 196 | window_size, 197 | softcap, 198 | alibi_slopes, 199 | deterministic, 200 | return_softmax, 201 | group, 202 | attn_type, 203 | ): 204 | if softmax_scale is None: 205 | softmax_scale = q.shape[-1] ** (-0.5) 206 | 207 | assert alibi_slopes is None 208 | k = k.contiguous() 209 | v = v.contiguous() 210 | out, softmax_lse = zigzag_ring_flash_attn_forward( 211 | group, 212 | q, 213 | k, 214 | v, 215 | softmax_scale=softmax_scale, 216 | dropout_p=dropout_p, 217 | causal=causal, 218 | window_size=window_size, 219 | softcap=softcap, 220 | alibi_slopes=alibi_slopes, 221 | deterministic=False, 222 | attn_type=attn_type, 223 | ) 224 | # this should be out_padded 225 | ctx.save_for_backward(q, k, v, out, softmax_lse) 226 | ctx.dropout_p = dropout_p 227 | ctx.softmax_scale = softmax_scale 228 | ctx.causal = causal 229 | ctx.window_size = window_size 230 | ctx.softcap = softcap 231 | ctx.alibi_slopes = alibi_slopes 232 | ctx.deterministic = deterministic 233 | ctx.group = group 234 | ctx.attn_type = attn_type 235 | return out if not return_softmax else (out, softmax_lse, None) 236 | 237 | @staticmethod 238 | def backward(ctx, dout, *args): 239 | q, k, v, out, softmax_lse = ctx.saved_tensors 240 | dq, dk, dv = zigzag_ring_flash_attn_backward( 241 | ctx.group, 242 | dout, 243 | q, 244 | k, 245 | v, 246 | out, 247 | softmax_lse, 248 | softmax_scale=ctx.softmax_scale, 249 | dropout_p=ctx.dropout_p, 250 | causal=ctx.causal, 251 | window_size=ctx.window_size, 252 | softcap=ctx.softcap, 253 | alibi_slopes=ctx.alibi_slopes, 254 | deterministic=ctx.deterministic, 255 | attn_type=ctx.attn_type, 256 | ) 257 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None 258 | 259 | 260 | def zigzag_ring_flash_attn_qkvpacked_func( 261 | qkv, 262 | dropout_p=0.0, 263 | softmax_scale=None, 264 | causal=False, 265 | window_size=(-1, -1), 266 | softcap=0.0, 267 | alibi_slopes=None, 268 | deterministic=False, 269 | return_attn_probs=False, 270 | group=None, 271 | attn_type: AttnType = AttnType.FA, 272 | ): 273 | return ZigZagRingFlashAttnFunc.apply( 274 | qkv[:, :, 0], 275 | qkv[:, :, 1], 276 | qkv[:, :, 2], 277 | dropout_p, 278 | softmax_scale, 279 | causal, 280 | window_size, 281 | softcap, 282 | alibi_slopes, 283 | deterministic, 284 | return_attn_probs, 285 | group, 286 | attn_type, 287 | ) 288 | 289 | 290 | def zigzag_ring_flash_attn_kvpacked_func( 291 | q, 292 | kv, 293 | dropout_p=0.0, 294 | softmax_scale=None, 295 | causal=False, 296 | window_size=(-1, -1), 297 | softcap=0.0, 298 | alibi_slopes=None, 299 | deterministic=False, 300 | return_attn_probs=False, 301 | group=None, 302 | attn_type: AttnType = AttnType.FA, 303 | ): 304 | return ZigZagRingFlashAttnFunc.apply( 305 | q, 306 | kv[:, :, 0], 307 | kv[:, :, 1], 308 | dropout_p, 309 | softmax_scale, 310 | causal, 311 | window_size, 312 | softcap, 313 | alibi_slopes, 314 | deterministic, 315 | return_attn_probs, 316 | group, 317 | attn_type, 318 | ) 319 | 320 | 321 | def zigzag_ring_flash_attn_func( 322 | q, 323 | k, 324 | v, 325 | dropout_p=0.0, 326 | softmax_scale=None, 327 | causal=False, 328 | window_size=(-1, -1), 329 | softcap=0.0, 330 | alibi_slopes=None, 331 | deterministic=False, 332 | return_attn_probs=False, 333 | group=None, 334 | attn_type: AttnType = AttnType.FA, 335 | attn_processor=None, 336 | ): 337 | return ZigZagRingFlashAttnFunc.apply( 338 | q, 339 | k, 340 | v, 341 | dropout_p, 342 | softmax_scale, 343 | causal, 344 | window_size, 345 | softcap, 346 | alibi_slopes, 347 | deterministic, 348 | return_attn_probs, 349 | group, 350 | attn_type, 351 | ) 352 | -------------------------------------------------------------------------------- /yunchang/ring/stripe_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from yunchang.kernels import select_flash_attn_impl, AttnType 3 | from .utils import RingComm, update_out_and_lse 4 | 5 | 6 | def stripe_flash_attn_forward( 7 | process_group, 8 | q: torch.Tensor, 9 | k: torch.Tensor, 10 | v: torch.Tensor, 11 | softmax_scale, 12 | dropout_p=0, 13 | causal=True, 14 | window_size=(-1, -1), 15 | softcap=0.0, 16 | alibi_slopes=None, 17 | deterministic=False, 18 | attn_type: AttnType = AttnType.FA, 19 | ): 20 | assert ( 21 | causal 22 | ), "stripe flash attn only supports causal attention, if not causal, use ring flash attn instead" 23 | comm = RingComm(process_group) 24 | 25 | out = None 26 | lse = None 27 | 28 | next_k, next_v = None, None 29 | 30 | for step in range(comm.world_size): 31 | if step + 1 != comm.world_size: 32 | next_k: torch.Tensor = comm.send_recv(k) 33 | next_v: torch.Tensor = comm.send_recv(v) 34 | comm.commit() 35 | 36 | if step <= comm.rank: 37 | fn = select_flash_attn_impl(attn_type, stage="fwd-only") 38 | block_out, block_lse = fn( 39 | q, 40 | k, 41 | v, 42 | dropout_p, 43 | softmax_scale, 44 | causal=causal, 45 | window_size=window_size, 46 | softcap=softcap, 47 | alibi_slopes=alibi_slopes, 48 | return_softmax=True and dropout_p > 0, 49 | ) 50 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 51 | else: 52 | fn = select_flash_attn_impl(attn_type, stage="fwd-only") 53 | block_out, block_lse = fn( 54 | q[:, 1:], 55 | k[:, :-1], 56 | v[:, :-1], 57 | dropout_p, 58 | softmax_scale, 59 | causal=causal, 60 | window_size=window_size, 61 | softcap=softcap, 62 | alibi_slopes=alibi_slopes, 63 | return_softmax=True and dropout_p > 0, 64 | ) 65 | out, lse = update_out_and_lse( 66 | out, lse, block_out, block_lse, slice_=(slice(None), slice(1, None)) 67 | ) 68 | 69 | if step + 1 != comm.world_size: 70 | comm.wait() 71 | k = next_k 72 | v = next_v 73 | 74 | out = out.to(q.dtype) 75 | lse = lse.squeeze(dim=-1).transpose(1, 2) 76 | return out, lse 77 | 78 | 79 | def stripe_flash_attn_backward( 80 | process_group, 81 | dout, 82 | q, 83 | k, 84 | v, 85 | out, 86 | softmax_lse, 87 | softmax_scale, 88 | dropout_p=0, 89 | causal=True, 90 | window_size=(-1, -1), 91 | softcap=0.0, 92 | alibi_slopes=None, 93 | deterministic=False, 94 | attn_type: AttnType = AttnType.FA, 95 | ): 96 | assert ( 97 | causal 98 | ), "stripe flash attn only supports causal attention, if not causal, ring flash attn instead" 99 | kv_comm = RingComm(process_group) 100 | d_kv_comm = RingComm(process_group) 101 | dq, dk, dv = None, None, None 102 | next_dk, next_dv = None, None 103 | next_k, next_v = None, None 104 | dk_comm_buffer, dv_comm_buffer = None, None 105 | 106 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 107 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 108 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 109 | for step in range(kv_comm.world_size): 110 | if step + 1 != kv_comm.world_size: 111 | next_k = kv_comm.send_recv(k) 112 | next_v = kv_comm.send_recv(v) 113 | kv_comm.commit() 114 | 115 | shift_causal = step > kv_comm.rank 116 | softmax_lse_1 = None 117 | if not shift_causal: 118 | fn = select_flash_attn_impl(attn_type, stage="bwd-only") 119 | fn( 120 | dout, 121 | q, 122 | k, 123 | v, 124 | out, 125 | softmax_lse, 126 | block_dq_buffer, 127 | block_dk_buffer, 128 | block_dv_buffer, 129 | dropout_p, 130 | softmax_scale, 131 | causal, 132 | window_size, 133 | softcap, 134 | alibi_slopes, 135 | deterministic, 136 | rng_state=None, 137 | ) 138 | else: 139 | if softmax_lse_1 is None: 140 | # lazy init, since the last rank does not need softmax_lse_1 141 | softmax_lse_1 = softmax_lse[:, :, 1:].contiguous() 142 | fn = select_flash_attn_impl(attn_type, stage="bwd-only") 143 | fn( 144 | dout[:, 1:], 145 | q[:, 1:], 146 | k[:, :-1], 147 | v[:, :-1], 148 | out[:, 1:], 149 | softmax_lse_1, 150 | block_dq_buffer[:, 1:], 151 | block_dk_buffer[:, :-1], 152 | block_dv_buffer[:, :-1], 153 | dropout_p, 154 | softmax_scale, 155 | causal, 156 | window_size, 157 | softcap, 158 | alibi_slopes, 159 | deterministic, 160 | rng_state=None, 161 | ) 162 | 163 | if dq is None: 164 | dq = block_dq_buffer.to(torch.float32) 165 | dk = block_dk_buffer.to(torch.float32) 166 | dv = block_dv_buffer.to(torch.float32) 167 | else: 168 | if not shift_causal: 169 | dq += block_dq_buffer 170 | else: 171 | dq[:, 1:] += block_dq_buffer[:, 1:] 172 | d_kv_comm.wait() 173 | dk_comm_buffer, dv_comm_buffer = dk, dv 174 | dk = next_dk 175 | dv = next_dv 176 | 177 | if not shift_causal: 178 | dk = block_dk_buffer + dk 179 | dv = block_dv_buffer + dv 180 | else: 181 | dk[:, :-1] += block_dk_buffer[:, :-1] 182 | dv[:, :-1] += block_dv_buffer[:, :-1] 183 | 184 | if step + 1 != kv_comm.world_size: 185 | kv_comm.wait() 186 | k = next_k 187 | v = next_v 188 | 189 | next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) 190 | next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) 191 | d_kv_comm.commit() 192 | 193 | d_kv_comm.wait() 194 | 195 | return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) 196 | 197 | 198 | class StripeFlashAttnFunc(torch.autograd.Function): 199 | @staticmethod 200 | def forward( 201 | ctx, 202 | q, 203 | k, 204 | v, 205 | dropout_p, 206 | softmax_scale, 207 | causal, 208 | window_size, 209 | softcap, 210 | alibi_slopes, 211 | deterministic, 212 | return_softmax, 213 | group, 214 | attn_type: AttnType = AttnType.FA, 215 | ): 216 | if softmax_scale is None: 217 | softmax_scale = q.shape[-1] ** (-0.5) 218 | 219 | assert alibi_slopes is None 220 | k = k.contiguous() 221 | v = v.contiguous() 222 | out, softmax_lse = stripe_flash_attn_forward( 223 | group, 224 | q, 225 | k, 226 | v, 227 | softmax_scale=softmax_scale, 228 | dropout_p=dropout_p, 229 | causal=causal, 230 | window_size=window_size, 231 | softcap=softcap, 232 | alibi_slopes=alibi_slopes, 233 | deterministic=False, 234 | attn_type=attn_type, 235 | ) 236 | # this should be out_padded 237 | ctx.save_for_backward(q, k, v, out, softmax_lse) 238 | ctx.dropout_p = dropout_p 239 | ctx.softmax_scale = softmax_scale 240 | ctx.causal = causal 241 | ctx.window_size = window_size 242 | ctx.softcap = softcap 243 | ctx.alibi_slopes = alibi_slopes 244 | ctx.deterministic = deterministic 245 | ctx.group = group 246 | ctx.attn_type = attn_type 247 | return out if not return_softmax else (out, softmax_lse, None) 248 | 249 | @staticmethod 250 | def backward(ctx, dout, *args): 251 | q, k, v, out, softmax_lse = ctx.saved_tensors 252 | dq, dk, dv = stripe_flash_attn_backward( 253 | ctx.group, 254 | dout, 255 | q, 256 | k, 257 | v, 258 | out, 259 | softmax_lse, 260 | softmax_scale=ctx.softmax_scale, 261 | dropout_p=ctx.dropout_p, 262 | causal=ctx.causal, 263 | window_size=ctx.window_size, 264 | softcap=ctx.softcap, 265 | alibi_slopes=ctx.alibi_slopes, 266 | deterministic=ctx.deterministic, 267 | attn_type=ctx.attn_type, 268 | ) 269 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None 270 | 271 | 272 | def stripe_flash_attn_qkvpacked_func( 273 | qkv, 274 | dropout_p=0.0, 275 | softmax_scale=None, 276 | causal=False, 277 | window_size=(-1, -1), # -1 means infinite context window 278 | softcap=0.0, 279 | alibi_slopes=None, 280 | deterministic=False, 281 | return_attn_probs=False, 282 | group=None, 283 | attn_type: AttnType = AttnType.FA, 284 | ): 285 | return StripeFlashAttnFunc.apply( 286 | qkv[:, :, 0], 287 | qkv[:, :, 1], 288 | qkv[:, :, 2], 289 | dropout_p, 290 | softmax_scale, 291 | causal, 292 | window_size, 293 | softcap, 294 | alibi_slopes, 295 | deterministic, 296 | return_attn_probs, 297 | group, 298 | attn_type, 299 | ) 300 | 301 | 302 | def stripe_flash_attn_kvpacked_func( 303 | q, 304 | kv, 305 | dropout_p=0.0, 306 | softmax_scale=None, 307 | causal=False, 308 | window_size=(-1, -1), # -1 means infinite context window 309 | softcap=0.0, 310 | alibi_slopes=None, 311 | deterministic=False, 312 | return_attn_probs=False, 313 | group=None, 314 | attn_type: AttnType = AttnType.FA, 315 | ): 316 | return StripeFlashAttnFunc.apply( 317 | q, 318 | kv[:, :, 0], 319 | kv[:, :, 1], 320 | dropout_p, 321 | softmax_scale, 322 | causal, 323 | window_size, 324 | softcap, 325 | alibi_slopes, 326 | deterministic, 327 | return_attn_probs, 328 | group, 329 | attn_type, 330 | ) 331 | 332 | 333 | def stripe_flash_attn_func( 334 | q, 335 | k, 336 | v, 337 | dropout_p=0.0, 338 | softmax_scale=None, 339 | causal=False, 340 | window_size=(-1, -1), # -1 means infinite context window 341 | softcap=0.0, 342 | alibi_slopes=None, 343 | deterministic=False, 344 | return_attn_probs=False, 345 | group=None, 346 | attn_type: AttnType = AttnType.FA, 347 | attn_processor=None, 348 | ): 349 | return StripeFlashAttnFunc.apply( 350 | q, 351 | k, 352 | v, 353 | dropout_p, 354 | softmax_scale, 355 | causal, 356 | window_size, 357 | softcap, 358 | alibi_slopes, 359 | deterministic, 360 | return_attn_probs, 361 | group, 362 | attn_type, 363 | ) 364 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [2024] [feifeibear (Jiarui Fang)] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YunChang: A Unified Sequence Parallel (USP) Attention for Long Context LLM Model Training and Inference. 2 | 3 | [\[Tech Report\] USP: A Unified Sequence Parallelism Approach for Long Context Generative AI](https://arxiv.org/abs/2405.07719) 4 | 5 | 6 |
7 |
8 |
34 |
35 |
161 |
162 |
180 |
181 |
186 |
187 |