├── .gitignore ├── LICENSE ├── README.md ├── setup.py ├── tests ├── precision_test.py ├── run_speed_tests.py ├── speed_test.py ├── train.py └── varlen_test.py └── wind_rwkv ├── __init__.py └── rwkv7 ├── __init__.py ├── backstepping_longhead ├── __init__.py ├── backstepping_longhead.cpp ├── backstepping_longhead.cu └── backstepping_longhead.py ├── backstepping_smallhead ├── __init__.py ├── backstepping_smallhead.cpp ├── backstepping_smallhead.cu └── backstepping_smallhead.py ├── chunked_cuda ├── __init__.py ├── chunked_cuda.cpp ├── chunked_cuda.cu ├── chunked_cuda.py └── tile.cuh ├── chunked_cuda_varlen ├── __init__.py ├── chunked_cuda_varlen.cpp ├── chunked_cuda_varlen.cu ├── chunked_cuda_varlen.py └── tile.cuh └── triton_bighead.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Johan Sokrates Wind 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wind RWKV 2 | A repository with optimized kernels for [RWKV](https://github.com/BlinkDL/RWKV-LM/) language models. Currently focused on RWKV-7. 3 | 4 | ## Kernel benchmarks for RWKV-7 5 | The kernels were timed using [tests/speed_test.py](tests/speed_test.py) with modeldim 4096 and varying (batch size, head size, sequence length) as labeled in the table. 6 | 7 | ### H100 8 | | Kernel | (8,64,4096) | (8,128,4096) | (8,256,4096) | (1,256,32768) | Peak VRAM[^1] | Typical error | 9 | |:----------------------------|------:|-------:|-------:|--------:|--------:|-----:| 10 | | Chunked bf16 | 8 ms | 11 ms | 54 ms | 224 ms | 5 - 8 GB | 5e-3 | 11 | | Backstepping fp32 longhead | 23 ms | 46 ms | 80 ms | 124 ms | 8 - 14 GB | 9e-5 | 12 | | Backstepping fp32 smallhead | 17 ms | 101 ms | 862 ms | 1802 ms | 7 - 13 GB | 9e-5 | 13 | | Triton bighead fp32 | 66 ms | 87 ms | 168 ms | 1175 ms | 6 - 12 GB | 5e-5 | 14 | | Triton bighead bf16 | [^2]| 29 ms | 59 ms | 358 ms | 6 - 12 GB | 5e-3 | 15 | | FLA chunk_rwkv7 | 64 ms | 62 ms | 89 ms | 93 ms |12 - 13 GB | 4e-3 | 16 | [^1]: Smallest peak VRAM was typically for (8,64,4096) and largest for (8,256,4096). 17 | [^2]: Triton fails to compile the kernel, only seen on H100. 18 | 19 | ### MI300X 20 | | Kernel | (8,64,4096) | (8,128,4096) | (8,256,4096) | (1,256,32768) | Peak VRAM[^1] | Typical error | 21 | |:----------------------------|------:|-------:|-------:|--------:|--------:|-----:| 22 | | Backstepping fp32 longhead | 29 ms | 39 ms | 75 ms | 162 ms | 8 - 14 GB | 9e-5 | 23 | | Backstepping fp32 smallhead |251 ms | 757 ms |2706 ms |15025 ms | 7 - 13 GB | 9e-5 | 24 | | Triton bighead fp32 | 67 ms | 100 ms | 287 ms | 2073 ms | 6 - 12 GB | 5e-5 | 25 | | Triton bighead bf16 | 42 ms | 72 ms | 198 ms | 1453 ms | 6 - 12 GB | 5e-3 | 26 | | FLA chunk_rwkv7 | 52 ms | 61 ms | 98 ms | 202 ms |12 - 13 GB | 4e-3 | 27 | 28 | ## Kernel descriptions 29 | The RWKV-7 kernels all compute the following: 30 | ```python 31 | def naive(r,w,k,v,a,b,s): 32 | y = th.empty_like(v) 33 | for t in range(w.shape[1]): 34 | s = s * th.exp(-th.exp(w[:,t,:,None,:])) + s @ a[:,t,:,:,None] * b[:,t,:,None,:] + v[:,t,:,:,None] * k[:,t,:,None,:] 35 | y[:,t,:,:,None] = s @ r[:,t,:,:,None] 36 | return y, s 37 | ``` 38 | Here `r`,`w`,`k`,`v`,`a` and `b` have shape [batch size, sequence length, num heads, head size], while the initial state `s` has shape [batch size, num heads, head size, head size]. All inputs and outputs are bfloat16 precision. 39 | 40 | ### [Chunked bf16](wind_rwkv/rwkv7/chunked_cuda/chunked_cuda.cu) 41 | This is the fastest kernel when applicable. It processes the sequence in chunks of length 16 (chunked formulation) and uses Ampere (CUDA SM80+, i.e., A100 and later) instructions for fast bfloat16 matmuls. 42 | ### [Backstepping fp32 smallhead](wind_rwkv/rwkv7/backstepping_smallhead/backstepping_smallhead.cu) 43 | This is essentially the [official](https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/cuda/wkv7_cuda.cu) kernel which was used to train the [RWKV-7 World models](https://huggingface.co/BlinkDL/rwkv-7-world). Calculates gradients by iterating the state backwards in time (max 15 steps). This makes the code simple, but requires 32-bit floats and limits the decay to ca. 0.5. 44 | ### [Backstepping fp32 longhead](wind_rwkv/rwkv7/backstepping_longhead/backstepping_longhead.cu) 45 | Backstepping fp32 smallhead becomes very slow for large head sizes, since the full state is kept in registers, which overflow into global memory. To fix this, backstepping fp32 longhead uses the observation that the columns of the state are essentially updated independently. So it processes blocks of 64 or 32 columns indepdently. This increasing parallelization, and keeps less state in shared memory at a time, while keeping most of the simplicity of backstepping fp32 smallhead. 46 | ### [Triton bighead](wind_rwkv/rwkv7/triton_bighead.py) 47 | A simple chunked kernel written in triton. The kernel stores intermediate states in global memory instead of shared memory, so it handles large head sizes (like 1024) without crashing. It takes a flag to choose fp32 or bf16 precision[^3] which affects all matmuls inside the triton kernel. 48 | ### [FLA chunk_rwkv7](https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/rwkv7/chunk.py) 49 | RWKV-7 triton kernel from [Flash Linear Attention](https://github.com/fla-org/flash-linear-attention). Chunked implementation with partial sequence length parallelization. 50 | 51 | [^3]: The kernel also supports tf32 precision for matmuls, but tf32 seems to run into bugs in the triton language, so I didn't expose it. 52 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='wind_rwkv', 5 | version='0.2', 6 | description='Optimized kernels for RWKV models', 7 | author='Johan Sokrates Wind', 8 | author_email='johanswi@math.uio.no', 9 | url='https://github.com/johanwind/wind_rwkv', 10 | packages=find_packages(), 11 | package_data={ 12 | "wind_rwkv.rwkv7.backstepping_longhead": ["*.cu", "*.cpp"], 13 | "wind_rwkv.rwkv7.backstepping_smallhead": ["*.cu", "*.cpp"], 14 | "wind_rwkv.rwkv7.chunked_cuda": ["*.cu", "*.cpp", "*.cuh"], 15 | }, 16 | license='MIT', 17 | classifiers=[ 18 | 'Programming Language :: Python :: 3', 19 | 'License :: OSI Approved :: MIT License', 20 | 'Operating System :: OS Independent', 21 | 'Topic :: Scientific/Engineering :: Artificial Intelligence' 22 | ], 23 | python_requires='>=3.7', 24 | install_requires=[ 25 | 'triton>=3.0', 26 | 'ninja' 27 | ] 28 | ) 29 | -------------------------------------------------------------------------------- /tests/precision_test.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from wind_rwkv.rwkv7 import * 3 | 4 | def naive(r,w,k,v,a,b,s0): 5 | if s0 is None: s0 = th.zeros(w.shape[0],w.shape[2],w.shape[3],w.shape[3], device=w.device) 6 | dtype = w.dtype 7 | r,w,k,v,a,b,s = [i.double() for i in [r,w,k,v,a,b,s0]] 8 | y = th.empty_like(v) 9 | for t in range(w.shape[1]): 10 | s = s * th.exp(-th.exp(w[:,t,:,None,:])) + s @ a[:,t,:,:,None] * b[:,t,:,None,:] + v[:,t,:,:,None] * k[:,t,:,None,:] 11 | y[:,t,:,:,None] = s @ r[:,t,:,:,None] 12 | return y.to(dtype), s.to(dtype) 13 | 14 | def grad_check(f1, f2, params, backward = True, aux=()): 15 | if backward: params = [p.clone().requires_grad_() for p in params] 16 | y1 = f1(*params,*aux) 17 | y2 = f2(*params,*aux) 18 | def rel(a,b): return (a-b).norm()/max(b.norm(),1e-30) 19 | print('Forward rel. error'+'s'*(len(y1)>1)) 20 | for a,b in zip(y1,y2): 21 | print(f'{rel(a,b):.2e} ({b.norm():.0e})') 22 | 23 | if not backward: return 24 | 25 | dy = tuple(th.randn_like(i) for i in y1) 26 | d1 = th.autograd.grad(y1, params, grad_outputs=dy) 27 | for p in params: 28 | if p.grad is not None: 29 | p.grad.random_() # So th.empty doesn't recover the gradient 30 | p.grad = None 31 | d2 = th.autograd.grad(y2, params, grad_outputs=dy) 32 | print('Gradient rel. errors') 33 | for a,b in zip(d1,d2): 34 | print(f'{rel(a,b):.2e} ({b.norm():.0e})') 35 | 36 | batchsz = 2 37 | modeldim = 1024 38 | headsz = 128 39 | seqlen = 128 40 | def gen_rwkv7_data(): 41 | q,w,k,v,a,b = th.randn(6, batchsz, seqlen, modeldim//headsz, headsz, dtype = th.bfloat16, device = 'cuda') 42 | w = -th.nn.functional.softplus(w)-0.5 43 | a = th.nn.functional.normalize(a, p=2, dim=-1) 44 | b = -a*th.sigmoid(b) 45 | s0 = th.randn(batchsz, modeldim//headsz, headsz, headsz, dtype = th.bfloat16, device = 'cuda') 46 | return q,w,k,v,a,b,s0 47 | 48 | th.manual_seed(0) 49 | params = gen_rwkv7_data() 50 | 51 | if 0: 52 | print('FLA chunk_rwkv7') 53 | from fla.ops.rwkv7 import chunk_rwkv7 54 | def attn_fla(r,w,k,v,a,b,s0): 55 | y,sT = chunk_rwkv7(r,-w.exp(),k,v,a,b, initial_state=s0.mT) 56 | return y, sT.mT 57 | grad_check(attn_fla, naive, params) 58 | 59 | print('Triton bighead bf16') 60 | grad_check(attn_triton_bighead_bf16, naive, params) 61 | print('Triton bighead fp16') 62 | grad_check(attn_triton_bighead_fp16, naive, params) 63 | print('Triton bighead fp32') 64 | grad_check(attn_triton_bighead_fp32, naive, params) 65 | 66 | print('Chunked cuda') 67 | load_chunked_cuda(headsz) 68 | grad_check(attn_chunked_cuda, naive, params) 69 | print('Chunked cuda varlen') 70 | load_chunked_cuda_varlen(headsz) 71 | def wrap_varlen(r,w,k,v,a,b,s0): 72 | B,T,H,C = r.shape 73 | r,w,k,v,a,b = [i.view(B*T,H,C) for i in [r,w,k,v,a,b]] 74 | cu_seqlens = th.arange(B+1, device=w.device)*T 75 | y,sT = attn_chunked_cuda_varlen(r,w,k,v,a,b,s0,cu_seqlens) 76 | return y.view(B,T,H,C), sT 77 | grad_check(wrap_varlen, naive, params) 78 | 79 | print('Backstepping smallhead fp32') 80 | load_backstepping_smallhead(headsz) 81 | grad_check(attn_backstepping_smallhead, naive, params) 82 | 83 | print('Backstepping longhead fp32') 84 | load_backstepping_longhead(headsz) 85 | grad_check(attn_backstepping_longhead, naive, params) 86 | -------------------------------------------------------------------------------- /tests/run_speed_tests.py: -------------------------------------------------------------------------------- 1 | import subprocess, sys 2 | for alg in ['smallhead', 'longhead', 'chunked', 'bighead_fp32', 'bighead_bf16', 'fla']: 3 | print(alg) 4 | for forward in [False,True]: 5 | for (batchsz,modeldim,headsz,seqlen) in [(8,4096,64,4096), (8,4096,128,4096), (8,4096,256,4096), (1,4096,256,4096*8)]: 6 | out = subprocess.run(f"python speed_test.py --alg {alg} --batchsz {batchsz} --modeldim {modeldim} --headsz {headsz} --seqlen {seqlen} "+"--forward"*forward, shell=True, capture_output=True, text=True).stdout 7 | try: 8 | out = out.split('\n') 9 | gb = float(out[1].split()[2]) 10 | ms = float(out[2].split()[1]) 11 | except Exception: 12 | ms = gb = float('nan') 13 | print(out, file=sys.stderr) 14 | print((batchsz,modeldim,headsz,seqlen,forward), ms, gb) 15 | -------------------------------------------------------------------------------- /tests/speed_test.py: -------------------------------------------------------------------------------- 1 | import argparse, triton, torch as th 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--batchsz', type=int, default=8) 5 | parser.add_argument('--modeldim', type=int, default=4096) 6 | parser.add_argument('--headsz', type=int, default=64) 7 | parser.add_argument('--seqlen', type=int, default=4096) 8 | parser.add_argument('--alg', type=str, default='smallhead') 9 | parser.add_argument('--forward', action=argparse.BooleanOptionalAction) # Forward pass only 10 | cmd_args = parser.parse_args() 11 | 12 | def gen_rwkv7_data(): 13 | q,w,k,v,a,b = th.randn(6, cmd_args.batchsz, cmd_args.seqlen, cmd_args.modeldim//cmd_args.headsz, cmd_args.headsz, dtype = th.bfloat16, device = 'cuda') 14 | w = -th.nn.functional.softplus(w)-0.5 15 | a = th.nn.functional.normalize(a, p=2, dim=-1) 16 | b = -a*th.sigmoid(b) 17 | s0 = th.randn(cmd_args.batchsz, cmd_args.modeldim//cmd_args.headsz, cmd_args.headsz, cmd_args.headsz, dtype = th.bfloat16, device = 'cuda') 18 | return q,w,k,v,a,b,s0 19 | 20 | def benchmark(f, params): 21 | if not cmd_args.forward: 22 | for p in params: p.requires_grad_() 23 | dy = ds = None 24 | def wrap(): 25 | y,s = f(*params) 26 | if cmd_args.forward: return 27 | nonlocal dy,ds 28 | if dy is None: dy,ds = th.randn_like(y),th.randn_like(s) 29 | return th.autograd.grad(y, params, grad_outputs=(dy,ds)) 30 | 31 | wrap() # Warmup (compile triton) 32 | th.cuda.synchronize() 33 | th.cuda.reset_peak_memory_stats() 34 | wrap() # Measure memory 35 | th.cuda.synchronize() 36 | print(f'Peak VRAM {th.cuda.max_memory_allocated()/2**30:.2f} GB') 37 | ms, min_ms, max_ms = triton.testing.do_bench(wrap, quantiles=[0.5,0.2,0.8], warmup=1000,rep=2000) 38 | print('Time', f'{ms:.2f} ms ({min_ms:.2f} - {max_ms:.2f})') 39 | 40 | params = gen_rwkv7_data() 41 | 42 | if cmd_args.alg != 'fla': 43 | from wind_rwkv.rwkv7 import * 44 | if cmd_args.alg == 'smallhead': 45 | print('Backstepping smallhead fp32') 46 | load_backstepping_smallhead(cmd_args.headsz) 47 | benchmark(attn_backstepping_smallhead, params) 48 | elif cmd_args.alg == 'longhead': 49 | print('Backstepping longhead fp32') 50 | nheads = cmd_args.modeldim//cmd_args.headsz 51 | load_backstepping_longhead(cmd_args.headsz, cmd_args.batchsz * nheads) 52 | benchmark(attn_backstepping_longhead, params) 53 | elif cmd_args.alg == 'chunked': 54 | print('Chunked cuda') 55 | load_chunked_cuda(cmd_args.headsz) 56 | benchmark(attn_chunked_cuda, params) 57 | elif cmd_args.alg == 'chunked_varlen': 58 | print('Chunked cuda varlen') 59 | load_chunked_cuda_varlen(cmd_args.headsz) 60 | def wrap_varlen(r,w,k,v,a,b,s0): 61 | B,T,H,C = r.shape 62 | r,w,k,v,a,b = [i.view(B*T,H,C) for i in [r,w,k,v,a,b]] 63 | cu_seqlens = th.arange(B+1, device=w.device)*T 64 | y,sT = attn_chunked_cuda_varlen(r,w,k,v,a,b,s0,cu_seqlens) 65 | return y.view(B,T,H,C), sT 66 | benchmark(wrap_varlen, params) 67 | elif cmd_args.alg == 'bighead_bf16': 68 | print('Triton bighead bf16') 69 | benchmark(attn_triton_bighead_bf16, params) 70 | elif cmd_args.alg == 'bighead_fp16': 71 | print('Triton bighead fp16') 72 | benchmark(attn_triton_bighead_fp16, params) 73 | elif cmd_args.alg == 'bighead_fp32': 74 | print('Triton bighead fp32') 75 | benchmark(attn_triton_bighead_fp32, params) 76 | else: 77 | print('Unknown alg', cmd_args.arg) 78 | else: 79 | assert cmd_args.alg == 'fla' 80 | print('FLA chunk_rwkv7') 81 | from fla.ops.rwkv7 import chunk_rwkv7 82 | def attn_fla(r,w,k,v,a,b,s0): 83 | return chunk_rwkv7(r,-w.exp(),k,v,a,b, initial_state=s0) 84 | benchmark(attn_fla, params) 85 | -------------------------------------------------------------------------------- /tests/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Johan Sokrates Wind 2 | 3 | # Based on https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v7/train_temp 4 | 5 | """ 6 | # Install wind_rwkv (from root wind_rwkv directory) 7 | pip install -e . 8 | 9 | # Install requirements 10 | pip install deepspeed ninja wandb 11 | 12 | # Download data, we use minipile (1498226207 tokens, around 3GB) 13 | mkdir -p data 14 | wget --continue -O data/minipile.idx https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.idx 15 | wget --continue -O data/minipile.bin https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.bin 16 | 17 | # Run on a single gpu with ~8GB RAM 18 | torchrun train.py --micro_bsz 12 19 | # Run on 4 gpus without gradient checkpointing 20 | torchrun --nproc-per-node=4 train.py --grad_cp 0 21 | 22 | # First run creates out/L12-D768/rwkv-init.pth, subsequent runs will continue from latest checkpoint in out/L12-D768/ 23 | 24 | # out/L12-D768/train_log.txt losses should be similar to 25 | 0 5.056944 157.1097 0.00059976 2025-06-09 00:01:03.903503 26 | 1 4.016493 55.5061 0.00059901 2025-06-09 00:04:43.806584 27 | 2 3.750670 42.5496 0.00059775 2025-06-09 00:08:23.082158 28 | 3 3.630432 37.7291 0.00059600 2025-06-09 00:12:02.702873 29 | 4 3.553571 34.9379 0.00059374 2025-06-09 00:15:42.810617 30 | 5 3.486632 32.6757 0.00059099 2025-06-09 00:19:22.601675 31 | 6 3.434177 31.0059 0.00058775 2025-06-09 00:23:01.326041 32 | 7 3.381845 29.4250 0.00058403 2025-06-09 00:26:41.194270 33 | 8 3.338046 28.1640 0.00057984 2025-06-09 00:30:21.751441 34 | 9 3.293576 26.9390 0.00057517 2025-06-09 00:34:01.008750 35 | ... 36 | """ 37 | 38 | import os, struct, math, tqdm, datetime, time, argparse 39 | 40 | import torch 41 | from torch import nn 42 | import torch.nn.functional as F 43 | 44 | import deepspeed, logging 45 | deepspeed.utils.logger.setLevel(logging.WARNING) 46 | 47 | from wind_rwkv.rwkv7 import load_chunked_cuda, attn_chunked_cuda 48 | 49 | 50 | # Parse arguments 51 | 52 | def parse_args(): 53 | parser = argparse.ArgumentParser() 54 | 55 | parser.add_argument("--data_file", default="data/minipile", type=str) 56 | parser.add_argument("--proj_dir", default="auto", type=str) 57 | parser.add_argument("--wandb", default="Test", type=str) # wandb project name. if "" then don't use wandb 58 | 59 | parser.add_argument("--vocab_size", default=65536, type=int) 60 | parser.add_argument("--n_layer", default=12, type=int) 61 | parser.add_argument("--n_embd", default=768, type=int) 62 | parser.add_argument("--head_size", default=64, type=int) # can try larger values for larger models 63 | parser.add_argument("--dim_ffn", default=0, type=int) 64 | 65 | parser.add_argument("--micro_bsz", default=16, type=int) # micro batch size (batch size per GPU) 66 | parser.add_argument("--ctx_len", default=512, type=int) 67 | 68 | parser.add_argument("--epoch_save", default=10, type=int) # save the model every [epoch_save] "epochs" 69 | parser.add_argument("--samples_per_epoch", default=40320, type=int) 70 | 71 | parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 72 | parser.add_argument("--lr_final", default=6e-5, type=float) 73 | parser.add_argument("--warmup_steps", default=10, type=int) # try 10 if you load a model 74 | parser.add_argument("--beta1", default=0.9, type=float) 75 | parser.add_argument("--beta2", default=0.99, type=float) 76 | parser.add_argument("--adam_eps", default=1e-18, type=float) 77 | parser.add_argument("--grad_cp", default=1, type=int) # gradient checkpt: saves VRAM, but slower 78 | parser.add_argument("--weight_decay", default=1e-3, type=float) # try 0.1 79 | parser.add_argument("--grad_clip", default=1.0, type=float) # reduce it to 0.7 / 0.5 / 0.3 / 0.2 for problematic samples 80 | 81 | parser.add_argument("--torch_compile", default=1, type=int) 82 | parser.add_argument("--ds_bucket_mb", default=2, type=int) # deepspeed bucket size in MB. 200 seems enough 83 | parser.add_argument("--local_rank", default=0, type=int) 84 | 85 | args = parser.parse_args() 86 | 87 | if args.proj_dir == "auto": 88 | args.proj_dir = f"out/L{args.n_layer}-D{args.n_embd}" 89 | if not args.dim_ffn: 90 | args.dim_ffn = args.n_embd*4 91 | 92 | assert all(i%32 == 0 for i in [args.n_embd, args.dim_ffn]) 93 | 94 | args.global_rank = int(os.environ["RANK"]) 95 | args.world_size = int(os.environ["WORLD_SIZE"]) 96 | 97 | args.total_bsz = args.micro_bsz * args.world_size 98 | assert args.samples_per_epoch % args.total_bsz == 0 99 | 100 | args.timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") 101 | 102 | return args 103 | 104 | 105 | # Model definition 106 | 107 | def new_param(*shape): return nn.Parameter(torch.empty(*shape)) 108 | 109 | class RWKV_Tmix_x070(nn.Module): 110 | def __init__(self, args): 111 | super().__init__() 112 | C,N = args.n_embd, args.head_size 113 | self.n_head = C//N 114 | 115 | for p in "x_r x_w x_k x_v x_a x_g".split(): 116 | setattr(self, p, new_param(1,1,C)) 117 | 118 | dims = [max(32, 32*round(fac*C**p/32)) for fac, p in zip([1.8,1.8,1.3,0.6], [0.5,0.5,0.5,0.8])] 119 | for c, D in zip("wavg", dims): 120 | setattr(self, f"{c}1", new_param(C,D)) 121 | setattr(self, f"{c}2", new_param(D,C)) 122 | if c != "g": 123 | setattr(self, f"{c}0", new_param(1,1,C)) 124 | 125 | self.k_k = new_param(1,1,C) 126 | self.k_a = new_param(1,1,C) 127 | self.r_k = new_param(C//N,N) 128 | 129 | self.receptance, self.key, self.value, self.output = [nn.Linear(C, C, bias=False) for i in range(4)] 130 | self.ln_x = nn.GroupNorm(C//N, C, eps=64e-5) 131 | 132 | load_chunked_cuda(args.head_size) 133 | 134 | def forward(self, x, v0): 135 | B,T,C = x.shape 136 | H = self.n_head 137 | 138 | last_x = F.pad(x, (0,0,1,-1)) 139 | xr,xw,xk,xv,xa,xg = [x + m * (last_x - x) for m in [self.x_r, self.x_w, self.x_k, self.x_v, self.x_a, self.x_g]] 140 | 141 | r = self.receptance(xr) 142 | w = -F.softplus(-self.w0 - (xw @ self.w1).tanh() @ self.w2) - 0.5 143 | k = self.key(xk) 144 | v = self.value(xv) 145 | if v0 is None: 146 | v0 = v # store first layer's v 147 | else: 148 | v = v + (v0 - v) * (self.v0 + xv @ self.v1 @ self.v2).sigmoid() 149 | a = (self.a0 + xa @ self.a1 @ self.a2).sigmoid() 150 | g = (xg @ self.g1).sigmoid() @ self.g2 151 | 152 | kk = k * self.k_k 153 | k = k * (1 + (a-1) * self.k_a) 154 | 155 | r,w,k,v,kk,a = [i.reshape(B,T,H,-1) for i in [r,w,k,v,kk,a]] 156 | 157 | kk = F.normalize(kk, dim=-1) 158 | x = attn_chunked_cuda(r, w, k, v, kk, -kk*a)[0] 159 | x = self.ln_x(x.view(B*T, C)).view(B,T,C) 160 | 161 | x = x + ((r * k * self.r_k).sum(-1,True) * v).view(B,T,C) 162 | x = self.output(x * g) 163 | return x, v0 164 | 165 | class RWKV_CMix_x070(nn.Module): 166 | def __init__(self, args): 167 | super().__init__() 168 | self.x_k = new_param(1,1,args.n_embd) 169 | self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) 170 | self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) 171 | 172 | def forward(self, x): 173 | k = x + (F.pad(x, (0,0,1,-1)) - x) * self.x_k 174 | return self.value(self.key(k).relu()**2) 175 | 176 | 177 | class Block(nn.Module): 178 | def __init__(self, args): 179 | super().__init__() 180 | self.ln1 = nn.LayerNorm(args.n_embd) 181 | self.ln2 = nn.LayerNorm(args.n_embd) 182 | self.att = RWKV_Tmix_x070(args) 183 | self.ffn = RWKV_CMix_x070(args) 184 | 185 | def forward(self, x, v0): 186 | x_attn, v0 = self.att(self.ln1(x), v0) 187 | x = x + x_attn 188 | x = x + self.ffn(self.ln2(x)) 189 | return x, v0 190 | 191 | class RWKV(nn.Module): 192 | def __init__(self, args): 193 | super().__init__() 194 | self.emb = nn.Embedding(args.vocab_size, args.n_embd) 195 | self.blocks = nn.ModuleList([Block(args) for i in range(args.n_layer)]) 196 | self.blocks[0].ln0 = nn.LayerNorm(args.n_embd) 197 | self.ln_out = nn.LayerNorm(args.n_embd) 198 | self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False) 199 | 200 | def forward(self, tokens): 201 | x = self.blocks[0].ln0(self.emb(tokens)) 202 | v0 = None 203 | for block in self.blocks: 204 | if args.grad_cp: 205 | x, v0 = deepspeed.checkpointing.checkpoint(block, x, v0) 206 | else: 207 | x, v0 = block(x, v0) 208 | return self.head(self.ln_out(x)) 209 | 210 | 211 | # Sample initial weights 212 | 213 | def sample_initial_weights(model, args): 214 | W = model.state_dict() 215 | 216 | scale = 0.5*max(args.vocab_size / args.n_embd, 1)**0.5 217 | nn.init.orthogonal_(W["head.weight"], gain=scale) 218 | nn.init.uniform_(W["emb.weight"], a=-1e-4, b=1e-4) 219 | 220 | L,C,N = args.n_layer, args.n_embd, args.head_size 221 | for i in range(L): 222 | n = torch.arange(C) 223 | 224 | ffn = f"blocks.{i}.ffn." 225 | W[ffn+"x_k"][:] = 1-(n/C)**((1-i/L)**4) 226 | nn.init.orthogonal_(W[ffn+"key.weight"]) 227 | W[ffn+"value.weight"][:] = 0 228 | 229 | att = f"blocks.{i}.att." 230 | for c,p in zip("rwkvag", [0.2,0.9,0.7,0.7,0.9,0.2]): 231 | W[att+"x_"+c][:] = 1-(n/C)**(p*(1-i/L)) 232 | 233 | linear = n/(C-1)-0.5 234 | zigzag = (z := (n%N)*2/(N-1)-1) * z.abs() 235 | W[att+"k_k"][:] = 0.71 - linear*0.1 236 | W[att+"k_a"][:] = 1.02 237 | W[att+"r_k"][:] =-0.04 238 | W[att+"w0"][:] = 6*(n/(C-1))**(1+(i/(L-1))**0.3) - 6 + zigzag*2.5 + 0.5 239 | W[att+"a0"][:] =-0.19 + zigzag*0.3 + linear*0.4 240 | W[att+"v0"][:] = 0.73 - linear*0.4 241 | 242 | for c in "wvag": 243 | W[att+c+"1"][:] = 0 244 | nn.init.orthogonal_(W[att+c+"2"], gain=0.1) 245 | 246 | W[att+"ln_x.weight"][:] = ((1+i)/L)**0.7 247 | nn.init.orthogonal_(W[att+"receptance.weight"]) 248 | nn.init.orthogonal_(W[att+"key.weight"], gain=0.1) 249 | nn.init.orthogonal_(W[att+"value.weight"]) 250 | W[att+"output.weight"][:] = 0 251 | W = {k:v.bfloat16() for k,v in W.items()} 252 | return W 253 | 254 | 255 | # Load dataset 256 | 257 | class BinIdxDataset(torch.utils.data.Dataset): 258 | def __init__(self, args): 259 | self.args = args 260 | 261 | path = args.data_file 262 | with open(path+".idx", "rb") as stream: 263 | assert stream.read(9) == b"MMIDIDX\x00\x00" # File format magic 264 | assert struct.unpack("0) 389 | for local_step in prog_bar: 390 | cur_step = epoch_steps * epoch + local_step 391 | 392 | lr = lr_schedule(cur_step) 393 | for param_group in opt.param_groups: 394 | param_group["lr"] = lr * param_group["lr_scale"] 395 | opt.param_groups[2]["weight_decay"] = args.weight_decay 396 | 397 | # Load batch 398 | tokens, targets = map(torch.stack, zip(*[dataset[cur_step*args.micro_bsz+i] for i in range(args.micro_bsz)])) 399 | 400 | # Update step 401 | logits = model(tokens.cuda()) 402 | loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.flatten().cuda()) 403 | loss = L2Wrap.apply(loss, logits) 404 | model.backward(loss) 405 | model.step() 406 | 407 | all_loss = loss.detach().clone()/args.world_size 408 | torch.distributed.reduce(all_loss, 0, op=torch.distributed.ReduceOp.SUM) 409 | 410 | # Logging 411 | if args.global_rank == 0: 412 | tokens_per_step = args.ctx_len * args.total_bsz 413 | 414 | now = time.time_ns() 415 | it_s = 1e9/(now-last_time) 416 | kt_s = tokens_per_step * it_s / 1000 417 | last_time = now 418 | 419 | loss_sum += all_loss.item() 420 | loss_cnt += 1 421 | epoch_loss = loss_sum/loss_cnt 422 | 423 | info = {"loss": epoch_loss, "lr": lr, "last it/s": it_s, "Kt/s": kt_s} 424 | for k in info: info[k] = f'{info[k]:<5.4g}'.replace(' ','0') 425 | prog_bar.set_postfix(info) 426 | 427 | if args.wandb: 428 | info = {"loss": loss, "lr": lr, "wd": args.weight_decay, "Gtokens": cur_step * tokens_per_step / 1e9, "kt/s": kt_s} 429 | wandb.log(info, step=cur_step) 430 | 431 | # Save final model 432 | dataset_steps = len(dataset.data) / (args.ctx_len * args.total_bsz) 433 | if (cur_step+1)*tokens_per_step >= len(dataset.data): 434 | torch.save(model.module.state_dict(), f"{args.proj_dir}/rwkv-final.pth") 435 | exit(0) 436 | 437 | if args.global_rank == 0: 438 | # Save checkpoints 439 | if ((args.epoch_save and (epoch-start_epoch) % args.epoch_save == 0) or epoch == epochs-1): 440 | torch.save(model.module.state_dict(), f"{args.proj_dir}/rwkv-{epoch}.pth") 441 | 442 | # Logging 443 | print(f"{epoch} {epoch_loss:.6f} {math.exp(epoch_loss):.4f} {lr:.8f} {datetime.datetime.now()}", file=log, flush=True) 444 | loss_cnt = loss_sum = 0 445 | -------------------------------------------------------------------------------- /tests/varlen_test.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from wind_rwkv.rwkv7 import * 3 | 4 | def naive(r,w,k,v,a,b,s0, cu_seqlens): 5 | T,H,C = r.shape 6 | B = len(cu_seqlens)-1 7 | if s0 is None: s0 = th.zeros(B,H,C,C, device=w.device) 8 | dtype = w.dtype 9 | r,w,k,v,a,b,s0 = [i.double() for i in [r,w,k,v,a,b,s0]] 10 | y = th.empty_like(v) 11 | sT = th.empty_like(s0) 12 | bi = 0 13 | for t in range(T): 14 | if t == cu_seqlens[bi]: 15 | s = s0[bi] 16 | bi += 1 17 | s = s * th.exp(-th.exp(w[t,:,None,:])) + s @ a[t,:,:,None] * b[t,:,None,:] + v[t,:,:,None] * k[t,:,None,:] 18 | y[t,:,:,None] = s @ r[t,:,:,None] 19 | if t+1 == cu_seqlens[bi]: 20 | sT[bi-1] = s 21 | return y.to(dtype), sT.to(dtype) 22 | 23 | def grad_check(f1, f2, params, backward = True, aux=()): 24 | if backward: params = [p.clone().requires_grad_() for p in params] 25 | y1 = f1(*params,*aux) 26 | y2 = f2(*params,*aux) 27 | def rel(a,b): return (a-b).norm()/max(b.norm(),1e-30) 28 | print('Forward rel. error'+'s'*(len(y1)>1)) 29 | for a,b in zip(y1,y2): 30 | print(f'{rel(a,b):.2e} ({b.norm():.0e})') 31 | 32 | if not backward: return 33 | 34 | dy = tuple(th.randn_like(i) for i in y1) 35 | d1 = th.autograd.grad(y1, params, grad_outputs=dy) 36 | for p in params: 37 | if p.grad is not None: 38 | p.grad.random_() # So th.empty doesn't recover the gradient 39 | p.grad = None 40 | d2 = th.autograd.grad(y2, params, grad_outputs=dy) 41 | print('Gradient rel. errors') 42 | for a,b in zip(d1,d2): 43 | print(f'{rel(a,b):.2e} ({b.norm():.0e})') 44 | 45 | cu_seqlens = th.tensor([0,16,48,64], device='cuda') 46 | modeldim = 1024 47 | headsz = 128 48 | 49 | def gen_rwkv7_data(): 50 | q,w,k,v,a,b = th.randn(6, cu_seqlens[-1], modeldim//headsz, headsz, dtype = th.bfloat16, device = 'cuda') 51 | w = -th.nn.functional.softplus(w)-0.5 52 | a = th.nn.functional.normalize(a, p=2, dim=-1) 53 | b = -a*th.sigmoid(b) 54 | s0 = th.randn(len(cu_seqlens)-1, modeldim//headsz, headsz, headsz, dtype = th.bfloat16, device = 'cuda') 55 | return q,w,k,v,a,b,s0 56 | 57 | th.manual_seed(0) 58 | params = gen_rwkv7_data() 59 | 60 | if 0: 61 | print('FLA chunk_rwkv7') 62 | from fla.ops.rwkv7 import chunk_rwkv7 63 | def attn_fla(r,w,k,v,a,b,s0, cu_seqlens): 64 | r,w,k,v,a,b = [i.unsqueeze(0) for i in [r,w,k,v,a,b]] 65 | y,sT = chunk_rwkv7(r,-w.exp(),k,v,a,b, initial_state=s0.mT, cu_seqlens=cu_seqlens) 66 | return y.squeeze(0), sT.mT 67 | grad_check(attn_fla, naive, params, aux=(cu_seqlens,)) 68 | 69 | print('Chunked cuda varlen') 70 | load_chunked_cuda_varlen(headsz) 71 | grad_check(attn_chunked_cuda_varlen, naive, params, aux=(cu_seqlens,)) 72 | -------------------------------------------------------------------------------- /wind_rwkv/__init__.py: -------------------------------------------------------------------------------- 1 | import wind_rwkv.rwkv7 2 | __version__ = '0.2' 3 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/__init__.py: -------------------------------------------------------------------------------- 1 | from .triton_bighead import attn_triton_bighead, attn_triton_bighead_bf16, attn_triton_bighead_fp16, attn_triton_bighead_fp32 2 | from .chunked_cuda.chunked_cuda import load_chunked_cuda, attn_chunked_cuda 3 | from .chunked_cuda_varlen.chunked_cuda_varlen import load_chunked_cuda_varlen, attn_chunked_cuda_varlen 4 | from .backstepping_smallhead.backstepping_smallhead import load_backstepping_smallhead, attn_backstepping_smallhead 5 | from .backstepping_longhead.backstepping_longhead import load_backstepping_longhead, attn_backstepping_longhead 6 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/backstepping_longhead/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johanwind/wind_rwkv/acc0488e8c86ee5e7f3184ae4a9c1d97e1e14fff/wind_rwkv/rwkv7/backstepping_longhead/__init__.py -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/backstepping_longhead/backstepping_longhead.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | struct __nv_bfloat16; 4 | using bf = __nv_bfloat16; 5 | 6 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*s0, bf*y, float*s, float*sa, bf*sT); 7 | 8 | void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &s0, torch::Tensor &y, c10::optional s, c10::optional sa, torch::Tensor &sT) { 9 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; 10 | cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)s0.data_ptr(), (bf*)y.data_ptr(), s.has_value() ? (float*)s.value().data_ptr() : NULL, sa.has_value() ? (float*)sa.value().data_ptr() : NULL, (bf*)sT.data_ptr()); 11 | } 12 | 13 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*dy, float*s, float*sa, bf*dsT, float*dw, float*dq, float*dk, bf*dv, float*da, float*db, bf*ds0); 14 | 15 | void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &dy, 16 | torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dsT, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &da, torch::Tensor &db, torch::Tensor &ds0) { 17 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; 18 | cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)dy.data_ptr(), 19 | (float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dsT.data_ptr(), (float*)dw.data_ptr(), (float*)dq.data_ptr(), (float*)dk.data_ptr(), (bf*)dv.data_ptr(), (float*)da.data_ptr(), (float*)db.data_ptr(), (bf*)ds0.data_ptr()); 20 | } 21 | 22 | TORCH_LIBRARY(wind_backstepping_longhead, m) { 23 | m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor a, Tensor b, Tensor s0, Tensor(a!) y, Tensor? s, Tensor? sa, Tensor(d!) sT) -> ()"); 24 | m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor a, Tensor b, Tensor dy, Tensor s, Tensor sa, Tensor dsT, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) da, Tensor(f!) db, Tensor(g!) ds0) -> ()"); 25 | } 26 | 27 | TORCH_LIBRARY_IMPL(wind_backstepping_longhead, CUDA, m) { 28 | m.impl("forward", &forward); 29 | m.impl("backward", &backward); 30 | } 31 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/backstepping_longhead/backstepping_longhead.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, Johan Sokrates Wind 2 | 3 | #include 4 | #include 5 | 6 | using bf = __nv_bfloat16; 7 | #if defined AMD 8 | #define to_float(x) (x) 9 | #define to_bf(x) (x) 10 | #else 11 | #define to_float(x) __bfloat162float(x) 12 | #define to_bf(x) __float2bfloat16_rn(x) 13 | #endif 14 | 15 | typedef bf * __restrict__ F_; 16 | 17 | constexpr int K = _K_; // Value dim chunksize 18 | 19 | // sum "val" in groups of _C_/K threads with stride K. Expects "share" to be shared memory of size _C_ 20 | __device__ inline float sum_reduce(float val, float*share) { 21 | constexpr int ni = K, nj = _C_/K; 22 | int i = threadIdx.x%K, j = threadIdx.x/K; 23 | __syncthreads(); 24 | share[j+i*nj] = val; 25 | __syncthreads(); 26 | if (j == 0) { 27 | float sum = 0; 28 | #pragma unroll 29 | for (int l = 0; l < nj; l++) { 30 | sum += share[l+i*nj]; 31 | } 32 | share[i*nj] = sum; 33 | } 34 | __syncthreads(); 35 | val = share[i*nj]; 36 | __syncthreads(); 37 | return val; 38 | } 39 | 40 | __global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ s0_, bf* y_, float* s_, float* sa_, bf* sT_) { 41 | constexpr int C = _C_; 42 | int bb = blockIdx.z, hh = blockIdx.y, basei = blockIdx.x*K, i = threadIdx.x, rowi_ = basei+i%K, basej = i/K*K; 43 | 44 | float state[K]; 45 | for (int j = 0; j < K; j++) { 46 | state[j] = to_float(s0_[bb*H*C*C + hh*C*C + rowi_*C + basej+j]); 47 | } 48 | __shared__ float q[C], k[C], w[C], a[C], b[C], share[C]; 49 | 50 | for (int t = 0; t < T; t++) { 51 | int ind = bb*T*H*C + t*H*C + hh * C + i; 52 | __syncthreads(); 53 | q[i] = to_float(q_[ind]); 54 | w[i] = __expf(-__expf(to_float(w_[ind]))); 55 | k[i] = to_float(k_[ind]); 56 | a[i] = to_float(a_[ind]); 57 | b[i] = to_float(b_[ind]); 58 | __syncthreads(); 59 | 60 | float sa = 0; 61 | #pragma unroll 62 | for (int j = 0; j < K; j++) { 63 | sa += state[j] * a[basej+j]; 64 | } 65 | int vind = bb*T*H*C + t*H*C + hh*C + rowi_; 66 | sa = sum_reduce(sa, share); 67 | if (basej == 0 && sa_ != NULL) sa_[vind] = sa; 68 | 69 | float v = to_float(v_[vind]), y = 0; 70 | #pragma unroll 71 | for (int j = 0; j < K; j++) { 72 | float& s = state[j]; 73 | int j_ = basej+j; 74 | s = s * w[j_] + sa * b[j_] + v * k[j_]; 75 | y += s * q[j_]; 76 | } 77 | y = sum_reduce(y, share); 78 | if (basej == 0) y_[vind] = to_bf(y); 79 | 80 | if ((t+1)%_CHUNK_LEN_ == 0 && s_ != NULL) { 81 | int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + basej*C + rowi_; 82 | #pragma unroll 83 | for (int j = 0; j < K; j++) { 84 | s_[base + j*C] = state[j]; 85 | } 86 | } 87 | } 88 | for (int j = 0; j < K; j++) { 89 | sT_[bb*H*C*C + hh*C*C + rowi_*C + basej+j] = to_bf(state[j]); 90 | } 91 | } 92 | 93 | __global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, F_ dsT_, float* dw_, float* dq_, float* dk_, bf* dv_, float* da_, float* db_, bf* ds0_) { 94 | constexpr int C = _C_; 95 | int bb = blockIdx.z, hh = blockIdx.y, basei = blockIdx.x*K, i = threadIdx.x, rowi = i%K, basej = i/K*K; 96 | 97 | float stateT[K], dstate[K], dstateT[K]; 98 | for (int j = 0; j < K; j++) { 99 | dstate[j] = to_float(dsT_[(bb*H+hh)*C*C + (basei+rowi)*C + basej+j]); 100 | dstateT[j] = to_float(dsT_[(bb*H+hh)*C*C + (basei+j)*C + i]); 101 | } 102 | __shared__ float w[C], q[C], k[C], a[C], b[C], v[K], dy[K], sa[K], dSb_shared[K], share[C]; 103 | float qi, wi, ki, ai, bi; 104 | 105 | for (int t = T-1; t >= 0; t--) { 106 | int ind = bb*T*H*C + t*H*C + hh * C + i; 107 | __syncthreads(); 108 | q[i] = qi = to_float(q_[ind]); 109 | float wi_fac = -__expf(to_float(w_[ind])); 110 | w[i] = wi = __expf(wi_fac); 111 | k[i] = ki = to_float(k_[ind]); 112 | a[i] = ai = to_float(a_[ind]); 113 | b[i] = bi = to_float(b_[ind]); 114 | if (i < K) { 115 | int vind = ind + basei; 116 | v[i] = to_float(v_[vind]); 117 | dy[i] = to_float(dy_[vind]); 118 | sa[i] = sa_[vind]; 119 | } 120 | __syncthreads(); 121 | 122 | if ((t+1)%_CHUNK_LEN_ == 0) { 123 | int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C + basei; 124 | #pragma unroll 125 | for (int j = 0; j < K; j++) { 126 | stateT[j] = s_[base + j]; 127 | } 128 | } 129 | 130 | float dq = 0, iwi = 1.f/wi, dw = 0, dk = 0, db = 0; 131 | #pragma unroll 132 | for (int j = 0; j < K; j++) { 133 | dq += stateT[j]*dy[j]; 134 | stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi; 135 | dstateT[j] += qi * dy[j]; 136 | dw += dstateT[j] * stateT[j]; 137 | dk += dstateT[j] * v[j]; 138 | db += dstateT[j] * sa[j]; 139 | } 140 | atomicAdd(dq_+ind, dq); 141 | atomicAdd(dw_+ind, dw * wi * wi_fac); 142 | atomicAdd(dk_+ind, dk); 143 | atomicAdd(db_+ind, db); 144 | 145 | float dv = 0, dSb = 0, dyi = dy[rowi]; 146 | #pragma unroll 147 | for (int j = 0; j < K; j++) { 148 | dstate[j] += dyi * q[basej+j]; 149 | dv += dstate[j] * k[basej+j]; 150 | dSb += dstate[j] * b[basej+j]; 151 | } 152 | dv = sum_reduce(dv, share); 153 | dSb = sum_reduce(dSb, share); 154 | if (basej == 0) { 155 | dv_[bb*T*H*C + t*H*C + hh*C + basei+rowi] = to_bf(dv); 156 | dSb_shared[rowi] = dSb; 157 | } 158 | __syncthreads(); 159 | 160 | float da = 0; 161 | #pragma unroll 162 | for (int j = 0; j < K; j++) { 163 | da += stateT[j]*dSb_shared[j]; 164 | dstate[j] = dstate[j] * w[basej+j] + dSb * a[basej+j]; 165 | dstateT[j] = dstateT[j] * wi + ai * dSb_shared[j]; 166 | } 167 | atomicAdd(da_+ind, da); 168 | } 169 | for (int j = 0; j < K; j++) { 170 | ds0_[(bb*H+hh)*C*C + (basei+rowi)*C + basej+j] = to_bf(dstate[j]); 171 | } 172 | } 173 | 174 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*s0, bf*y, float*s, float*sa, bf*sT) { 175 | static_assert(_C_%K == 0, "_C_ must be divisible by 64"); 176 | forward_kernel<<>>(T,H,w,q,k,v,a,b,s0,y,s,sa,sT); 177 | } 178 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*dy, float*s, float*sa, bf*dsT, float*dw, float*dq, float*dk, bf*dv, float*da, float*db, bf*ds0) { 179 | assert(T%_CHUNK_LEN_ == 0); 180 | backward_kernel<<>>(T,H,w,q,k,v,a,b,dy,s,sa,dsT,dw,dq,dk,dv,da,db,ds0); 181 | } 182 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/backstepping_longhead/backstepping_longhead.py: -------------------------------------------------------------------------------- 1 | import os, torch as th 2 | from torch.utils.cpp_extension import load 3 | 4 | CHUNK_LEN = 16 5 | 6 | class RWKV7_longhead(th.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, q,w,k,v,a,b,s0): 9 | B,T,H,C = w.shape 10 | assert T%CHUNK_LEN == 0 11 | if not th.compiler.is_compiling(): 12 | assert hasattr(th.ops.wind_backstepping_longhead, 'forward'), 'Requires a loaded kernel from load_backstepping_longhead(head_size)' 13 | assert all(i.dtype==th.bfloat16 for i in [w,q,k,v,a,b,s0]) 14 | assert all(i.is_contiguous() for i in [w,q,k,v,a,b,s0]) 15 | assert all(i.shape == w.shape for i in [w,q,k,v,a,b]) 16 | assert list(s0.shape) == [B,H,C,C] 17 | B,T,H,C = w.shape 18 | y = th.empty_like(v) 19 | sT = th.empty_like(s0) 20 | if any(i.requires_grad for i in [w,q,k,v,a,b,s0]): 21 | s = th.empty(B,H,T//CHUNK_LEN,C,C, dtype=th.float32,device=w.device) 22 | sa = th.empty(B,T,H,C, dtype=th.float32,device=w.device) 23 | else: 24 | s = sa = None 25 | th.ops.wind_backstepping_longhead.forward(w,q,k,v,a,b, s0,y,s,sa,sT) 26 | ctx.save_for_backward(w,q,k,v,a,b,s,sa) 27 | return y, sT 28 | @staticmethod 29 | def backward(ctx, dy, dsT): 30 | w,q,k,v,a,b,s,sa = ctx.saved_tensors 31 | B,T,H,C = w.shape 32 | if not th.compiler.is_compiling(): 33 | assert all(i.dtype==th.bfloat16 for i in [dy,dsT]) 34 | assert all(i.is_contiguous() for i in [dy,dsT]) 35 | 36 | dv,ds0 = [th.empty_like(x) for x in [v,dsT]] 37 | dw,dq,dk,da,db = [th.zeros(B,T,H,C, device=w.device) for i in range(5)] 38 | th.ops.wind_backstepping_longhead.backward(w,q,k,v,a,b, dy,s,sa,dsT, dw,dq,dk,dv,da,db,ds0) 39 | return dq,dw,dk,dv,da,db,ds0 40 | 41 | def attn_backstepping_longhead(r,w,k,v,a,b, s0 = None): 42 | B,T,H,C = w.shape 43 | if s0 is None: s0 = th.zeros(B,H,C,C, dtype=th.bfloat16,device=w.device) 44 | return RWKV7_longhead.apply(r,w,k,v,a,b, s0) 45 | 46 | def load_backstepping_longhead(head_size, batchsz_times_heads_estimate = 8*64): 47 | if hasattr(th.ops.wind_backstepping_longhead, 'forward'): return 48 | device_props = th.cuda.get_device_properties(th.cuda.current_device()) 49 | if 'AMD' in device_props.name: 50 | value_chunk_size = 16 51 | CUDA_FLAGS = [f'-D_C_={head_size}', f'-D_K_={value_chunk_size}', f'-D_CHUNK_LEN_={CHUNK_LEN}', '-O3', '-ffast-math', '-DAMD'] 52 | else: 53 | value_chunk_size = 64 54 | if th.cuda.get_device_properties(th.cuda.current_device()).multi_processor_count >= batchsz_times_heads_estimate * head_size / 32: 55 | value_chunk_size = 32 56 | CUDA_FLAGS = ['-res-usage', f'-D_C_={head_size} -D_K_={value_chunk_size}', f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] 57 | path = os.path.dirname(__file__) 58 | load(name="wind_backstepping_longhead", sources=[os.path.join(path,'backstepping_longhead.cu'), os.path.join(path,'backstepping_longhead.cpp')], is_python_module=False, verbose=False, extra_cuda_cflags=CUDA_FLAGS) 59 | assert hasattr(th.ops.wind_backstepping_longhead, 'forward') 60 | 61 | def attn_backstepping_longhead_wrap(r,w,k,v,a,b, head_size): 62 | B,T,HC = w.shape 63 | C = head_size 64 | H = HC//C 65 | r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,a,b]] 66 | s0 = th.zeros(B,H,C,C, dtype=th.bfloat16,device=w.device) 67 | return attn_backstepping_longhead(r,w,k,v,a,b,s0)[0].view(B,T,HC) 68 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/backstepping_smallhead/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johanwind/wind_rwkv/acc0488e8c86ee5e7f3184ae4a9c1d97e1e14fff/wind_rwkv/rwkv7/backstepping_smallhead/__init__.py -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/backstepping_smallhead/backstepping_smallhead.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | struct __nv_bfloat16; 4 | using bf = __nv_bfloat16; 5 | 6 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*s0, bf*y, float*s, float*sa, bf*sT); 7 | 8 | void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &s0, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa, torch::Tensor &sT) { 9 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; 10 | cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)s0.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)sT.data_ptr()); 11 | } 12 | 13 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*dy, float*s, float*sa, bf*dsT, bf*dw, bf*dq, bf*dk, bf*dv, bf*da, bf*db, bf*ds0); 14 | 15 | void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &dy, 16 | torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dsT, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &da, torch::Tensor &db, torch::Tensor &ds0) { 17 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; 18 | cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)dy.data_ptr(), 19 | (float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dsT.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)da.data_ptr(), (bf*)db.data_ptr(), (bf*)ds0.data_ptr()); 20 | } 21 | 22 | TORCH_LIBRARY(wind_backstepping_smallhead, m) { 23 | m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor a, Tensor b, Tensor s0, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa, Tensor(d!) sT) -> ()"); 24 | m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor a, Tensor b, Tensor dy, Tensor s, Tensor sa, Tensor dsT, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) da, Tensor(f!) db, Tensor(g!) ds0) -> ()"); 25 | } 26 | 27 | TORCH_LIBRARY_IMPL(wind_backstepping_smallhead, CUDA, m) { 28 | m.impl("forward", &forward); 29 | m.impl("backward", &backward); 30 | } 31 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/backstepping_smallhead/backstepping_smallhead.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | using bf = __nv_bfloat16; 5 | #if defined AMD 6 | #define to_float(x) (x) 7 | #define to_bf(x) (x) 8 | #else 9 | #define to_float(x) __bfloat162float(x) 10 | #define to_bf(x) __float2bfloat16_rn(x) 11 | #endif 12 | 13 | typedef bf * __restrict__ F_; 14 | 15 | __global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ s0_, bf* y_, float* s_, float* sa_, bf* sT_) { 16 | constexpr int C = _C_; 17 | int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x; 18 | 19 | float state[C]; 20 | for (int j = 0; j < C; j++) { 21 | state[j] = to_float(s0_[bb*H*C*C + hh*C*C + i * C + j]); 22 | } 23 | __shared__ float q[C], k[C], w[C], a[C], b[C]; 24 | 25 | for (int t = 0; t < T; t++) { 26 | int ind = bb*T*H*C + t*H*C + hh * C + i; 27 | __syncthreads(); 28 | q[i] = to_float(q_[ind]); 29 | w[i] = __expf(-__expf(to_float(w_[ind]))); 30 | k[i] = to_float(k_[ind]); 31 | a[i] = to_float(a_[ind]); 32 | b[i] = to_float(b_[ind]); 33 | __syncthreads(); 34 | 35 | float sa = 0; 36 | #pragma unroll 37 | for (int j = 0; j < C; j++) { 38 | sa += a[j] * state[j]; 39 | } 40 | sa_[ind] = sa; 41 | 42 | float v = to_float(v_[ind]); 43 | float y = 0; 44 | #pragma unroll 45 | for (int j = 0; j < C; j++) { 46 | float& s = state[j]; 47 | s = s * w[j] + sa * b[j] + k[j] * v; 48 | y += s * q[j]; 49 | } 50 | y_[ind] = to_bf(y); 51 | 52 | if ((t+1)%_CHUNK_LEN_ == 0) { 53 | int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i; 54 | #pragma unroll 55 | for (int j = 0; j < C; j++) { 56 | s_[base + j*C] = state[j]; 57 | } 58 | } 59 | } 60 | for (int j = 0; j < C; j++) { 61 | sT_[bb*H*C*C + hh*C*C + i * C + j] = to_bf(state[j]); 62 | } 63 | } 64 | 65 | __global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, F_ dsT_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_, bf* ds0_) { 66 | constexpr int C = _C_; 67 | int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x; 68 | 69 | float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0}; 70 | for (int j = 0; j < C; j++) { 71 | dstate[j] = to_float(dsT_[bb*H*C*C + hh*C*C + i * C + j]); 72 | dstateT[j] = to_float(dsT_[bb*H*C*C + hh*C*C + j * C + i]); 73 | } 74 | __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C]; 75 | float qi, wi, ki, ai, bi, dyi; 76 | 77 | for (int t = T-1; t >= 0; t--) { 78 | int ind = bb*T*H*C + t*H*C + hh * C + i; 79 | __syncthreads(); 80 | q[i] = qi = to_float(q_[ind]); 81 | float wi_fac = -__expf(to_float(w_[ind])); 82 | w[i] = wi = __expf(wi_fac); 83 | k[i] = ki = to_float(k_[ind]); 84 | a[i] = ai = to_float(a_[ind]); 85 | b[i] = bi = to_float(b_[ind]); 86 | v[i] = to_float(v_[ind]); 87 | dy[i] = dyi = to_float(dy_[ind]); 88 | sa[i] = sa_[ind]; 89 | __syncthreads(); 90 | 91 | if ((t+1)%_CHUNK_LEN_ == 0) { 92 | int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C; 93 | #pragma unroll 94 | for (int j = 0; j < C; j++) { 95 | stateT[j] = s_[base + j]; 96 | } 97 | } 98 | 99 | float dq = 0; 100 | #pragma unroll 101 | for (int j = 0; j < C; j++) { 102 | dq += stateT[j]*dy[j]; 103 | } 104 | dq_[ind] = to_bf(dq); 105 | 106 | float iwi = 1.0f/wi; 107 | #pragma unroll 108 | for (int j = 0; j < C; j++) { 109 | stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi; 110 | dstate[j] += dyi * q[j]; 111 | dstateT[j] += qi * dy[j]; 112 | } 113 | 114 | float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0; 115 | #pragma unroll 116 | for (int j = 0; j < C; j++) { 117 | dw += dstateT[j]*stateT[j]; 118 | dk += dstateT[j]*v[j]; 119 | dv += dstate[j]*k[j]; 120 | dSb += dstate[j]*b[j]; 121 | db += dstateT[j]*sa[j]; 122 | } 123 | dw_[ind] = to_bf(dw * wi * wi_fac); 124 | dk_[ind] = to_bf(dk); 125 | dv_[ind] = to_bf(dv); 126 | db_[ind] = to_bf(db); 127 | 128 | __syncthreads(); 129 | dSb_shared[i] = dSb; 130 | __syncthreads(); 131 | 132 | float da = 0; 133 | #pragma unroll 134 | for (int j = 0; j < C; j++) { 135 | da += stateT[j]*dSb_shared[j]; 136 | } 137 | da_[ind] = to_bf(da); 138 | 139 | #pragma unroll 140 | for (int j = 0; j < C; j++) { 141 | dstate[j] = dstate[j]*w[j] + dSb * a[j]; 142 | dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j]; 143 | } 144 | } 145 | for (int j = 0; j < C; j++) { 146 | ds0_[bb*H*C*C + hh*C*C + i * C + j] = to_bf(dstate[j]); 147 | } 148 | } 149 | 150 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*s0, bf*y, float*s, float*sa, bf*sT) { 151 | forward_kernel<<>>(T,H,w,q,k,v,a,b,s0,y,s,sa,sT); 152 | } 153 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*dy, float*s, float*sa, bf*dsT, bf*dw, bf*dq, bf*dk, bf*dv, bf*da, bf*db, bf*ds0) { 154 | assert(T%_CHUNK_LEN_ == 0); 155 | backward_kernel<<>>(T,H,w,q,k,v,a,b,dy,s,sa,dsT,dw,dq,dk,dv,da,db,ds0); 156 | } 157 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/backstepping_smallhead/backstepping_smallhead.py: -------------------------------------------------------------------------------- 1 | import os, torch as th 2 | from torch.utils.cpp_extension import load 3 | 4 | CHUNK_LEN = 16 5 | 6 | class RWKV7_smallhead(th.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, q,w,k,v,a,b,s0): 9 | B,T,H,C = w.shape 10 | assert T%CHUNK_LEN == 0 11 | if not th.compiler.is_compiling(): 12 | assert hasattr(th.ops.wind_backstepping_smallhead, 'forward'), 'Requires a loaded kernel from load_backstepping_smallhead(head_size)' 13 | assert all(i.dtype==th.bfloat16 for i in [w,q,k,v,a,b,s0]) 14 | assert all(i.is_contiguous() for i in [w,q,k,v,a,b,s0]) 15 | assert all(i.shape == w.shape for i in [w,q,k,v,a,b]) 16 | assert list(s0.shape) == [B,H,C,C] 17 | B,T,H,C = w.shape 18 | y = th.empty_like(v) 19 | sT = th.empty_like(s0) 20 | s = th.empty(B,H,T//CHUNK_LEN,C,C, dtype=th.float32,device=w.device) 21 | sa = th.empty(B,T,H,C, dtype=th.float32,device=w.device) 22 | th.ops.wind_backstepping_smallhead.forward(w,q,k,v,a,b, s0,y,s,sa,sT) 23 | ctx.save_for_backward(w,q,k,v,a,b,s,sa) 24 | return y,sT 25 | @staticmethod 26 | def backward(ctx, dy, dsT): 27 | w,q,k,v,a,b,s,sa = ctx.saved_tensors 28 | B,T,H,C = w.shape 29 | if not th.compiler.is_compiling(): 30 | assert all(i.dtype==th.bfloat16 for i in [dy,dsT]) 31 | assert all(i.is_contiguous() for i in [dy,dsT]) 32 | 33 | dw,dq,dk,dv,da,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,a,b,dsT]] 34 | th.ops.wind_backstepping_smallhead.backward(w,q,k,v,a,b, dy,s,sa,dsT, dw,dq,dk,dv,da,db,ds0) 35 | return dq,dw,dk,dv,da,db,ds0 36 | 37 | def attn_backstepping_smallhead(r,w,k,v,a,b, s0 = None): 38 | B,T,H,C = w.shape 39 | if s0 is None: s0 = th.zeros(B,H,C,C, dtype=th.bfloat16,device=w.device) 40 | return RWKV7_smallhead.apply(r,w,k,v,a,b, s0) 41 | 42 | def load_backstepping_smallhead(head_size): 43 | if hasattr(th.ops.wind_backstepping_smallhead, 'forward'): return 44 | device_props = th.cuda.get_device_properties(th.cuda.current_device()) 45 | if 'AMD' in device_props.name: 46 | CUDA_FLAGS = [f'-D_C_={head_size}', f'-D_CHUNK_LEN_={CHUNK_LEN}', '-O3', '-ffast-math', '-DAMD'] 47 | else: 48 | CUDA_FLAGS = ['-res-usage', f'-D_C_={head_size}', f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] 49 | path = os.path.dirname(__file__) 50 | load(name="wind_backstepping_smallhead", sources=[os.path.join(path,'backstepping_smallhead.cu'), os.path.join(path,'backstepping_smallhead.cpp')], is_python_module=False, verbose=False, extra_cuda_cflags=CUDA_FLAGS) 51 | assert hasattr(th.ops.wind_backstepping_smallhead, 'forward') 52 | 53 | def attn_backstepping_smallhead_wrap(r,w,k,v,a,b, head_size): 54 | B,T,HC = w.shape 55 | C = head_size 56 | H = HC//C 57 | r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,a,b]] 58 | s0 = th.zeros(B,H,C,C, dtype=th.bfloat16,device=w.device) 59 | return attn_backstepping_smallhead(r,w,k,v,a,b,s0)[0].view(B,T,HC) 60 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/chunked_cuda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johanwind/wind_rwkv/acc0488e8c86ee5e7f3184ae4a9c1d97e1e14fff/wind_rwkv/rwkv7/chunked_cuda/__init__.py -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/chunked_cuda/chunked_cuda.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, Johan Sokrates Wind 2 | 3 | #include 4 | 5 | struct __nv_bfloat16; 6 | using bf = __nv_bfloat16; 7 | using torch::Tensor; 8 | 9 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*s0, bf*y, bf*s, bf*sT); 10 | 11 | void forward(Tensor &w, Tensor &q, Tensor &k, Tensor &v, Tensor &a, Tensor &b, Tensor &s0, Tensor &y, c10::optional s, Tensor &sT) { 12 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; 13 | cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)s0.data_ptr(), (bf*)y.data_ptr(), s.has_value() ? (bf*)s.value().data_ptr() : NULL, (bf*)sT.data_ptr()); 14 | } 15 | 16 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*dy, bf*s, bf*dsT, bf*dw, bf*dq, bf*dk, bf*dv, bf*da, bf*db, bf*ds0); 17 | 18 | void backward(Tensor &w, Tensor &q, Tensor &k, Tensor &v, Tensor &a, Tensor &b, Tensor &dy, 19 | Tensor &s, Tensor &dsT, Tensor &dw, Tensor &dq, Tensor &dk, Tensor &dv, Tensor &da, Tensor &db, Tensor &ds0) { 20 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; 21 | cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)dy.data_ptr(), 22 | (bf*)s.data_ptr(), (bf*)dsT.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)da.data_ptr(), (bf*)db.data_ptr(), (bf*)ds0.data_ptr()); 23 | } 24 | 25 | TORCH_LIBRARY(wind_chunked_cuda, m) { 26 | m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor a, Tensor b, Tensor s0, Tensor(a!) y, Tensor? s, Tensor(c!) sT) -> ()"); 27 | m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor a, Tensor b, Tensor dy, Tensor s, Tensor dsT, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) da, Tensor(f!) db, Tensor(g!) ds0) -> ()"); 28 | } 29 | 30 | TORCH_LIBRARY_IMPL(wind_chunked_cuda, CUDA, m) { 31 | m.impl("forward", &forward); 32 | m.impl("backward", &backward); 33 | } 34 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/chunked_cuda/chunked_cuda.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, Johan Sokrates Wind 2 | 3 | #include "tile.cuh" 4 | #include 5 | typedef bf * __restrict__ F_; 6 | typedef float * __restrict__ F32_; 7 | 8 | constexpr int WARPS = _C_/16; 9 | constexpr int fw_stages = 1, bw_stages = 1; 10 | 11 | __global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ s0_, bf* y_, bf* s_, bf* sT_) { 12 | constexpr int C = _C_, K = 16; 13 | int bi = blockIdx.y, hi = blockIdx.x; 14 | extern __shared__ char smem_[]; 15 | char*smem = smem_; 16 | 17 | STile *sw_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 18 | STile *sq_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 19 | STile *sk_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 20 | STile *sv_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 21 | STile *sa_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 22 | STile *sb_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 23 | char*share = (char*)smem; 24 | 25 | int stride = H*C; 26 | int warpi = threadIdx.x/32; 27 | 28 | auto push = [&](int t) { 29 | int off = bi*T*H*C + t*K*H*C + hi*C + warpi*16; 30 | int si = t%fw_stages; 31 | sw_[si*WARPS+warpi] = GTile(w_+off, stride); 32 | sq_[si*WARPS+warpi] = GTile(q_+off, stride); 33 | sk_[si*WARPS+warpi] = GTile(k_+off, stride); 34 | sv_[si*WARPS+warpi] = GTile(v_+off, stride); 35 | sa_[si*WARPS+warpi] = GTile(a_+off, stride); 36 | sb_[si*WARPS+warpi] = GTile(b_+off, stride); 37 | }; 38 | for (int t = 0; t < fw_stages-1 && t < T/K; t++) push(t), __commit_group(); 39 | 40 | FTile state[WARPS]; 41 | for (int i = 0; i < WARPS; i++) { 42 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16; 43 | RTile tmp; 44 | tmp = GTile(s0_+off, C); 45 | state[i] = tmp; 46 | } 47 | 48 | for (int t = 0; t < T/K; t++) { 49 | __syncthreads(); 50 | if (t+fw_stages-1 < T/K) 51 | push(t+fw_stages-1); 52 | __commit_group(); 53 | __wait_groups(); 54 | __syncthreads(); 55 | int si = t%fw_stages; 56 | STile &sw = sw_[si*WARPS+warpi], &sq = sq_[si*WARPS+warpi], &sk = sk_[si*WARPS+warpi], &sv = sv_[si*WARPS+warpi], &sa = sa_[si*WARPS+warpi], &sb = sb_[si*WARPS+warpi]; 57 | 58 | FTile w = (RTile)sw; 59 | apply_(w, [](float x) { return __expf(-__expf(x)); }); 60 | FTile fw = w; 61 | FTile non_incl_pref = cumprodv<0,0>(fw); 62 | FTile incl_pref = non_incl_pref * w; 63 | FTile inv_incl_pref = incl_pref; 64 | apply_(inv_incl_pref, [](float x) { return 1.f/x; }); 65 | 66 | RTile wq = (RTile)sq * incl_pref, kwi = (RTile)sk * inv_incl_pref; 67 | RTile wa = (RTile)sa * non_incl_pref, bwi = (RTile)sb * inv_incl_pref; 68 | FTile ab = sum_warp<1,WARPS>((float*)share, tril<1>(wa % bwi)); 69 | RTile ak = sum_warp<1,WARPS>((float*)share, tril<1>(wa % kwi)); 70 | 71 | RTile ab_inv; 72 | __syncthreads(); 73 | if (threadIdx.x < 32) ab_inv = tri_minv(ab, (float*)share); 74 | __syncthreads(); 75 | ab_inv = from_warp(ab_inv, 0, (float4*)share); 76 | 77 | RTile vt = sv.t(); 78 | FTile ab_ut = vt % ak; 79 | for (int i = 0; i < WARPS; i++) 80 | ab_ut += state[i] % from_warp(wa, i, (float4*)share); 81 | RTile ut = FTile(ab_ut % ab_inv); 82 | 83 | FTile y = sum_warp<1,WARPS>((float*)share, tril<0>(wq % kwi)) % vt; 84 | y += sum_warp<1,WARPS>((float*)share, tril<0>(wq % bwi)) % ut; 85 | for (int i = 0; i < WARPS; i++) 86 | y += from_warp(wq, i, (float4*)share) % state[i]; 87 | 88 | int off = bi*T*H*C + t*K*H*C + hi*C + warpi*16; 89 | GTile(y_+off, stride) = RTile(y); 90 | 91 | RTile kwt = transpose(kwi*fw), bwt = transpose(bwi*fw); 92 | for (int i = 0; i < WARPS; i++) { 93 | if (s_ != NULL) { 94 | int off = bi*H*(T/K)*C*C + hi*(T/K)*C*C + t*C*C + warpi*16*C + i*16; 95 | GTile(s_+off, C) = (RTile)state[i]; 96 | } 97 | 98 | FTile fstate = state[i] * from_warp(fw, i, (float4*)share); 99 | fstate += vt % from_warp(kwt, i, (float4*)share); 100 | fstate += ut % from_warp(bwt, i, (float4*)share); 101 | state[i] = fstate; 102 | } 103 | } 104 | for (int i = 0; i < WARPS; i++) { 105 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16; 106 | GTile(sT_+off, C) = state[i]; 107 | } 108 | } 109 | 110 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*s0, bf*y, bf*s, bf*sT) { 111 | assert(T%16 == 0); 112 | constexpr int tmp_size1 = sizeof(float)*32*8*WARPS, tmp_size2 = sizeof(float)*16*16*2; 113 | constexpr int threads = 32*WARPS, shared_mem = sizeof(STile)*fw_stages*WARPS*6 + (tmp_size1 > tmp_size2 ? tmp_size1 : tmp_size2); 114 | static int reported = 0; 115 | if (!reported++) { 116 | #if defined VERBOSE 117 | printf("forward_kernel() uses %d bytes of (dynamic) shared memory\n", shared_mem); 118 | #endif 119 | cudaFuncAttributes attr; 120 | cudaFuncGetAttributes(&attr, forward_kernel); 121 | int cur_mem = attr.maxDynamicSharedSizeBytes; 122 | if (shared_mem > cur_mem) { 123 | #if defined VERBOSE 124 | printf("Increasing forward_kernel's MaxDynamicSharedMemorySize from %d to %d\n", cur_mem, shared_mem); 125 | #endif 126 | assert(!cudaFuncSetAttribute(forward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem)); 127 | } 128 | } 129 | forward_kernel<<>>(T,H,w,q,k,v,a,b,s0,y,s,sT); 130 | } 131 | 132 | 133 | __global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, F_ s_, F_ dsT_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_, bf* ds0_) { 134 | constexpr int C = _C_, K = 16; 135 | int bi = blockIdx.y, hi = blockIdx.x; 136 | extern __shared__ char smem_[]; 137 | char*smem = smem_; 138 | 139 | STile *sw_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 140 | STile *sq_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 141 | STile *sk_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 142 | STile *sv_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 143 | STile *sa_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 144 | STile *sb_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 145 | STile *sdy_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 146 | STile *state_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS*WARPS; 147 | char*share = (char*)smem; 148 | 149 | int stride = H*C; 150 | int warpi = threadIdx.x/32; 151 | 152 | auto push = [&](int t) { 153 | int off = bi*T*H*C + t*K*H*C + hi*C + warpi*16; 154 | int si = t%bw_stages; 155 | sw_[si*WARPS+warpi] = GTile(w_+off, stride); 156 | sq_[si*WARPS+warpi] = GTile(q_+off, stride); 157 | sk_[si*WARPS+warpi] = GTile(k_+off, stride); 158 | sv_[si*WARPS+warpi] = GTile(v_+off, stride); 159 | sa_[si*WARPS+warpi] = GTile(a_+off, stride); 160 | sb_[si*WARPS+warpi] = GTile(b_+off, stride); 161 | sdy_[si*WARPS+warpi] = GTile(dy_+off, stride); 162 | for (int i = 0; i < WARPS; i++) { 163 | int off2 = bi*H*(T/K)*C*C + hi*(T/K)*C*C + t*C*C + warpi*16*C + i*16; 164 | state_[si*WARPS*WARPS+warpi*WARPS+i] = GTile(s_+off2, C); 165 | } 166 | }; 167 | 168 | FTile dstate[WARPS]; 169 | for (int i = 0; i < WARPS; i++) { 170 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16; 171 | RTile tmp; 172 | tmp = GTile(dsT_+off, C); 173 | dstate[i] = tmp; 174 | __commit_group(); 175 | } 176 | 177 | for (int t = 0; t < bw_stages-1 && t < T/K; t++) push(T/K-1-t), __commit_group(); 178 | 179 | for (int t = T/K-1; t >= 0; t--) { 180 | __syncthreads(); 181 | if (t-bw_stages+1 >= 0) 182 | push(t-bw_stages+1); 183 | __commit_group(); 184 | __wait_groups(); 185 | __syncthreads(); 186 | int si = t%bw_stages; 187 | STile &sw = sw_[si*WARPS+warpi], &sq = sq_[si*WARPS+warpi], &sk = sk_[si*WARPS+warpi], &sv = sv_[si*WARPS+warpi], &sa = sa_[si*WARPS+warpi], &sb = sb_[si*WARPS+warpi], &sdy = sdy_[si*WARPS+warpi]; 188 | STile*state = state_+si*WARPS*WARPS; 189 | 190 | FTile w = (RTile)sw; 191 | apply_(w, [](float x) { return __expf(-__expf(x)); }); 192 | FTile fw = w; 193 | FTile non_incl_pref = cumprodv<0,0>(fw); 194 | FTile incl_pref = non_incl_pref * w; 195 | FTile inv_incl_pref = incl_pref; 196 | apply_(inv_incl_pref, [](float x) { return 1.f/x; }); 197 | 198 | RTile wq = (RTile)sq * incl_pref, kwi = (RTile)sk * inv_incl_pref; 199 | RTile wa = (RTile)sa * non_incl_pref, bwi = (RTile)sb * inv_incl_pref; 200 | FTile ab = sum_warp<1,WARPS>((float*)share, tril<1>(wa % bwi)); 201 | RTile ak = sum_warp<1,WARPS>((float*)share, tril<1>(wa % kwi)); 202 | 203 | RTile ab_inv; 204 | __syncthreads(); 205 | if (threadIdx.x < 32) ab_inv = tri_minv(ab, (float*)share); 206 | __syncthreads(); 207 | ab_inv = from_warp(ab_inv, 0, (float4*)share); 208 | 209 | RTile vt = sv.t(); 210 | FTile ab_ut = vt % ak; 211 | for (int i = 0; i < WARPS; i++) 212 | ab_ut += state[warpi*WARPS+i] % from_warp(wa, i, (float4*)share); 213 | RTile ut = FTile(ab_ut % ab_inv); 214 | 215 | FTile y = sum_warp<1,WARPS>((float*)share, tril<0>(wq % kwi)) % vt; 216 | y += sum_warp<1,WARPS>((float*)share, tril<0>(wq % bwi)) % ut; 217 | for (int i = 0; i < WARPS; i++) 218 | y += from_warp(wq, i, (float4*)share) % state[warpi*WARPS+i]; 219 | 220 | RTile qb = sum_warp<1,WARPS>((float*)share, tril<0>(wq % bwi)); 221 | RTile qk = sum_warp<1,WARPS>((float*)share, tril<0>(wq % kwi)); 222 | 223 | RTile dyt = sdy.t(); 224 | FTile dut = FTile(dyt % transpose(qb)); 225 | FTile dv = transpose(qk) % dyt; 226 | for (int i = 0; i < WARPS; i++) { 227 | RTile dstatei = dstate[i]; 228 | dut += dstatei % from_warp(bwi*fw, i, (float4*)share); 229 | dv += from_warp(kwi*fw, i, (float4*)share) % dstatei; 230 | } 231 | RTile dab_ut = FTile(dut % transpose(ab_inv)); 232 | dv += transpose(ak) % dab_ut; 233 | 234 | int off = bi*T*H*C + t*K*H*C + hi*C + warpi*16; 235 | GTile(dv_+off, stride) = RTile(dv); 236 | 237 | FTile dab = sum_warp<1,WARPS>((float*)share, tril<1>(transpose(dab_ut) % transpose(ut))); 238 | FTile dak = sum_warp<1,WARPS>((float*)share, tril<1>(transpose(dab_ut) % transpose(vt))); 239 | FTile dab_u_state0; 240 | dab_u_state0.zero_(); 241 | for (int i = 0; i < WARPS; i++) 242 | dab_u_state0 += from_warp(transpose(dab_ut), i, (float4*)share) % state[i*WARPS+warpi].t(); 243 | 244 | FTile da = dab_u_state0; 245 | da += dab % transpose(bwi); 246 | da += dak % transpose(kwi); 247 | da = non_incl_pref * da; 248 | GTile(da_+off, stride) = RTile(da); 249 | 250 | FTile dqb = sum_warp<1,WARPS>((float*)share, tril<0>(transpose(dyt) % transpose(ut))); 251 | FTile dqk = sum_warp<1,WARPS>((float*)share, tril<0>(transpose(dyt) % transpose(vt))); 252 | FTile dy_state0; 253 | dy_state0.zero_(); 254 | for (int i = 0; i < WARPS; i++) 255 | dy_state0 += from_warp(transpose(dyt), i, (float4*)share) % state[i*WARPS+warpi].t(); 256 | 257 | FTile dq = dy_state0; 258 | dq += dqb % transpose(bwi); 259 | dq += dqk % transpose(kwi); 260 | dq = incl_pref * dq; 261 | GTile(dq_+off, stride) = RTile(dq); 262 | 263 | RTile wqt = transpose(wq), wat = transpose(wa); 264 | 265 | FTile u_dstate, v_dstate, dw; 266 | u_dstate.zero_(); 267 | v_dstate.zero_(); 268 | dw.zero_(); 269 | RTile ones; 270 | for (int i = 0; i < 4; i++) ones.data[i] = to_bf2({1.f,1.f}); 271 | for (int i = 0; i < WARPS; i++) { 272 | int tid = threadIdx.x%32; 273 | if (warpi == i) { 274 | for (int j = 0; j < WARPS; j++) { 275 | RTile ra = dstate[j]; 276 | ((float4*)share)[j*32+tid] = *((float4*)ra.data); 277 | } 278 | } 279 | RTile dstatei;// = dstate[i*WARPS+warpi]; 280 | __syncthreads(); 281 | *((float4*)dstatei.data) = ((float4*)share)[warpi*32+tid]; 282 | __syncthreads(); 283 | RTile dstatei_t = transpose(dstatei); 284 | v_dstate += from_warp(transpose(vt), i, (float4*)share) % dstatei_t; 285 | u_dstate += from_warp(transpose(ut), i, (float4*)share) % dstatei_t; 286 | dw += ones % ((RTile)state[i*WARPS+warpi].t()*dstatei_t); 287 | } 288 | 289 | FTile db = fw * u_dstate; 290 | db += transpose(dab) % wat; 291 | db += transpose(dqb) % wqt; 292 | db = inv_incl_pref * db; 293 | GTile(db_+off, stride) = RTile(db); 294 | 295 | FTile dk = fw * v_dstate; 296 | dk += transpose(dak) % wat; 297 | dk += transpose(dqk) % wqt; 298 | dk = inv_incl_pref * dk; 299 | GTile(dk_+off, stride) = RTile(dk); 300 | 301 | dw = fw * dw; 302 | dw += fast_dw<1>(dab,wa,bwi); 303 | dw += fast_dw<1>(dak,wa,kwi); 304 | dw += fast_dw<0>(dqb,wq,bwi); 305 | dw += fast_dw<0>(dqk,wq,kwi); 306 | FTile tmp; 307 | dw += cumsumv<0,0>(tmp = v_dstate*(fw*kwi)); 308 | dw += cumsumv<0,0>(tmp = u_dstate*(fw*bwi)); 309 | dw += cumsumv<0,1>(tmp = dab_u_state0*wa); 310 | dw += cumsumv<1,1>(tmp = dy_state0*wq); 311 | 312 | FTile dw_fac = (RTile)sw; 313 | apply_(dw_fac, [](float x) { return -__expf(x); }); 314 | dw = dw * dw_fac; 315 | GTile(dw_+off, stride) = RTile(dw); 316 | 317 | for (int i = 0; i < WARPS; i++) { 318 | FTile ndstate = dstate[i] * from_warp(fw, i, (float4*)share); 319 | ndstate += dyt % from_warp(wqt, i, (float4*)share); 320 | ndstate += dab_ut % from_warp(wat, i, (float4*)share); 321 | dstate[i] = ndstate; 322 | } 323 | } 324 | for (int i = 0; i < WARPS; i++) { 325 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16; 326 | GTile(ds0_+off, C) = dstate[i]; 327 | } 328 | } 329 | 330 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*dy, bf*s, bf*dsT, bf*dw, bf*dq, bf*dk, bf*dv, bf*da, bf*db, bf*ds0) { 331 | assert(T%16 == 0); 332 | constexpr int tmp_size1 = sizeof(float)*32*8*WARPS, tmp_size2 = sizeof(float)*16*16*2; 333 | constexpr int threads = 32*WARPS, shared_mem = sizeof(STile)*WARPS*bw_stages*(7+WARPS) + (tmp_size1 > tmp_size2 ? tmp_size1 : tmp_size2); 334 | static int reported = 0; 335 | if (!reported++) { 336 | #if defined VERBOSE 337 | printf("backward_kernel() uses %d bytes of (dynamic) shared memory\n", shared_mem); 338 | #endif 339 | cudaFuncAttributes attr; 340 | cudaFuncGetAttributes(&attr, backward_kernel); 341 | int cur_mem = attr.maxDynamicSharedSizeBytes; 342 | if (shared_mem > cur_mem) { 343 | #if defined VERBOSE 344 | printf("Increasing backward_kernel's MaxDynamicSharedMemorySize from %d to %d\n", cur_mem, shared_mem); 345 | #endif 346 | assert(!cudaFuncSetAttribute(backward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem)); 347 | } 348 | } 349 | backward_kernel<<>>(T,H,w,q,k,v,a,b,dy,s,dsT,dw,dq,dk,dv,da,db,ds0); 350 | } 351 | 352 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/chunked_cuda/chunked_cuda.py: -------------------------------------------------------------------------------- 1 | import os, torch as th 2 | from torch.utils.cpp_extension import load 3 | 4 | class RWKV7_chunked(th.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, q,w,k,v,a,b,s0): 7 | B,T,H,C = w.shape 8 | assert T%16 == 0 9 | if not th.compiler.is_compiling(): 10 | assert hasattr(th.ops.wind_chunked_cuda, 'forward'), 'Requires a loaded kernel from load_chunked_cuda(head_size)' 11 | assert all(i.dtype==th.bfloat16 for i in [w,q,k,v,a,b,s0]) 12 | assert all(i.is_contiguous() for i in [w,q,k,v,a,b,s0]) 13 | assert all(i.shape == w.shape for i in [w,q,k,v,a,b]) 14 | assert list(s0.shape) == [B,H,C,C] 15 | y = th.empty_like(v) 16 | sT = th.empty_like(s0) 17 | if any(i.requires_grad for i in [w,q,k,v,a,b,s0]): 18 | s = th.empty(B,H,T//16,C,C, dtype=th.bfloat16,device=w.device) 19 | else: 20 | s = None 21 | th.ops.wind_chunked_cuda.forward(w,q,k,v,a,b, s0,y,s,sT) 22 | ctx.save_for_backward(w,q,k,v,a,b,s) 23 | return y, sT 24 | @staticmethod 25 | def backward(ctx, dy, dsT): 26 | w,q,k,v,a,b,s = ctx.saved_tensors 27 | B,T,H,C = w.shape 28 | if not th.compiler.is_compiling(): 29 | assert all(i.dtype==th.bfloat16 for i in [dy,dsT]) 30 | assert all(i.is_contiguous() for i in [dy,dsT]) 31 | dw,dq,dk,dv,da,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,a,b,dsT]] 32 | th.ops.wind_chunked_cuda.backward(w,q,k,v,a,b, dy,s,dsT, dw,dq,dk,dv,da,db,ds0) 33 | return dq,dw,dk,dv,da,db,ds0 34 | 35 | def attn_chunked_cuda(r,w,k,v,a,b, s0 = None): 36 | B,T,H,C = w.shape 37 | if s0 is None: s0 = th.zeros(B,H,C,C, dtype=th.bfloat16,device=w.device) 38 | return RWKV7_chunked.apply(r,w,k,v,a,b, s0) 39 | 40 | def load_chunked_cuda(head_size): 41 | if hasattr(th.ops.wind_chunked_cuda, 'forward'): return 42 | CUDA_FLAGS = ["-res-usage", f'-D_C_={head_size}', "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] 43 | if head_size == 256: CUDA_FLAGS.append('-maxrregcount=128') 44 | path = os.path.dirname(__file__) 45 | load(name="wind_chunked_cuda", sources=[os.path.join(path,'chunked_cuda.cu'), os.path.join(path,'chunked_cuda.cpp')], is_python_module=False, verbose=False, extra_cuda_cflags=CUDA_FLAGS) 46 | assert hasattr(th.ops.wind_chunked_cuda, 'forward') 47 | 48 | def attn_chunked_cuda_wrap(r,w,k,v,a,b, head_size): 49 | B,T,HC = w.shape 50 | C = head_size 51 | H = HC//C 52 | r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,a,b]] 53 | s0 = th.zeros(B,H,C,C, dtype=th.bfloat16,device=w.device) 54 | return attn_chunked_cuda(r,w,k,v,a,b,s0)[0].view(B,T,HC) 55 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/chunked_cuda/tile.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, Johan Sokrates Wind 2 | 3 | #include 4 | #include 5 | 6 | //TODO: static? inline? __align__(16)? 7 | 8 | using bf = __nv_bfloat16; 9 | using bf2 = __nv_bfloat162; 10 | using uint = unsigned int; 11 | __device__ inline float to_float(const bf & u) { return __bfloat162float(u); } 12 | __device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); } 13 | __device__ inline float2 to_float2(const bf2 & u) { return __bfloat1622float2(u); } 14 | __device__ inline float2 to_float2(const float2 & u) { return u; } 15 | __device__ inline bf2 to_bf2(const float2 & u) { return __float22bfloat162_rn(u); } 16 | __device__ inline uint& as_uint(const bf2&x) { return *((uint*)(&x)); } 17 | __device__ inline uint __smem(const void*x) { return __cvta_generic_to_shared(x); } 18 | 19 | __device__ void __commit_group() { asm volatile("cp.async.commit_group;\n" ::); } 20 | __device__ void __wait_group() { asm volatile("cp.async.wait_all;\n" ::); } 21 | template __device__ void __wait_groups() { asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); } 22 | 23 | __device__ void __copy_wait() { __commit_group(); __wait_group(); } 24 | 25 | __device__ void operator*=(float2&a, const float2&b) { a.x *= b.x; a.y *= b.y; } 26 | __device__ void operator+=(float2&a, const float2&b) { a.x += b.x; a.y += b.y; } 27 | __device__ float2 operator+(const float2&a, const float2&b) { return {a.x+b.x,a.y+b.y}; } 28 | __device__ float2 operator*(const float2&a, const float2&b) { return {a.x*b.x,a.y*b.y}; } 29 | 30 | struct STile; 31 | struct RTile; 32 | struct FTile; 33 | 34 | struct GTile { 35 | bf*ga; 36 | int stride; 37 | __device__ GTile(bf*ga_, int stride_) : ga(ga_), stride(stride_) {} 38 | __device__ GTile& operator=(const RTile&); 39 | }; 40 | struct GFTile { 41 | float*ga; 42 | int stride; 43 | __device__ GFTile(float*ga_, int stride_) : ga(ga_), stride(stride_) {} 44 | __device__ GFTile& operator=(const FTile&); 45 | }; 46 | struct STileT { STile*st; }; 47 | 48 | struct __align__(16) STile { 49 | bf data[16*16]; 50 | __device__ STile() {} 51 | __device__ STile(const RTile&o) { *this=o; } 52 | __device__ STile& operator=(const GTile&); 53 | __device__ STile& operator=(const RTile&); 54 | __device__ STileT t() { return STileT{this}; } 55 | }; 56 | struct Product { const RTile*a, *b; }; 57 | struct ProductPlus { const RTile*a, *b; const FTile* c; }; 58 | struct RTile { 59 | bf2 data[4]; 60 | __device__ RTile() {} 61 | __device__ void zero_() { data[0] = data[1] = data[2] = data[3] = to_bf2({0.f,0.f}); } 62 | __device__ RTile(const STile&o) { *this=o; } 63 | __device__ RTile(const STileT&o) { *this=o; } 64 | __device__ RTile(const FTile&o) { *this=o; } 65 | __device__ RTile& operator=(const STile&); 66 | __device__ RTile& operator=(const STileT&); 67 | __device__ RTile& operator=(const FTile&fa); 68 | __device__ RTile& operator=(const GTile&); 69 | }; 70 | struct FTile { 71 | union { 72 | float2 data[4]; 73 | float fdata[8]; 74 | }; 75 | __device__ void zero_() { data[0] = data[1] = data[2] = data[3] = {0.f,0.f}; } 76 | __device__ FTile() {} 77 | __device__ FTile(const FTile&o) { for (int i = 0; i < 4; i++) data[i] = o.data[i]; } 78 | __device__ FTile(const RTile&r) { *this=r; } 79 | __device__ FTile(const Product&p) { *this=p; } 80 | __device__ FTile(const ProductPlus&p) { *this=p; } 81 | __device__ FTile& operator=(const Product&); 82 | __device__ FTile& operator=(const RTile&); 83 | __device__ FTile& operator=(const ProductPlus&); 84 | __device__ FTile& operator+=(const Product&); 85 | __device__ FTile& operator+=(const FTile&o) { for (int i = 0; i < 4; i++) data[i] += o.data[i]; return *this; } 86 | }; 87 | 88 | __device__ void print(STile t) { 89 | if (threadIdx.x == 0) { 90 | for (int i = 0; i < 16; i++) { 91 | for (int j = 0; j < 16; j++) { 92 | printf("%f ", to_float(t.data[i*16+j])); 93 | } 94 | printf("\n"); 95 | } 96 | printf("\n"); 97 | } 98 | } 99 | 100 | template 101 | __device__ void print(T t, int warpi = 0) { 102 | int tid = threadIdx.x - warpi*32; 103 | for (int i = 0; i < 16; i++) { 104 | for (int j = 0; j < 16; j += 2) { 105 | if (tid == i%8*4+j%8/2) { 106 | float2 xy = to_float2(t.data[i/8+j/8*2]); 107 | printf("%f %f ", xy.x, xy.y); 108 | //printf("T%d:{a%d,a%d} ", threadIdx.x, (i/8+j/8*2)*2, (i/8+j/8*2)*2+1); 109 | } 110 | __syncthreads(); 111 | } 112 | if (tid == 0) printf("\n"); 113 | __syncthreads(); 114 | } 115 | if (tid == 0) printf("\n"); 116 | __syncthreads(); 117 | } 118 | 119 | template 120 | __device__ void print8(T mat) { 121 | for (int i = 0; i < 8; i++) { 122 | for (int j = 0; j < 8; j += 2) { 123 | if (threadIdx.x == i%8*4+j%8/2) { 124 | float2 xy = to_float2(mat); 125 | printf("%f %f ", xy.x, xy.y); 126 | } 127 | __syncthreads(); 128 | } 129 | if (threadIdx.x == 0) printf("\n"); 130 | __syncthreads(); 131 | } 132 | if (threadIdx.x == 0) printf("\n"); 133 | __syncthreads(); 134 | } 135 | 136 | 137 | 138 | __device__ void load(STile&sa, bf*ga, int stride) { 139 | int i = threadIdx.x%32/2, j = threadIdx.x%2; 140 | asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" :: "r"(__smem(&sa.data[i*16+j*8])), "l"(ga+stride*i+j*8), "n"(16)); 141 | } 142 | 143 | __device__ void load(RTile&ra, const STile&sa) { 144 | int i = threadIdx.x%8, j = threadIdx.x%32/16, k = threadIdx.x/8%2; 145 | asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" 146 | : "=r"(as_uint(ra.data[0])), "=r"(as_uint(ra.data[1])), "=r"(as_uint(ra.data[2])), "=r"(as_uint(ra.data[3])) 147 | : "r"(__smem(&sa.data[i*16+j*8+k*8*16]))); 148 | } 149 | __device__ void loadT(RTile&ra, const STile&sa) { 150 | int i = threadIdx.x%8, j = threadIdx.x%32/16, k = threadIdx.x/8%2; 151 | asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" 152 | : "=r"(as_uint(ra.data[0])), "=r"(as_uint(ra.data[1])), "=r"(as_uint(ra.data[2])), "=r"(as_uint(ra.data[3])) 153 | : "r"(__smem(&sa.data[i*16+j*8*16+k*8]))); 154 | } 155 | 156 | __device__ static inline void __m16n8k16(float2&d0, float2&d1, const bf2 &a0, const bf2 &a1, const bf2 &a2, const bf2 &a3, const bf2 &b0, const bf2 &b1, const float2 &c0, const float2 &c1) { 157 | asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" 158 | : "=f"(d0.x), "=f"(d0.y), "=f"(d1.x), "=f"(d1.y) 159 | : "r"(as_uint(a0)), "r"(as_uint(a1)), "r"(as_uint(a2)), "r"(as_uint(a3)), 160 | "r"(as_uint(b0)), "r"(as_uint(b1)), 161 | "f"(c0.x), "f"(c0.y), "f"(c1.x), "f"(c1.y)); 162 | } 163 | __device__ void mma(FTile&rd, const RTile&ra, const RTile&rb, const FTile&rc) { // d = a*b^T + c 164 | __m16n8k16(rd.data[0],rd.data[1], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[0],rb.data[2], rc.data[0],rc.data[1]); 165 | __m16n8k16(rd.data[2],rd.data[3], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[1],rb.data[3], rc.data[2],rc.data[3]); 166 | } 167 | __device__ static inline void __m16n8k16(float2&d0, float2&d1, const bf2 &a0, const bf2 &a1, const bf2 &a2, const bf2 &a3, const bf2 &b0, const bf2 &b1) { 168 | asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" 169 | : "+f"(d0.x), "+f"(d0.y), "+f"(d1.x), "+f"(d1.y) 170 | : "r"(as_uint(a0)), "r"(as_uint(a1)), "r"(as_uint(a2)), "r"(as_uint(a3)), 171 | "r"(as_uint(b0)), "r"(as_uint(b1)), 172 | "f"(d0.x), "f"(d0.y), "f"(d1.x), "f"(d1.y)); 173 | } 174 | __device__ void mma(FTile&rd, const RTile&ra, const RTile&rb) { // d += a*b^T 175 | __m16n8k16(rd.data[0],rd.data[1], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[0],rb.data[2]); 176 | __m16n8k16(rd.data[2],rd.data[3], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[1],rb.data[3]); 177 | } 178 | __device__ void mm(FTile&rd, const RTile&ra, const RTile&rb) { // d = a*b^T 179 | __m16n8k16(rd.data[0],rd.data[1], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[0],rb.data[2], {0.f,0.f}, {0.f,0.f}); 180 | __m16n8k16(rd.data[2],rd.data[3], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[1],rb.data[3], {0.f,0.f}, {0.f,0.f}); 181 | } 182 | 183 | __device__ void store(const FTile&ra, float*ga, int stride) { 184 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2; 185 | *((float2*)&ga[ i *stride+j ]) = ra.data[0]; 186 | *((float2*)&ga[(i+8)*stride+j ]) = ra.data[1]; 187 | *((float2*)&ga[ i *stride+j+8]) = ra.data[2]; 188 | *((float2*)&ga[(i+8)*stride+j+8]) = ra.data[3]; 189 | } 190 | 191 | __device__ void store(const RTile&ra, bf*ga, int stride) { 192 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2; 193 | *((bf2*)&ga[ i *stride+j ]) = ra.data[0]; 194 | *((bf2*)&ga[(i+8)*stride+j ]) = ra.data[1]; 195 | *((bf2*)&ga[ i *stride+j+8]) = ra.data[2]; 196 | *((bf2*)&ga[(i+8)*stride+j+8]) = ra.data[3]; 197 | } 198 | __device__ void load(RTile&ra, bf*ga, int stride) { 199 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2; 200 | ra.data[0] = *((bf2*)&ga[ i *stride+j ]); 201 | ra.data[1] = *((bf2*)&ga[(i+8)*stride+j ]); 202 | ra.data[2] = *((bf2*)&ga[ i *stride+j+8]); 203 | ra.data[3] = *((bf2*)&ga[(i+8)*stride+j+8]); 204 | } 205 | __device__ void store(const RTile&ra, STile&sa) { //TODO: reduce bank conflicts? 206 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2; 207 | *((bf2*)&sa.data[ i *16+j ]) = ra.data[0]; 208 | *((bf2*)&sa.data[(i+8)*16+j ]) = ra.data[1]; 209 | *((bf2*)&sa.data[ i *16+j+8]) = ra.data[2]; 210 | *((bf2*)&sa.data[(i+8)*16+j+8]) = ra.data[3]; 211 | } 212 | 213 | __device__ void convert(RTile&ra, const FTile&fa) { 214 | ra.data[0] = to_bf2(fa.data[0]); 215 | ra.data[1] = to_bf2(fa.data[1]); 216 | ra.data[2] = to_bf2(fa.data[2]); 217 | ra.data[3] = to_bf2(fa.data[3]); 218 | } 219 | __device__ void convert(FTile&fa, const RTile&ra) { 220 | fa.data[0] = to_float2(ra.data[0]); 221 | fa.data[1] = to_float2(ra.data[1]); 222 | fa.data[2] = to_float2(ra.data[2]); 223 | fa.data[3] = to_float2(ra.data[3]); 224 | } 225 | 226 | __device__ STile& STile::operator=(const GTile& ga) { load(*this, ga.ga, ga.stride); return *this; } 227 | __device__ RTile& RTile::operator=(const GTile& ga) { load(*this, ga.ga, ga.stride); return *this; } 228 | __device__ RTile& RTile::operator=(const STile& sa) { load(*this, sa); return *this; } 229 | __device__ STile& STile::operator=(const RTile& ra) { store(ra, *this); return *this; } 230 | __device__ RTile& RTile::operator=(const STileT& sa) { loadT(*this, *sa.st); return *this; } 231 | __device__ Product operator%(const RTile&ra, const RTile&rb) { return Product{&ra,&rb}; } 232 | __device__ ProductPlus operator+(const Product&prod, const FTile&rc) { return ProductPlus{prod.a,prod.b,&rc}; } 233 | __device__ FTile& FTile::operator=(const Product& prod) { mm(*this, *prod.a, *prod.b); return *this; } 234 | __device__ FTile& FTile::operator=(const ProductPlus& prod) { mma(*this, *prod.a, *prod.b, *prod.c); return *this; } 235 | __device__ FTile& FTile::operator+=(const Product& prod) { mma(*this, *prod.a, *prod.b); return *this; } 236 | __device__ RTile& RTile::operator=(const FTile&fa) { convert(*this,fa); return *this; } 237 | __device__ FTile& FTile::operator=(const RTile&ra) { convert(*this,ra); return *this; } 238 | __device__ GTile& GTile::operator=(const RTile&ra) { store(ra, this->ga, this->stride); return *this; } 239 | __device__ GFTile& GFTile::operator=(const FTile&fa) { store(fa, this->ga, this->stride); return *this; } 240 | 241 | // Is this kind of cumsum better than multiplying with a triangular matrix of ones? 242 | template 243 | __device__ FTile cumsumv(FTile&w) { 244 | int tid = threadIdx.x%32, t = tid/4; 245 | 246 | FTile ret; 247 | if (inclusive) for (int i = 0; i < 4; i++) ret.data[i] = w.data[i]; 248 | else for (int i = 0; i < 4; i++) ret.data[i] = float2{0.f,0.f}; 249 | 250 | for (int b = 0; b < 3; b++) { 251 | for (int i = 0; i < 8; i++) { 252 | float other_w = __shfl_xor_sync(0xffffffff, w.fdata[i], 4<>b)%2 == !rev) ret.fdata[i] += other_w; 254 | w.fdata[i] += other_w; 255 | } 256 | } 257 | for (int i : {0,1,4,5}) { 258 | float &w0 = w.fdata[i^(2*!rev)], &w1 = w.fdata[i^(2*rev)]; 259 | ret.fdata[i^(2*!rev)] += w1; 260 | w0 += w1; 261 | w1 = w0; 262 | } 263 | return ret; 264 | } 265 | 266 | template 267 | __device__ FTile cumprodv(FTile&w) { 268 | int tid = threadIdx.x%32, t = tid/4; 269 | 270 | FTile ret; 271 | if (inclusive) for (int i = 0; i < 4; i++) ret.data[i] = w.data[i]; 272 | else for (int i = 0; i < 4; i++) ret.data[i] = float2{1.f,1.f}; 273 | 274 | for (int b = 0; b < 3; b++) { 275 | for (int i = 0; i < 8; i++) { 276 | float other_w = __shfl_xor_sync(0xffffffff, w.fdata[i], 4<>b)%2 == !rev) ret.fdata[i] *= other_w; 278 | w.fdata[i] *= other_w; 279 | } 280 | } 281 | for (int i : {0,1,4,5}) { 282 | float &w0 = w.fdata[i^(2*!rev)], &w1 = w.fdata[i^(2*rev)]; 283 | ret.fdata[i^(2*!rev)] *= w1; 284 | w0 *= w1; 285 | w1 = w0; 286 | } 287 | return ret; 288 | } 289 | 290 | __device__ FTile operator*(const FTile&a, const FTile&b) { 291 | FTile ret; 292 | for (int i = 0; i < 8; i++) ret.fdata[i] = a.fdata[i]*b.fdata[i]; 293 | return ret; 294 | } 295 | 296 | template // Lower triangular 297 | __device__ FTile sum_warp(float*share, const FTile&f) { // Requires share of size sizeof(float)*32*8*WARPS 298 | int tid = threadIdx.x%32, warpi = threadIdx.x/32; 299 | __syncthreads(); 300 | for (int i = 0; i < 8; i++) { 301 | if (triangular && (i == 4 || i == 5)) continue; 302 | share[32*(warpi*8+i)+tid] = f.fdata[i]; 303 | } 304 | __syncthreads(); 305 | #pragma unroll 306 | for (int k = warpi; k < 8; k += WARPS) { 307 | if (triangular && (k == 4 || k == 5)) continue; 308 | float sum = 0; 309 | for (int i = 1; i < WARPS; i++) { 310 | sum += share[32*(i*8+k)+tid]; 311 | } 312 | share[32*k+tid] += sum; 313 | } 314 | __syncthreads(); 315 | FTile ret; 316 | for (int i = 0; i < 8; i++) { 317 | if (triangular && (i == 4 || i == 5)) 318 | ret.fdata[i] = 0; 319 | else 320 | ret.fdata[i] = share[32*i+tid]; 321 | } 322 | __syncthreads(); 323 | return ret; 324 | } 325 | 326 | __device__ RTile from_warp(const RTile&ra, int src, float4*share) { 327 | int tid = threadIdx.x%32, warpi = threadIdx.x/32; 328 | if (warpi == src) share[tid] = *((float4*)ra.data); 329 | __syncthreads(); 330 | RTile ret; 331 | *((float4*)ret.data) = share[tid]; 332 | __syncthreads(); 333 | return ret; 334 | } 335 | 336 | // inv(I-f) where f is strictly lower triangular 337 | __device__ FTile tri_minv(const FTile&f, float*share) { 338 | int i0 = threadIdx.x%32/4, j0 = threadIdx.x%4*2; 339 | float inv[16] = {}; 340 | for (int k = 0; k < 8; k++) { 341 | int i = i0+k/2%2*8, j = j0+k%2+k/4*8; 342 | share[i*16+j] = f.fdata[k]; 343 | } 344 | int tid = threadIdx.x%32; 345 | inv[tid%16] = 1; 346 | for (int i = 1; i < 16; i++) { 347 | for (int j = 0; j < i; j++) { 348 | float fac = share[i*16+j]; 349 | inv[i] += fac*inv[j]; 350 | } 351 | } 352 | for (int i = 0; i < 16; i++) 353 | share[tid*16+i] = inv[i]; 354 | FTile ret; 355 | for (int k = 0; k < 8; k++) { 356 | int i = i0+k/2%2*8, j = j0+k%2+k/4*8; 357 | ret.fdata[k] = share[j*16+i]; 358 | } 359 | return ret; 360 | } 361 | 362 | template 363 | __device__ FTile tril(const FTile&f) { 364 | int i0 = threadIdx.x%32/4, j0 = threadIdx.x%4*2; 365 | FTile ret; 366 | for (int k = 0; k < 8; k++) { 367 | int i = i0+k/2%2*8, j = j0+k%2+k/4*8; 368 | if (strict) ret.fdata[k] = (i>j ? f.fdata[k] : 0.f); 369 | else ret.fdata[k] = (i>=j ? f.fdata[k] : 0.f); 370 | } 371 | return ret; 372 | } 373 | 374 | template 375 | __device__ void apply_(FTile&tile, F f) { 376 | for (int i = 0; i < 8; i++) tile.fdata[i] = f(tile.fdata[i]); 377 | } 378 | 379 | __device__ bf2 transpose(bf2 a) { 380 | bf2 ret; 381 | asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(as_uint(ret)) : "r"(as_uint(a))); 382 | return ret; 383 | } 384 | 385 | __device__ RTile transpose(const RTile&ra) { 386 | RTile rb; 387 | rb.data[0] = transpose(ra.data[0]); 388 | rb.data[1] = transpose(ra.data[2]); 389 | rb.data[2] = transpose(ra.data[1]); 390 | rb.data[3] = transpose(ra.data[3]); 391 | return rb; 392 | } 393 | 394 | template 395 | __device__ FTile slow_dw(const RTile&A, const RTile&q, const RTile&k, STile*share) { 396 | share[0] = A; 397 | share[1] = q; 398 | share[2] = k; 399 | __syncthreads(); 400 | if (threadIdx.x%32 == 0) { 401 | for (int k = 0; k < 16; k++) { 402 | for (int j = 0; j < 16; j++) { 403 | float sum = 0; 404 | for (int l = 0; l < k; l++) { 405 | for (int r = k+strict; r < 16; r++) { 406 | sum += to_float(share[0].data[r*16+l]) * to_float(share[1].data[r*16+j]) * to_float(share[2].data[l*16+j]); 407 | } 408 | } 409 | share[3].data[k*16+j] = to_bf(sum); 410 | } 411 | } 412 | } 413 | __syncthreads(); 414 | RTile ret = (RTile)share[3]; 415 | __syncthreads(); 416 | return ret; 417 | } 418 | 419 | 420 | __device__ static inline void __m16n8k8(float2&d0, float2&d1, const bf2 &a0, const bf2 &a1, const bf2 &b0) { 421 | asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" 422 | : "=f"(d0.x), "=f"(d0.y), "=f"(d1.x), "=f"(d1.y) : "r"(as_uint(a0)), "r"(as_uint(a1)), "r"(as_uint(b0)), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); 423 | } 424 | 425 | template 426 | __device__ RTile fast_dw(const RTile&A, const RTile&q, const RTile&k) { 427 | float2 qkA8[4]; 428 | RTile kt = transpose(k), qt = transpose(q); 429 | __m16n8k8(qkA8[0],qkA8[1], qt.data[2], qt.data[3], transpose(A.data[1])); 430 | __m16n8k8(qkA8[2],qkA8[3], kt.data[0], kt.data[1], A.data[1]); 431 | for (int x : {0,1}) { 432 | qkA8[x] *= to_float2(kt.data[x]); 433 | qkA8[2+x] *= to_float2(qt.data[2+x]); 434 | } 435 | 436 | int tid = threadIdx.x%32, j = threadIdx.x%4; 437 | // Non-inclusive cumsum 438 | for (int i = 0; i < 4; i++) { 439 | float sum = qkA8[i].x+qkA8[i].y; 440 | float psum = __shfl_xor_sync(0xffffffff, sum, 1); 441 | float ppsum = __shfl_xor_sync(0xffffffff, sum+psum, 2); 442 | if (i < 2) { 443 | psum = ppsum*(j>=2)+psum*(j%2); 444 | qkA8[i].y = psum + qkA8[i].x; 445 | qkA8[i].x = psum; 446 | } else { 447 | psum = ppsum*(j<2)+psum*(j%2==0); 448 | qkA8[i].x = psum + qkA8[i].y; 449 | qkA8[i].y = psum; 450 | } 451 | } 452 | 453 | float2 qkA4[4]; 454 | { 455 | RTile k_q; 456 | for (int i = 0; i < 8; i++) ((bf*)k_q.data)[i] = (j<2?((bf*)kt.data)[i]:((bf*)qt.data)[i]); 457 | float lower_left = (tid >= 16 && j < 2); 458 | bf2 A0 = to_bf2(to_float2(A.data[0])*float2{lower_left,lower_left}); 459 | bf2 A3 = to_bf2(to_float2(A.data[3])*float2{lower_left,lower_left}); 460 | __m16n8k8(qkA4[0],qkA4[1], k_q.data[0], k_q.data[1], A0 + transpose(A0)); 461 | __m16n8k8(qkA4[2],qkA4[3], k_q.data[2], k_q.data[3], A3 + transpose(A3)); 462 | for (int i = 0; i < 4; i++) 463 | qkA4[i] *= to_float2(k_q.data[i]); 464 | } 465 | 466 | // Non-inclusive cumsum 467 | for (int i = 0; i < 4; i++) { 468 | float sum = qkA4[i].x+qkA4[i].y; 469 | float psum = __shfl_xor_sync(0xffffffff, sum, 1); 470 | psum *= (j%2 == j<2); 471 | qkA4[i] = {psum + qkA4[i].y*(j>=2), psum + qkA4[i].x*(j<2)}; 472 | } 473 | 474 | FTile ret; 475 | ret.data[0] = qkA8[0]+qkA4[0]; 476 | ret.data[1] = qkA8[1]+qkA4[1]; 477 | ret.data[2] = qkA8[2]+qkA4[2]; 478 | ret.data[3] = qkA8[3]+qkA4[3]; 479 | 480 | for (int ci : {0,1}) { 481 | for (int ti : {0,1}) { 482 | int Ai = ti*3, di = ti*2+ci; 483 | bf A8x = __shfl_sync(0xffffffff, A.data[Ai].x, 8+(j>=2)*18); 484 | bf A12x = __shfl_sync(0xffffffff, A.data[Ai].x, 12+(j>=2)*18); 485 | bf A12y = __shfl_sync(0xffffffff, A.data[Ai].y, 12+(j>=2)*18); 486 | bf2 nq = __shfl_xor_sync(0xffffffff, qt.data[di], 1); 487 | bf2 pk = __shfl_xor_sync(0xffffffff, kt.data[di], 1); 488 | 489 | bool even = (j%2==0); 490 | float ax = to_float(even?A8x:A12x), ay = to_float(even?A12x:A12y), c = to_float(even?kt.data[di].x:qt.data[di].y); 491 | float2 b = to_float2(j%2?pk:nq); 492 | float d = (ax*b.x+ay*b.y)*c; 493 | ret.data[di].y += even*d; 494 | ret.data[di].x +=!even*d; 495 | } 496 | } 497 | 498 | if (!strict) { 499 | // Do we really need tril<1>()? 500 | ret += (kt % tril<1>(A)) * qt; 501 | } 502 | return transpose(ret); 503 | } 504 | 505 | __device__ void debug_set(RTile&ra, int i, int j, float v) { 506 | if (threadIdx.x%32 == i%8*4+j%8/2) ((bf*)ra.data)[i/8*2+j/8*4+j%2] = to_bf(v); 507 | } 508 | 509 | template 510 | __device__ float2 sumh(const FTile&f, float*share) { // Requires shared of size sizeof(float)*16*WARPS 511 | float2 warpsum = {f.fdata[0]+f.fdata[1]+f.fdata[4]+f.fdata[5], 512 | f.fdata[2]+f.fdata[3]+f.fdata[6]+f.fdata[7]}; 513 | warpsum.x += __shfl_xor_sync(0xffffffff, warpsum.x, 1); 514 | warpsum.y += __shfl_xor_sync(0xffffffff, warpsum.y, 1); 515 | warpsum.x += __shfl_xor_sync(0xffffffff, warpsum.x, 2); 516 | warpsum.y += __shfl_xor_sync(0xffffffff, warpsum.y, 2); 517 | 518 | int tid = threadIdx.x%32, warpi = threadIdx.x/32; 519 | __syncthreads(); 520 | if (tid%4 < 2) 521 | share[warpi*16+tid/4+tid%4*8] = (tid%4?warpsum.y:warpsum.x); 522 | __syncthreads(); 523 | if (warpi == 0 && tid < 16) { 524 | float sum = 0; 525 | for (int i = 1; i < WARPS; i++) { 526 | sum += share[i*16+tid]; 527 | } 528 | share[tid] += sum; 529 | } 530 | __syncthreads(); 531 | float2 ret = {share[tid/4], share[tid/4+8]}; 532 | __syncthreads(); 533 | return ret; 534 | } 535 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/chunked_cuda_varlen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johanwind/wind_rwkv/acc0488e8c86ee5e7f3184ae4a9c1d97e1e14fff/wind_rwkv/rwkv7/chunked_cuda_varlen/__init__.py -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/chunked_cuda_varlen/chunked_cuda_varlen.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, Johan Sokrates Wind 2 | 3 | #include 4 | 5 | struct __nv_bfloat16; 6 | using bf = __nv_bfloat16; 7 | using torch::Tensor; 8 | 9 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*s0, bf*y, bf*s, bf*sT, long long*cu_seqlens); 10 | 11 | void forward(Tensor &w, Tensor &q, Tensor &k, Tensor &v, Tensor &a, Tensor &b, Tensor &s0, Tensor &y, c10::optional s, Tensor &sT, Tensor &cu_seqlens) { 12 | int B = cu_seqlens.sizes()[0]-1, T = w.sizes()[0], H = w.sizes()[1]; 13 | cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)s0.data_ptr(), (bf*)y.data_ptr(), s.has_value() ? (bf*)s.value().data_ptr() : NULL, (bf*)sT.data_ptr(), (long long*)cu_seqlens.data_ptr()); 14 | } 15 | 16 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*dy, bf*s, bf*dsT, bf*dw, bf*dq, bf*dk, bf*dv, bf*da, bf*db, bf*ds0, long long*cu_seqlens); 17 | 18 | void backward(Tensor &w, Tensor &q, Tensor &k, Tensor &v, Tensor &a, Tensor &b, Tensor &dy, 19 | Tensor &s, Tensor &dsT, Tensor &dw, Tensor &dq, Tensor &dk, Tensor &dv, Tensor &da, Tensor &db, Tensor &ds0, Tensor &cu_seqlens) { 20 | int B = cu_seqlens.sizes()[0]-1, T = w.sizes()[0], H = w.sizes()[1]; 21 | cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)dy.data_ptr(), 22 | (bf*)s.data_ptr(), (bf*)dsT.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)da.data_ptr(), (bf*)db.data_ptr(), (bf*)ds0.data_ptr(), (long long*)cu_seqlens.data_ptr()); 23 | } 24 | 25 | TORCH_LIBRARY(wind_chunked_cuda_varlen, m) { 26 | m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor a, Tensor b, Tensor s0, Tensor(a!) y, Tensor? s, Tensor(c!) sT, Tensor cu_seqlens) -> ()"); 27 | m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor a, Tensor b, Tensor dy, Tensor s, Tensor dsT, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) da, Tensor(f!) db, Tensor(g!) ds0, Tensor cu_seqlens) -> ()"); 28 | } 29 | 30 | TORCH_LIBRARY_IMPL(wind_chunked_cuda_varlen, CUDA, m) { 31 | m.impl("forward", &forward); 32 | m.impl("backward", &backward); 33 | } 34 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/chunked_cuda_varlen/chunked_cuda_varlen.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, Johan Sokrates Wind 2 | 3 | #include "tile.cuh" 4 | #include 5 | typedef bf * __restrict__ F_; 6 | typedef float * __restrict__ F32_; 7 | 8 | constexpr int WARPS = _C_/16; 9 | constexpr int fw_stages = 1, bw_stages = 1; 10 | 11 | __global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ s0_, bf* y_, bf* s_, bf* sT_, long long*cu_seqlens) { 12 | constexpr int C = _C_, K = 16; 13 | int bi = blockIdx.y, hi = blockIdx.x; 14 | int t_off = cu_seqlens[bi]/K, nT = cu_seqlens[bi+1]/K - t_off; 15 | extern __shared__ char smem_[]; 16 | char*smem = smem_; 17 | 18 | STile *sw_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 19 | STile *sq_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 20 | STile *sk_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 21 | STile *sv_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 22 | STile *sa_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 23 | STile *sb_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 24 | char*share = (char*)smem; 25 | 26 | int stride = H*C; 27 | int warpi = threadIdx.x/32; 28 | 29 | auto push = [&](int t) { 30 | int off = (t_off + t)*K*H*C + hi*C + warpi*16; 31 | int si = t%fw_stages; 32 | sw_[si*WARPS+warpi] = GTile(w_+off, stride); 33 | sq_[si*WARPS+warpi] = GTile(q_+off, stride); 34 | sk_[si*WARPS+warpi] = GTile(k_+off, stride); 35 | sv_[si*WARPS+warpi] = GTile(v_+off, stride); 36 | sa_[si*WARPS+warpi] = GTile(a_+off, stride); 37 | sb_[si*WARPS+warpi] = GTile(b_+off, stride); 38 | }; 39 | for (int t = 0; t < fw_stages-1 && t < nT; t++) push(t), __commit_group(); 40 | 41 | FTile state[WARPS]; 42 | for (int i = 0; i < WARPS; i++) { 43 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16; 44 | RTile tmp; 45 | tmp = GTile(s0_+off, C); 46 | state[i] = tmp; 47 | } 48 | 49 | for (int t = 0; t < nT; t++) { 50 | __syncthreads(); 51 | if (t+fw_stages-1 < nT) 52 | push(t+fw_stages-1); 53 | __commit_group(); 54 | __wait_groups(); 55 | __syncthreads(); 56 | int si = t%fw_stages; 57 | STile &sw = sw_[si*WARPS+warpi], &sq = sq_[si*WARPS+warpi], &sk = sk_[si*WARPS+warpi], &sv = sv_[si*WARPS+warpi], &sa = sa_[si*WARPS+warpi], &sb = sb_[si*WARPS+warpi]; 58 | 59 | FTile w = (RTile)sw; 60 | apply_(w, [](float x) { return __expf(-__expf(x)); }); 61 | FTile fw = w; 62 | FTile non_incl_pref = cumprodv<0,0>(fw); 63 | FTile incl_pref = non_incl_pref * w; 64 | FTile inv_incl_pref = incl_pref; 65 | apply_(inv_incl_pref, [](float x) { return 1.f/x; }); 66 | 67 | RTile wq = (RTile)sq * incl_pref, kwi = (RTile)sk * inv_incl_pref; 68 | RTile wa = (RTile)sa * non_incl_pref, bwi = (RTile)sb * inv_incl_pref; 69 | FTile ab = sum_warp<1,WARPS>((float*)share, tril<1>(wa % bwi)); 70 | RTile ak = sum_warp<1,WARPS>((float*)share, tril<1>(wa % kwi)); 71 | 72 | RTile ab_inv; 73 | __syncthreads(); 74 | if (threadIdx.x < 32) ab_inv = tri_minv(ab, (float*)share); 75 | __syncthreads(); 76 | ab_inv = from_warp(ab_inv, 0, (float4*)share); 77 | 78 | RTile vt = sv.t(); 79 | FTile ab_ut = vt % ak; 80 | for (int i = 0; i < WARPS; i++) 81 | ab_ut += state[i] % from_warp(wa, i, (float4*)share); 82 | RTile ut = FTile(ab_ut % ab_inv); 83 | 84 | FTile y = sum_warp<1,WARPS>((float*)share, tril<0>(wq % kwi)) % vt; 85 | y += sum_warp<1,WARPS>((float*)share, tril<0>(wq % bwi)) % ut; 86 | for (int i = 0; i < WARPS; i++) 87 | y += from_warp(wq, i, (float4*)share) % state[i]; 88 | 89 | int off = (t_off+t)*K*H*C + hi*C + warpi*16; 90 | GTile(y_+off, stride) = RTile(y); 91 | 92 | RTile kwt = transpose(kwi*fw), bwt = transpose(bwi*fw); 93 | for (int i = 0; i < WARPS; i++) { 94 | if (s_ != NULL) { 95 | int off = hi*(T/K)*C*C + (t_off+t)*C*C + warpi*16*C + i*16; 96 | GTile(s_+off, C) = (RTile)state[i]; 97 | } 98 | 99 | FTile fstate = state[i] * from_warp(fw, i, (float4*)share); 100 | fstate += vt % from_warp(kwt, i, (float4*)share); 101 | fstate += ut % from_warp(bwt, i, (float4*)share); 102 | state[i] = fstate; 103 | } 104 | } 105 | for (int i = 0; i < WARPS; i++) { 106 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16; 107 | GTile(sT_+off, C) = state[i]; 108 | } 109 | } 110 | 111 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*s0, bf*y, bf*s, bf*sT, long long*cu_seqlens) { 112 | assert(T%16 == 0); 113 | constexpr int tmp_size1 = sizeof(float)*32*8*WARPS, tmp_size2 = sizeof(float)*16*16*2; 114 | constexpr int threads = 32*WARPS, shared_mem = sizeof(STile)*fw_stages*WARPS*6 + (tmp_size1 > tmp_size2 ? tmp_size1 : tmp_size2); 115 | static int reported = 0; 116 | if (!reported++) { 117 | #if defined VERBOSE 118 | printf("forward_kernel() uses %d bytes of (dynamic) shared memory\n", shared_mem); 119 | #endif 120 | cudaFuncAttributes attr; 121 | cudaFuncGetAttributes(&attr, forward_kernel); 122 | int cur_mem = attr.maxDynamicSharedSizeBytes; 123 | if (shared_mem > cur_mem) { 124 | #if defined VERBOSE 125 | printf("Increasing forward_kernel's MaxDynamicSharedMemorySize from %d to %d\n", cur_mem, shared_mem); 126 | #endif 127 | assert(!cudaFuncSetAttribute(forward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem)); 128 | } 129 | } 130 | forward_kernel<<>>(T,H,w,q,k,v,a,b,s0,y,s,sT,cu_seqlens); 131 | } 132 | 133 | 134 | __global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, F_ s_, F_ dsT_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_, bf* ds0_, long long*cu_seqlens) { 135 | constexpr int C = _C_, K = 16; 136 | int bi = blockIdx.y, hi = blockIdx.x; 137 | int t_off = cu_seqlens[bi]/K, nT = cu_seqlens[bi+1]/K - t_off; 138 | extern __shared__ char smem_[]; 139 | char*smem = smem_; 140 | 141 | STile *sw_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 142 | STile *sq_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 143 | STile *sk_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 144 | STile *sv_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 145 | STile *sa_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 146 | STile *sb_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 147 | STile *sdy_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 148 | STile *state_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS*WARPS; 149 | char*share = (char*)smem; 150 | 151 | int stride = H*C; 152 | int warpi = threadIdx.x/32; 153 | 154 | auto push = [&](int t) { 155 | int off = (t_off+t)*K*H*C + hi*C + warpi*16; 156 | int si = t%bw_stages; 157 | sw_[si*WARPS+warpi] = GTile(w_+off, stride); 158 | sq_[si*WARPS+warpi] = GTile(q_+off, stride); 159 | sk_[si*WARPS+warpi] = GTile(k_+off, stride); 160 | sv_[si*WARPS+warpi] = GTile(v_+off, stride); 161 | sa_[si*WARPS+warpi] = GTile(a_+off, stride); 162 | sb_[si*WARPS+warpi] = GTile(b_+off, stride); 163 | sdy_[si*WARPS+warpi] = GTile(dy_+off, stride); 164 | for (int i = 0; i < WARPS; i++) { 165 | int off2 = hi*(T/K)*C*C + (t_off+t)*C*C + warpi*16*C + i*16; 166 | state_[si*WARPS*WARPS+warpi*WARPS+i] = GTile(s_+off2, C); 167 | } 168 | }; 169 | 170 | FTile dstate[WARPS]; 171 | for (int i = 0; i < WARPS; i++) { 172 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16; 173 | RTile tmp; 174 | tmp = GTile(dsT_+off, C); 175 | dstate[i] = tmp; 176 | __commit_group(); 177 | } 178 | 179 | for (int t = 0; t < bw_stages-1 && t < nT; t++) push(nT-1-t), __commit_group(); 180 | 181 | for (int t = nT-1; t >= 0; t--) { 182 | __syncthreads(); 183 | if (t-bw_stages+1 >= 0) 184 | push(t-bw_stages+1); 185 | __commit_group(); 186 | __wait_groups(); 187 | __syncthreads(); 188 | int si = t%bw_stages; 189 | STile &sw = sw_[si*WARPS+warpi], &sq = sq_[si*WARPS+warpi], &sk = sk_[si*WARPS+warpi], &sv = sv_[si*WARPS+warpi], &sa = sa_[si*WARPS+warpi], &sb = sb_[si*WARPS+warpi], &sdy = sdy_[si*WARPS+warpi]; 190 | STile*state = state_+si*WARPS*WARPS; 191 | 192 | FTile w = (RTile)sw; 193 | apply_(w, [](float x) { return __expf(-__expf(x)); }); 194 | FTile fw = w; 195 | FTile non_incl_pref = cumprodv<0,0>(fw); 196 | FTile incl_pref = non_incl_pref * w; 197 | FTile inv_incl_pref = incl_pref; 198 | apply_(inv_incl_pref, [](float x) { return 1.f/x; }); 199 | 200 | RTile wq = (RTile)sq * incl_pref, kwi = (RTile)sk * inv_incl_pref; 201 | RTile wa = (RTile)sa * non_incl_pref, bwi = (RTile)sb * inv_incl_pref; 202 | FTile ab = sum_warp<1,WARPS>((float*)share, tril<1>(wa % bwi)); 203 | RTile ak = sum_warp<1,WARPS>((float*)share, tril<1>(wa % kwi)); 204 | 205 | RTile ab_inv; 206 | __syncthreads(); 207 | if (threadIdx.x < 32) ab_inv = tri_minv(ab, (float*)share); 208 | __syncthreads(); 209 | ab_inv = from_warp(ab_inv, 0, (float4*)share); 210 | 211 | RTile vt = sv.t(); 212 | FTile ab_ut = vt % ak; 213 | for (int i = 0; i < WARPS; i++) 214 | ab_ut += state[warpi*WARPS+i] % from_warp(wa, i, (float4*)share); 215 | RTile ut = FTile(ab_ut % ab_inv); 216 | 217 | FTile y = sum_warp<1,WARPS>((float*)share, tril<0>(wq % kwi)) % vt; 218 | y += sum_warp<1,WARPS>((float*)share, tril<0>(wq % bwi)) % ut; 219 | for (int i = 0; i < WARPS; i++) 220 | y += from_warp(wq, i, (float4*)share) % state[warpi*WARPS+i]; 221 | 222 | RTile qb = sum_warp<1,WARPS>((float*)share, tril<0>(wq % bwi)); 223 | RTile qk = sum_warp<1,WARPS>((float*)share, tril<0>(wq % kwi)); 224 | 225 | RTile dyt = sdy.t(); 226 | FTile dut = FTile(dyt % transpose(qb)); 227 | FTile dv = transpose(qk) % dyt; 228 | for (int i = 0; i < WARPS; i++) { 229 | RTile dstatei = dstate[i]; 230 | dut += dstatei % from_warp(bwi*fw, i, (float4*)share); 231 | dv += from_warp(kwi*fw, i, (float4*)share) % dstatei; 232 | } 233 | RTile dab_ut = FTile(dut % transpose(ab_inv)); 234 | dv += transpose(ak) % dab_ut; 235 | 236 | int off = (t_off+t)*K*H*C + hi*C + warpi*16; 237 | GTile(dv_+off, stride) = RTile(dv); 238 | 239 | FTile dab = sum_warp<1,WARPS>((float*)share, tril<1>(transpose(dab_ut) % transpose(ut))); 240 | FTile dak = sum_warp<1,WARPS>((float*)share, tril<1>(transpose(dab_ut) % transpose(vt))); 241 | FTile dab_u_state0; 242 | dab_u_state0.zero_(); 243 | for (int i = 0; i < WARPS; i++) 244 | dab_u_state0 += from_warp(transpose(dab_ut), i, (float4*)share) % state[i*WARPS+warpi].t(); 245 | 246 | FTile da = dab_u_state0; 247 | da += dab % transpose(bwi); 248 | da += dak % transpose(kwi); 249 | da = non_incl_pref * da; 250 | GTile(da_+off, stride) = RTile(da); 251 | 252 | FTile dqb = sum_warp<1,WARPS>((float*)share, tril<0>(transpose(dyt) % transpose(ut))); 253 | FTile dqk = sum_warp<1,WARPS>((float*)share, tril<0>(transpose(dyt) % transpose(vt))); 254 | FTile dy_state0; 255 | dy_state0.zero_(); 256 | for (int i = 0; i < WARPS; i++) 257 | dy_state0 += from_warp(transpose(dyt), i, (float4*)share) % state[i*WARPS+warpi].t(); 258 | 259 | FTile dq = dy_state0; 260 | dq += dqb % transpose(bwi); 261 | dq += dqk % transpose(kwi); 262 | dq = incl_pref * dq; 263 | GTile(dq_+off, stride) = RTile(dq); 264 | 265 | RTile wqt = transpose(wq), wat = transpose(wa); 266 | 267 | FTile u_dstate, v_dstate, dw; 268 | u_dstate.zero_(); 269 | v_dstate.zero_(); 270 | dw.zero_(); 271 | RTile ones; 272 | for (int i = 0; i < 4; i++) ones.data[i] = to_bf2({1.f,1.f}); 273 | for (int i = 0; i < WARPS; i++) { 274 | int tid = threadIdx.x%32; 275 | if (warpi == i) { 276 | for (int j = 0; j < WARPS; j++) { 277 | RTile ra = dstate[j]; 278 | ((float4*)share)[j*32+tid] = *((float4*)ra.data); 279 | } 280 | } 281 | RTile dstatei;// = dstate[i*WARPS+warpi]; 282 | __syncthreads(); 283 | *((float4*)dstatei.data) = ((float4*)share)[warpi*32+tid]; 284 | __syncthreads(); 285 | RTile dstatei_t = transpose(dstatei); 286 | v_dstate += from_warp(transpose(vt), i, (float4*)share) % dstatei_t; 287 | u_dstate += from_warp(transpose(ut), i, (float4*)share) % dstatei_t; 288 | dw += ones % ((RTile)state[i*WARPS+warpi].t()*dstatei_t); 289 | } 290 | 291 | FTile db = fw * u_dstate; 292 | db += transpose(dab) % wat; 293 | db += transpose(dqb) % wqt; 294 | db = inv_incl_pref * db; 295 | GTile(db_+off, stride) = RTile(db); 296 | 297 | FTile dk = fw * v_dstate; 298 | dk += transpose(dak) % wat; 299 | dk += transpose(dqk) % wqt; 300 | dk = inv_incl_pref * dk; 301 | GTile(dk_+off, stride) = RTile(dk); 302 | 303 | dw = fw * dw; 304 | dw += fast_dw<1>(dab,wa,bwi); 305 | dw += fast_dw<1>(dak,wa,kwi); 306 | dw += fast_dw<0>(dqb,wq,bwi); 307 | dw += fast_dw<0>(dqk,wq,kwi); 308 | FTile tmp; 309 | dw += cumsumv<0,0>(tmp = v_dstate*(fw*kwi)); 310 | dw += cumsumv<0,0>(tmp = u_dstate*(fw*bwi)); 311 | dw += cumsumv<0,1>(tmp = dab_u_state0*wa); 312 | dw += cumsumv<1,1>(tmp = dy_state0*wq); 313 | 314 | FTile dw_fac = (RTile)sw; 315 | apply_(dw_fac, [](float x) { return -__expf(x); }); 316 | dw = dw * dw_fac; 317 | GTile(dw_+off, stride) = RTile(dw); 318 | 319 | for (int i = 0; i < WARPS; i++) { 320 | FTile ndstate = dstate[i] * from_warp(fw, i, (float4*)share); 321 | ndstate += dyt % from_warp(wqt, i, (float4*)share); 322 | ndstate += dab_ut % from_warp(wat, i, (float4*)share); 323 | dstate[i] = ndstate; 324 | } 325 | } 326 | for (int i = 0; i < WARPS; i++) { 327 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16; 328 | GTile(ds0_+off, C) = dstate[i]; 329 | } 330 | } 331 | 332 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*dy, bf*s, bf*dsT, bf*dw, bf*dq, bf*dk, bf*dv, bf*da, bf*db, bf*ds0, long long*cu_seqlens) { 333 | assert(T%16 == 0); 334 | constexpr int tmp_size1 = sizeof(float)*32*8*WARPS, tmp_size2 = sizeof(float)*16*16*2; 335 | constexpr int threads = 32*WARPS, shared_mem = sizeof(STile)*WARPS*bw_stages*(7+WARPS) + (tmp_size1 > tmp_size2 ? tmp_size1 : tmp_size2); 336 | static int reported = 0; 337 | if (!reported++) { 338 | #if defined VERBOSE 339 | printf("backward_kernel() uses %d bytes of (dynamic) shared memory\n", shared_mem); 340 | #endif 341 | cudaFuncAttributes attr; 342 | cudaFuncGetAttributes(&attr, backward_kernel); 343 | int cur_mem = attr.maxDynamicSharedSizeBytes; 344 | if (shared_mem > cur_mem) { 345 | #if defined VERBOSE 346 | printf("Increasing backward_kernel's MaxDynamicSharedMemorySize from %d to %d\n", cur_mem, shared_mem); 347 | #endif 348 | assert(!cudaFuncSetAttribute(backward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem)); 349 | } 350 | } 351 | backward_kernel<<>>(T,H,w,q,k,v,a,b,dy,s,dsT,dw,dq,dk,dv,da,db,ds0,cu_seqlens); 352 | } 353 | 354 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/chunked_cuda_varlen/chunked_cuda_varlen.py: -------------------------------------------------------------------------------- 1 | import os, torch as th 2 | from torch.utils.cpp_extension import load 3 | 4 | class RWKV7_chunked_varlen(th.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, q,w,k,v,a,b,s0, cu_seqlens): 7 | T,H,C = w.shape 8 | assert T%16 == 0 9 | if not th.compiler.is_compiling(): 10 | assert hasattr(th.ops.wind_chunked_cuda_varlen, 'forward'), 'Requires a load kernel from load_chunked_cuda_varlen(head_size)' 11 | assert all(i.dtype==th.bfloat16 for i in [w,q,k,v,a,b,s0]) 12 | assert all(i.is_contiguous() for i in [w,q,k,v,a,b,s0]) 13 | assert all(i.shape == w.shape for i in [w,q,k,v,a,b]) 14 | assert list(s0.shape) == [len(cu_seqlens)-1,H,C,C] 15 | assert cu_seqlens.dtype == th.long 16 | assert cu_seqlens.device == w.device 17 | assert (cu_seqlens%16 == 0).all() 18 | assert cu_seqlens[-1].item() == T 19 | y = th.empty_like(v) 20 | sT = th.empty_like(s0) 21 | if any(i.requires_grad for i in [w,q,k,v,a,b,s0]): 22 | s = th.empty(H,T//16,C,C, dtype=th.bfloat16,device=w.device) 23 | else: 24 | s = None 25 | th.ops.wind_chunked_cuda_varlen.forward(w,q,k,v,a,b, s0,y,s,sT, cu_seqlens) 26 | ctx.save_for_backward(w,q,k,v,a,b,s,cu_seqlens) 27 | return y, sT 28 | @staticmethod 29 | def backward(ctx, dy, dsT): 30 | w,q,k,v,a,b,s,cu_seqlens = ctx.saved_tensors 31 | if not th.compiler.is_compiling(): 32 | assert all(i.dtype==th.bfloat16 for i in [dy,dsT]) 33 | assert all(i.is_contiguous() for i in [dy,dsT]) 34 | dw,dq,dk,dv,da,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,a,b,dsT]] 35 | th.ops.wind_chunked_cuda_varlen.backward(w,q,k,v,a,b, dy,s,dsT, dw,dq,dk,dv,da,db,ds0, cu_seqlens) 36 | return dq,dw,dk,dv,da,db,ds0,None 37 | 38 | def attn_chunked_cuda_varlen(r,w,k,v,a,b,s0,cu_seqlens): 39 | T,H,C = w.shape 40 | B = len(cu_seqlens)-1 41 | if s0 is None: s0 = th.zeros(B,H,C,C, dtype=th.bfloat16,device=w.device) 42 | return RWKV7_chunked_varlen.apply(r,w,k,v,a,b,s0,cu_seqlens) 43 | 44 | def load_chunked_cuda_varlen(head_size): 45 | if hasattr(th.ops.wind_chunked_cuda_varlen, 'forward'): return 46 | CUDA_FLAGS = ["-res-usage", f'-D_C_={head_size}', "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] 47 | if head_size == 256: CUDA_FLAGS.append('-maxrregcount=128') 48 | path = os.path.dirname(__file__) 49 | load(name="wind_chunked_cuda_varlen", sources=[os.path.join(path,'chunked_cuda_varlen.cu'), os.path.join(path,'chunked_cuda_varlen.cpp')], is_python_module=False, verbose=False, extra_cuda_cflags=CUDA_FLAGS) 50 | assert hasattr(th.ops.wind_chunked_cuda_varlen, 'forward') 51 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/chunked_cuda_varlen/tile.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, Johan Sokrates Wind 2 | 3 | #include 4 | #include 5 | 6 | //TODO: static? inline? __align__(16)? 7 | 8 | using bf = __nv_bfloat16; 9 | using bf2 = __nv_bfloat162; 10 | using uint = unsigned int; 11 | __device__ inline float to_float(const bf & u) { return __bfloat162float(u); } 12 | __device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); } 13 | __device__ inline float2 to_float2(const bf2 & u) { return __bfloat1622float2(u); } 14 | __device__ inline float2 to_float2(const float2 & u) { return u; } 15 | __device__ inline bf2 to_bf2(const float2 & u) { return __float22bfloat162_rn(u); } 16 | __device__ inline uint& as_uint(const bf2&x) { return *((uint*)(&x)); } 17 | __device__ inline uint __smem(const void*x) { return __cvta_generic_to_shared(x); } 18 | 19 | __device__ void __commit_group() { asm volatile("cp.async.commit_group;\n" ::); } 20 | __device__ void __wait_group() { asm volatile("cp.async.wait_all;\n" ::); } 21 | template __device__ void __wait_groups() { asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); } 22 | 23 | __device__ void __copy_wait() { __commit_group(); __wait_group(); } 24 | 25 | __device__ void operator*=(float2&a, const float2&b) { a.x *= b.x; a.y *= b.y; } 26 | __device__ void operator+=(float2&a, const float2&b) { a.x += b.x; a.y += b.y; } 27 | __device__ float2 operator+(const float2&a, const float2&b) { return {a.x+b.x,a.y+b.y}; } 28 | __device__ float2 operator*(const float2&a, const float2&b) { return {a.x*b.x,a.y*b.y}; } 29 | 30 | struct STile; 31 | struct RTile; 32 | struct FTile; 33 | 34 | struct GTile { 35 | bf*ga; 36 | int stride; 37 | __device__ GTile(bf*ga_, int stride_) : ga(ga_), stride(stride_) {} 38 | __device__ GTile& operator=(const RTile&); 39 | }; 40 | struct GFTile { 41 | float*ga; 42 | int stride; 43 | __device__ GFTile(float*ga_, int stride_) : ga(ga_), stride(stride_) {} 44 | __device__ GFTile& operator=(const FTile&); 45 | }; 46 | struct STileT { STile*st; }; 47 | 48 | struct __align__(16) STile { 49 | bf data[16*16]; 50 | __device__ STile() {} 51 | __device__ STile(const RTile&o) { *this=o; } 52 | __device__ STile& operator=(const GTile&); 53 | __device__ STile& operator=(const RTile&); 54 | __device__ STileT t() { return STileT{this}; } 55 | }; 56 | struct Product { const RTile*a, *b; }; 57 | struct ProductPlus { const RTile*a, *b; const FTile* c; }; 58 | struct RTile { 59 | bf2 data[4]; 60 | __device__ RTile() {} 61 | __device__ void zero_() { data[0] = data[1] = data[2] = data[3] = to_bf2({0.f,0.f}); } 62 | __device__ RTile(const STile&o) { *this=o; } 63 | __device__ RTile(const STileT&o) { *this=o; } 64 | __device__ RTile(const FTile&o) { *this=o; } 65 | __device__ RTile& operator=(const STile&); 66 | __device__ RTile& operator=(const STileT&); 67 | __device__ RTile& operator=(const FTile&fa); 68 | __device__ RTile& operator=(const GTile&); 69 | }; 70 | struct FTile { 71 | union { 72 | float2 data[4]; 73 | float fdata[8]; 74 | }; 75 | __device__ void zero_() { data[0] = data[1] = data[2] = data[3] = {0.f,0.f}; } 76 | __device__ FTile() {} 77 | __device__ FTile(const FTile&o) { for (int i = 0; i < 4; i++) data[i] = o.data[i]; } 78 | __device__ FTile(const RTile&r) { *this=r; } 79 | __device__ FTile(const Product&p) { *this=p; } 80 | __device__ FTile(const ProductPlus&p) { *this=p; } 81 | __device__ FTile& operator=(const Product&); 82 | __device__ FTile& operator=(const RTile&); 83 | __device__ FTile& operator=(const ProductPlus&); 84 | __device__ FTile& operator+=(const Product&); 85 | __device__ FTile& operator+=(const FTile&o) { for (int i = 0; i < 4; i++) data[i] += o.data[i]; return *this; } 86 | }; 87 | 88 | __device__ void print(STile t) { 89 | if (threadIdx.x == 0) { 90 | for (int i = 0; i < 16; i++) { 91 | for (int j = 0; j < 16; j++) { 92 | printf("%f ", to_float(t.data[i*16+j])); 93 | } 94 | printf("\n"); 95 | } 96 | printf("\n"); 97 | } 98 | } 99 | 100 | template 101 | __device__ void print(T t, int warpi = 0) { 102 | int tid = threadIdx.x - warpi*32; 103 | for (int i = 0; i < 16; i++) { 104 | for (int j = 0; j < 16; j += 2) { 105 | if (tid == i%8*4+j%8/2) { 106 | float2 xy = to_float2(t.data[i/8+j/8*2]); 107 | printf("%f %f ", xy.x, xy.y); 108 | //printf("T%d:{a%d,a%d} ", threadIdx.x, (i/8+j/8*2)*2, (i/8+j/8*2)*2+1); 109 | } 110 | __syncthreads(); 111 | } 112 | if (tid == 0) printf("\n"); 113 | __syncthreads(); 114 | } 115 | if (tid == 0) printf("\n"); 116 | __syncthreads(); 117 | } 118 | 119 | template 120 | __device__ void print8(T mat) { 121 | for (int i = 0; i < 8; i++) { 122 | for (int j = 0; j < 8; j += 2) { 123 | if (threadIdx.x == i%8*4+j%8/2) { 124 | float2 xy = to_float2(mat); 125 | printf("%f %f ", xy.x, xy.y); 126 | } 127 | __syncthreads(); 128 | } 129 | if (threadIdx.x == 0) printf("\n"); 130 | __syncthreads(); 131 | } 132 | if (threadIdx.x == 0) printf("\n"); 133 | __syncthreads(); 134 | } 135 | 136 | 137 | 138 | __device__ void load(STile&sa, bf*ga, int stride) { 139 | int i = threadIdx.x%32/2, j = threadIdx.x%2; 140 | asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" :: "r"(__smem(&sa.data[i*16+j*8])), "l"(ga+stride*i+j*8), "n"(16)); 141 | } 142 | 143 | __device__ void load(RTile&ra, const STile&sa) { 144 | int i = threadIdx.x%8, j = threadIdx.x%32/16, k = threadIdx.x/8%2; 145 | asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" 146 | : "=r"(as_uint(ra.data[0])), "=r"(as_uint(ra.data[1])), "=r"(as_uint(ra.data[2])), "=r"(as_uint(ra.data[3])) 147 | : "r"(__smem(&sa.data[i*16+j*8+k*8*16]))); 148 | } 149 | __device__ void loadT(RTile&ra, const STile&sa) { 150 | int i = threadIdx.x%8, j = threadIdx.x%32/16, k = threadIdx.x/8%2; 151 | asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" 152 | : "=r"(as_uint(ra.data[0])), "=r"(as_uint(ra.data[1])), "=r"(as_uint(ra.data[2])), "=r"(as_uint(ra.data[3])) 153 | : "r"(__smem(&sa.data[i*16+j*8*16+k*8]))); 154 | } 155 | 156 | __device__ static inline void __m16n8k16(float2&d0, float2&d1, const bf2 &a0, const bf2 &a1, const bf2 &a2, const bf2 &a3, const bf2 &b0, const bf2 &b1, const float2 &c0, const float2 &c1) { 157 | asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" 158 | : "=f"(d0.x), "=f"(d0.y), "=f"(d1.x), "=f"(d1.y) 159 | : "r"(as_uint(a0)), "r"(as_uint(a1)), "r"(as_uint(a2)), "r"(as_uint(a3)), 160 | "r"(as_uint(b0)), "r"(as_uint(b1)), 161 | "f"(c0.x), "f"(c0.y), "f"(c1.x), "f"(c1.y)); 162 | } 163 | __device__ void mma(FTile&rd, const RTile&ra, const RTile&rb, const FTile&rc) { // d = a*b^T + c 164 | __m16n8k16(rd.data[0],rd.data[1], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[0],rb.data[2], rc.data[0],rc.data[1]); 165 | __m16n8k16(rd.data[2],rd.data[3], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[1],rb.data[3], rc.data[2],rc.data[3]); 166 | } 167 | __device__ static inline void __m16n8k16(float2&d0, float2&d1, const bf2 &a0, const bf2 &a1, const bf2 &a2, const bf2 &a3, const bf2 &b0, const bf2 &b1) { 168 | asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" 169 | : "+f"(d0.x), "+f"(d0.y), "+f"(d1.x), "+f"(d1.y) 170 | : "r"(as_uint(a0)), "r"(as_uint(a1)), "r"(as_uint(a2)), "r"(as_uint(a3)), 171 | "r"(as_uint(b0)), "r"(as_uint(b1)), 172 | "f"(d0.x), "f"(d0.y), "f"(d1.x), "f"(d1.y)); 173 | } 174 | __device__ void mma(FTile&rd, const RTile&ra, const RTile&rb) { // d += a*b^T 175 | __m16n8k16(rd.data[0],rd.data[1], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[0],rb.data[2]); 176 | __m16n8k16(rd.data[2],rd.data[3], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[1],rb.data[3]); 177 | } 178 | __device__ void mm(FTile&rd, const RTile&ra, const RTile&rb) { // d = a*b^T 179 | __m16n8k16(rd.data[0],rd.data[1], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[0],rb.data[2], {0.f,0.f}, {0.f,0.f}); 180 | __m16n8k16(rd.data[2],rd.data[3], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[1],rb.data[3], {0.f,0.f}, {0.f,0.f}); 181 | } 182 | 183 | __device__ void store(const FTile&ra, float*ga, int stride) { 184 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2; 185 | *((float2*)&ga[ i *stride+j ]) = ra.data[0]; 186 | *((float2*)&ga[(i+8)*stride+j ]) = ra.data[1]; 187 | *((float2*)&ga[ i *stride+j+8]) = ra.data[2]; 188 | *((float2*)&ga[(i+8)*stride+j+8]) = ra.data[3]; 189 | } 190 | 191 | __device__ void store(const RTile&ra, bf*ga, int stride) { 192 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2; 193 | *((bf2*)&ga[ i *stride+j ]) = ra.data[0]; 194 | *((bf2*)&ga[(i+8)*stride+j ]) = ra.data[1]; 195 | *((bf2*)&ga[ i *stride+j+8]) = ra.data[2]; 196 | *((bf2*)&ga[(i+8)*stride+j+8]) = ra.data[3]; 197 | } 198 | __device__ void load(RTile&ra, bf*ga, int stride) { 199 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2; 200 | ra.data[0] = *((bf2*)&ga[ i *stride+j ]); 201 | ra.data[1] = *((bf2*)&ga[(i+8)*stride+j ]); 202 | ra.data[2] = *((bf2*)&ga[ i *stride+j+8]); 203 | ra.data[3] = *((bf2*)&ga[(i+8)*stride+j+8]); 204 | } 205 | __device__ void store(const RTile&ra, STile&sa) { //TODO: reduce bank conflicts? 206 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2; 207 | *((bf2*)&sa.data[ i *16+j ]) = ra.data[0]; 208 | *((bf2*)&sa.data[(i+8)*16+j ]) = ra.data[1]; 209 | *((bf2*)&sa.data[ i *16+j+8]) = ra.data[2]; 210 | *((bf2*)&sa.data[(i+8)*16+j+8]) = ra.data[3]; 211 | } 212 | 213 | __device__ void convert(RTile&ra, const FTile&fa) { 214 | ra.data[0] = to_bf2(fa.data[0]); 215 | ra.data[1] = to_bf2(fa.data[1]); 216 | ra.data[2] = to_bf2(fa.data[2]); 217 | ra.data[3] = to_bf2(fa.data[3]); 218 | } 219 | __device__ void convert(FTile&fa, const RTile&ra) { 220 | fa.data[0] = to_float2(ra.data[0]); 221 | fa.data[1] = to_float2(ra.data[1]); 222 | fa.data[2] = to_float2(ra.data[2]); 223 | fa.data[3] = to_float2(ra.data[3]); 224 | } 225 | 226 | __device__ STile& STile::operator=(const GTile& ga) { load(*this, ga.ga, ga.stride); return *this; } 227 | __device__ RTile& RTile::operator=(const GTile& ga) { load(*this, ga.ga, ga.stride); return *this; } 228 | __device__ RTile& RTile::operator=(const STile& sa) { load(*this, sa); return *this; } 229 | __device__ STile& STile::operator=(const RTile& ra) { store(ra, *this); return *this; } 230 | __device__ RTile& RTile::operator=(const STileT& sa) { loadT(*this, *sa.st); return *this; } 231 | __device__ Product operator%(const RTile&ra, const RTile&rb) { return Product{&ra,&rb}; } 232 | __device__ ProductPlus operator+(const Product&prod, const FTile&rc) { return ProductPlus{prod.a,prod.b,&rc}; } 233 | __device__ FTile& FTile::operator=(const Product& prod) { mm(*this, *prod.a, *prod.b); return *this; } 234 | __device__ FTile& FTile::operator=(const ProductPlus& prod) { mma(*this, *prod.a, *prod.b, *prod.c); return *this; } 235 | __device__ FTile& FTile::operator+=(const Product& prod) { mma(*this, *prod.a, *prod.b); return *this; } 236 | __device__ RTile& RTile::operator=(const FTile&fa) { convert(*this,fa); return *this; } 237 | __device__ FTile& FTile::operator=(const RTile&ra) { convert(*this,ra); return *this; } 238 | __device__ GTile& GTile::operator=(const RTile&ra) { store(ra, this->ga, this->stride); return *this; } 239 | __device__ GFTile& GFTile::operator=(const FTile&fa) { store(fa, this->ga, this->stride); return *this; } 240 | 241 | // Is this kind of cumsum better than multiplying with a triangular matrix of ones? 242 | template 243 | __device__ FTile cumsumv(FTile&w) { 244 | int tid = threadIdx.x%32, t = tid/4; 245 | 246 | FTile ret; 247 | if (inclusive) for (int i = 0; i < 4; i++) ret.data[i] = w.data[i]; 248 | else for (int i = 0; i < 4; i++) ret.data[i] = float2{0.f,0.f}; 249 | 250 | for (int b = 0; b < 3; b++) { 251 | for (int i = 0; i < 8; i++) { 252 | float other_w = __shfl_xor_sync(0xffffffff, w.fdata[i], 4<>b)%2 == !rev) ret.fdata[i] += other_w; 254 | w.fdata[i] += other_w; 255 | } 256 | } 257 | for (int i : {0,1,4,5}) { 258 | float &w0 = w.fdata[i^(2*!rev)], &w1 = w.fdata[i^(2*rev)]; 259 | ret.fdata[i^(2*!rev)] += w1; 260 | w0 += w1; 261 | w1 = w0; 262 | } 263 | return ret; 264 | } 265 | 266 | template 267 | __device__ FTile cumprodv(FTile&w) { 268 | int tid = threadIdx.x%32, t = tid/4; 269 | 270 | FTile ret; 271 | if (inclusive) for (int i = 0; i < 4; i++) ret.data[i] = w.data[i]; 272 | else for (int i = 0; i < 4; i++) ret.data[i] = float2{1.f,1.f}; 273 | 274 | for (int b = 0; b < 3; b++) { 275 | for (int i = 0; i < 8; i++) { 276 | float other_w = __shfl_xor_sync(0xffffffff, w.fdata[i], 4<>b)%2 == !rev) ret.fdata[i] *= other_w; 278 | w.fdata[i] *= other_w; 279 | } 280 | } 281 | for (int i : {0,1,4,5}) { 282 | float &w0 = w.fdata[i^(2*!rev)], &w1 = w.fdata[i^(2*rev)]; 283 | ret.fdata[i^(2*!rev)] *= w1; 284 | w0 *= w1; 285 | w1 = w0; 286 | } 287 | return ret; 288 | } 289 | 290 | __device__ FTile operator*(const FTile&a, const FTile&b) { 291 | FTile ret; 292 | for (int i = 0; i < 8; i++) ret.fdata[i] = a.fdata[i]*b.fdata[i]; 293 | return ret; 294 | } 295 | 296 | template // Lower triangular 297 | __device__ FTile sum_warp(float*share, const FTile&f) { // Requires share of size sizeof(float)*32*8*WARPS 298 | int tid = threadIdx.x%32, warpi = threadIdx.x/32; 299 | __syncthreads(); 300 | for (int i = 0; i < 8; i++) { 301 | if (triangular && (i == 4 || i == 5)) continue; 302 | share[32*(warpi*8+i)+tid] = f.fdata[i]; 303 | } 304 | __syncthreads(); 305 | #pragma unroll 306 | for (int k = warpi; k < 8; k += WARPS) { 307 | if (triangular && (k == 4 || k == 5)) continue; 308 | float sum = 0; 309 | for (int i = 1; i < WARPS; i++) { 310 | sum += share[32*(i*8+k)+tid]; 311 | } 312 | share[32*k+tid] += sum; 313 | } 314 | __syncthreads(); 315 | FTile ret; 316 | for (int i = 0; i < 8; i++) { 317 | if (triangular && (i == 4 || i == 5)) 318 | ret.fdata[i] = 0; 319 | else 320 | ret.fdata[i] = share[32*i+tid]; 321 | } 322 | __syncthreads(); 323 | return ret; 324 | } 325 | 326 | __device__ RTile from_warp(const RTile&ra, int src, float4*share) { 327 | int tid = threadIdx.x%32, warpi = threadIdx.x/32; 328 | if (warpi == src) share[tid] = *((float4*)ra.data); 329 | __syncthreads(); 330 | RTile ret; 331 | *((float4*)ret.data) = share[tid]; 332 | __syncthreads(); 333 | return ret; 334 | } 335 | 336 | // inv(I-f) where f is strictly lower triangular 337 | __device__ FTile tri_minv(const FTile&f, float*share) { 338 | int i0 = threadIdx.x%32/4, j0 = threadIdx.x%4*2; 339 | float inv[16] = {}; 340 | for (int k = 0; k < 8; k++) { 341 | int i = i0+k/2%2*8, j = j0+k%2+k/4*8; 342 | share[i*16+j] = f.fdata[k]; 343 | } 344 | int tid = threadIdx.x%32; 345 | inv[tid%16] = 1; 346 | for (int i = 1; i < 16; i++) { 347 | for (int j = 0; j < i; j++) { 348 | float fac = share[i*16+j]; 349 | inv[i] += fac*inv[j]; 350 | } 351 | } 352 | for (int i = 0; i < 16; i++) 353 | share[tid*16+i] = inv[i]; 354 | FTile ret; 355 | for (int k = 0; k < 8; k++) { 356 | int i = i0+k/2%2*8, j = j0+k%2+k/4*8; 357 | ret.fdata[k] = share[j*16+i]; 358 | } 359 | return ret; 360 | } 361 | 362 | template 363 | __device__ FTile tril(const FTile&f) { 364 | int i0 = threadIdx.x%32/4, j0 = threadIdx.x%4*2; 365 | FTile ret; 366 | for (int k = 0; k < 8; k++) { 367 | int i = i0+k/2%2*8, j = j0+k%2+k/4*8; 368 | if (strict) ret.fdata[k] = (i>j ? f.fdata[k] : 0.f); 369 | else ret.fdata[k] = (i>=j ? f.fdata[k] : 0.f); 370 | } 371 | return ret; 372 | } 373 | 374 | template 375 | __device__ void apply_(FTile&tile, F f) { 376 | for (int i = 0; i < 8; i++) tile.fdata[i] = f(tile.fdata[i]); 377 | } 378 | 379 | __device__ bf2 transpose(bf2 a) { 380 | bf2 ret; 381 | asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(as_uint(ret)) : "r"(as_uint(a))); 382 | return ret; 383 | } 384 | 385 | __device__ RTile transpose(const RTile&ra) { 386 | RTile rb; 387 | rb.data[0] = transpose(ra.data[0]); 388 | rb.data[1] = transpose(ra.data[2]); 389 | rb.data[2] = transpose(ra.data[1]); 390 | rb.data[3] = transpose(ra.data[3]); 391 | return rb; 392 | } 393 | 394 | template 395 | __device__ FTile slow_dw(const RTile&A, const RTile&q, const RTile&k, STile*share) { 396 | share[0] = A; 397 | share[1] = q; 398 | share[2] = k; 399 | __syncthreads(); 400 | if (threadIdx.x%32 == 0) { 401 | for (int k = 0; k < 16; k++) { 402 | for (int j = 0; j < 16; j++) { 403 | float sum = 0; 404 | for (int l = 0; l < k; l++) { 405 | for (int r = k+strict; r < 16; r++) { 406 | sum += to_float(share[0].data[r*16+l]) * to_float(share[1].data[r*16+j]) * to_float(share[2].data[l*16+j]); 407 | } 408 | } 409 | share[3].data[k*16+j] = to_bf(sum); 410 | } 411 | } 412 | } 413 | __syncthreads(); 414 | RTile ret = (RTile)share[3]; 415 | __syncthreads(); 416 | return ret; 417 | } 418 | 419 | 420 | __device__ static inline void __m16n8k8(float2&d0, float2&d1, const bf2 &a0, const bf2 &a1, const bf2 &b0) { 421 | asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" 422 | : "=f"(d0.x), "=f"(d0.y), "=f"(d1.x), "=f"(d1.y) : "r"(as_uint(a0)), "r"(as_uint(a1)), "r"(as_uint(b0)), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); 423 | } 424 | 425 | template 426 | __device__ RTile fast_dw(const RTile&A, const RTile&q, const RTile&k) { 427 | float2 qkA8[4]; 428 | RTile kt = transpose(k), qt = transpose(q); 429 | __m16n8k8(qkA8[0],qkA8[1], qt.data[2], qt.data[3], transpose(A.data[1])); 430 | __m16n8k8(qkA8[2],qkA8[3], kt.data[0], kt.data[1], A.data[1]); 431 | for (int x : {0,1}) { 432 | qkA8[x] *= to_float2(kt.data[x]); 433 | qkA8[2+x] *= to_float2(qt.data[2+x]); 434 | } 435 | 436 | int tid = threadIdx.x%32, j = threadIdx.x%4; 437 | // Non-inclusive cumsum 438 | for (int i = 0; i < 4; i++) { 439 | float sum = qkA8[i].x+qkA8[i].y; 440 | float psum = __shfl_xor_sync(0xffffffff, sum, 1); 441 | float ppsum = __shfl_xor_sync(0xffffffff, sum+psum, 2); 442 | if (i < 2) { 443 | psum = ppsum*(j>=2)+psum*(j%2); 444 | qkA8[i].y = psum + qkA8[i].x; 445 | qkA8[i].x = psum; 446 | } else { 447 | psum = ppsum*(j<2)+psum*(j%2==0); 448 | qkA8[i].x = psum + qkA8[i].y; 449 | qkA8[i].y = psum; 450 | } 451 | } 452 | 453 | float2 qkA4[4]; 454 | { 455 | RTile k_q; 456 | for (int i = 0; i < 8; i++) ((bf*)k_q.data)[i] = (j<2?((bf*)kt.data)[i]:((bf*)qt.data)[i]); 457 | float lower_left = (tid >= 16 && j < 2); 458 | bf2 A0 = to_bf2(to_float2(A.data[0])*float2{lower_left,lower_left}); 459 | bf2 A3 = to_bf2(to_float2(A.data[3])*float2{lower_left,lower_left}); 460 | __m16n8k8(qkA4[0],qkA4[1], k_q.data[0], k_q.data[1], A0 + transpose(A0)); 461 | __m16n8k8(qkA4[2],qkA4[3], k_q.data[2], k_q.data[3], A3 + transpose(A3)); 462 | for (int i = 0; i < 4; i++) 463 | qkA4[i] *= to_float2(k_q.data[i]); 464 | } 465 | 466 | // Non-inclusive cumsum 467 | for (int i = 0; i < 4; i++) { 468 | float sum = qkA4[i].x+qkA4[i].y; 469 | float psum = __shfl_xor_sync(0xffffffff, sum, 1); 470 | psum *= (j%2 == j<2); 471 | qkA4[i] = {psum + qkA4[i].y*(j>=2), psum + qkA4[i].x*(j<2)}; 472 | } 473 | 474 | FTile ret; 475 | ret.data[0] = qkA8[0]+qkA4[0]; 476 | ret.data[1] = qkA8[1]+qkA4[1]; 477 | ret.data[2] = qkA8[2]+qkA4[2]; 478 | ret.data[3] = qkA8[3]+qkA4[3]; 479 | 480 | for (int ci : {0,1}) { 481 | for (int ti : {0,1}) { 482 | int Ai = ti*3, di = ti*2+ci; 483 | bf A8x = __shfl_sync(0xffffffff, A.data[Ai].x, 8+(j>=2)*18); 484 | bf A12x = __shfl_sync(0xffffffff, A.data[Ai].x, 12+(j>=2)*18); 485 | bf A12y = __shfl_sync(0xffffffff, A.data[Ai].y, 12+(j>=2)*18); 486 | bf2 nq = __shfl_xor_sync(0xffffffff, qt.data[di], 1); 487 | bf2 pk = __shfl_xor_sync(0xffffffff, kt.data[di], 1); 488 | 489 | bool even = (j%2==0); 490 | float ax = to_float(even?A8x:A12x), ay = to_float(even?A12x:A12y), c = to_float(even?kt.data[di].x:qt.data[di].y); 491 | float2 b = to_float2(j%2?pk:nq); 492 | float d = (ax*b.x+ay*b.y)*c; 493 | ret.data[di].y += even*d; 494 | ret.data[di].x +=!even*d; 495 | } 496 | } 497 | 498 | if (!strict) { 499 | // Do we really need tril<1>()? 500 | ret += (kt % tril<1>(A)) * qt; 501 | } 502 | return transpose(ret); 503 | } 504 | 505 | __device__ void debug_set(RTile&ra, int i, int j, float v) { 506 | if (threadIdx.x%32 == i%8*4+j%8/2) ((bf*)ra.data)[i/8*2+j/8*4+j%2] = to_bf(v); 507 | } 508 | 509 | template 510 | __device__ float2 sumh(const FTile&f, float*share) { // Requires shared of size sizeof(float)*16*WARPS 511 | float2 warpsum = {f.fdata[0]+f.fdata[1]+f.fdata[4]+f.fdata[5], 512 | f.fdata[2]+f.fdata[3]+f.fdata[6]+f.fdata[7]}; 513 | warpsum.x += __shfl_xor_sync(0xffffffff, warpsum.x, 1); 514 | warpsum.y += __shfl_xor_sync(0xffffffff, warpsum.y, 1); 515 | warpsum.x += __shfl_xor_sync(0xffffffff, warpsum.x, 2); 516 | warpsum.y += __shfl_xor_sync(0xffffffff, warpsum.y, 2); 517 | 518 | int tid = threadIdx.x%32, warpi = threadIdx.x/32; 519 | __syncthreads(); 520 | if (tid%4 < 2) 521 | share[warpi*16+tid/4+tid%4*8] = (tid%4?warpsum.y:warpsum.x); 522 | __syncthreads(); 523 | if (warpi == 0 && tid < 16) { 524 | float sum = 0; 525 | for (int i = 1; i < WARPS; i++) { 526 | sum += share[i*16+tid]; 527 | } 528 | share[tid] += sum; 529 | } 530 | __syncthreads(); 531 | float2 ret = {share[tid/4], share[tid/4+8]}; 532 | __syncthreads(); 533 | return ret; 534 | } 535 | -------------------------------------------------------------------------------- /wind_rwkv/rwkv7/triton_bighead.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Johan Sokrates Wind 2 | 3 | import torch as th 4 | import triton 5 | import triton.language as tl 6 | 7 | @triton.jit 8 | def IND3(a,b,c,nb,nc): 9 | return (a*nb+b)*nc+c 10 | @triton.jit 11 | def IND4(a,b,c,d,nb,nc,nd): 12 | return ((a*nb+b)*nc+c)*nd+d 13 | @triton.jit 14 | def IND5(a,b,c,d,e,nb,nc,nd,ne): 15 | return (((a*nb+b)*nc+c)*nd+d)*ne+e 16 | 17 | @triton.jit 18 | def _prod(a,b): return a*b 19 | 20 | # inv(I-A) where A is a strictly lower triangular nxn matrix 21 | @triton.jit 22 | def tri_minv(A, n:tl.constexpr, prec:tl.constexpr): 23 | i = tl.arange(0,n) 24 | prod = (i[None,:]==i[:,None]).to(tl.float32) 25 | for j in range(n-1): 26 | prod += tl_dot(prec, prod, (A*((i[None,:]==j)*(i[:,None]>i[None,:]))).trans()) 27 | return prod.trans() 28 | 29 | @triton.autotune(configs=[triton.Config({'dC': dC}, num_stages=1) for dC in [16,32,64]], key=['T','H','C','dT','prec']) 30 | @triton.jit 31 | def fw_attn_triton(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, wq_,wa_,kwi_,bwi_,fw_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr, dC:tl.constexpr): 32 | tl.static_assert(C%dC == 0) 33 | bi = tl.program_id(1) 34 | hi = tl.program_id(0) 35 | 36 | for i0 in range(0,C,dC): 37 | i = i0+tl.arange(0,dC)[None,:] 38 | for j0 in range(0,C,dC): 39 | j = j0+tl.arange(0,dC)[None,:] 40 | state = tl.load(s0_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) 41 | tl.store(s_+IND5(bi,hi,0,i.trans(),j, H,T//dT,C,C), state.to(tl.float32)) 42 | 43 | for t0 in range(T//dT): 44 | dt = tl.arange(0,dT)[:,None] 45 | t = t0*dT+dt 46 | tl.debug_barrier() 47 | for j0 in range(0,C,dC): 48 | j = j0+tl.arange(0,dC)[None,:] 49 | sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 50 | sq = tl.load(q_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 51 | sk = tl.load(k_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 52 | sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 53 | sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 54 | 55 | w = (-sw.exp()).exp() 56 | fw = tl.reduce(w, 0, _prod, keep_dims=True) 57 | incl_pref = tl.cumprod(w,axis=0) 58 | non_incl_pref = incl_pref / w 59 | inv_incl_pref = 1 / incl_pref 60 | 61 | wq = sq * incl_pref 62 | wa = sa * non_incl_pref 63 | kwi = sk * inv_incl_pref 64 | bwi = sb * inv_incl_pref 65 | 66 | tl.store(wq_+IND4(bi,hi,dt,j, H,dT,C), wq.to(tl.float32)) 67 | tl.store(wa_+IND4(bi,hi,dt,j, H,dT,C), wa.to(tl.float32)) 68 | tl.store(kwi_+IND4(bi,hi,dt,j, H,dT,C), kwi.to(tl.float32)) 69 | tl.store(bwi_+IND4(bi,hi,dt,j, H,dT,C), bwi.to(tl.float32)) 70 | tl.store(fw_+IND3(bi,hi,j, H,C), fw.to(tl.float32)) 71 | tl.debug_barrier() 72 | 73 | ab = tl.zeros((dT,dT), tl.float32) 74 | ak = tl.zeros((dT,dT), tl.float32) 75 | qb = tl.zeros((dT,dT), tl.float32) 76 | qk = tl.zeros((dT,dT), tl.float32) 77 | for j0 in range(0,C,dC): 78 | j = j0+tl.arange(0,dC)[None,:] 79 | 80 | wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 81 | wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 82 | bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 83 | kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 84 | 85 | sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 86 | sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 87 | 88 | ab += tl_dot(prec, wa, bwi.trans()) 89 | ak += tl_dot(prec, wa, kwi.trans()) 90 | qb += tl_dot(prec, wq, bwi.trans()) 91 | qk += tl_dot(prec, wq, kwi.trans()) 92 | 93 | mask1 = (t > t.trans()) 94 | mask2 = (t >= t.trans()) 95 | ab *= mask1 96 | ak *= mask1 97 | qb *= mask2 98 | qk *= mask2 99 | 100 | ab_inv = tri_minv(ab, dT, prec) 101 | 102 | for i0 in range(0,C,dC): 103 | i = i0+tl.arange(0,dC)[None,:] 104 | sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) 105 | 106 | wa_state = tl.zeros((dT,dC), tl.float32) 107 | wq_state = tl.zeros((dT,dC), tl.float32) 108 | for j0 in range(0,C,dC): 109 | j = j0+tl.arange(0,dC)[None,:] 110 | state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) 111 | wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 112 | wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 113 | wa_state += tl_dot(prec, wa, state.trans()) 114 | wq_state += tl_dot(prec, wq, state.trans()) 115 | 116 | ab_u = tl_dot(prec, ak, sv) + wa_state 117 | u = tl_dot(prec, ab_inv, ab_u) 118 | yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + wq_state 119 | tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy) 120 | 121 | for j0 in range(0,C,dC): 122 | j = j0+tl.arange(0,dC)[None,:] 123 | state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) 124 | kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 125 | bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 126 | fw = tl.load(fw_+IND3(bi,hi,j, H,C)).to(tl.float32) 127 | 128 | state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw) 129 | 130 | if t0+1 < T//dT: 131 | tl.store(s_+IND5(bi,hi,t0+1,i.trans(),j, H,T//dT,C,C), state.to(tl.float32)) 132 | else: 133 | tl.store(sT_+IND4(bi,hi,i.trans(),j, H,C,C), state) 134 | 135 | 136 | @triton.autotune(configs=[triton.Config({'dC': dC}, num_stages=1) for dC in [16,32,64]], key=['T','H','C','dT','prec']) 137 | @triton.jit 138 | def bw_attn_triton(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_,ds_, dw_,dq_,dk_,dv_,da_,db_,ds0_, wq_,wa_,kwi_,bwi_,fw_,u_,dab_u_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr, dC:tl.constexpr): 139 | tl.static_assert(C%dC == 0) 140 | bi = tl.program_id(1) 141 | hi = tl.program_id(0) 142 | for i0 in range(0,C,dC): 143 | i = i0+tl.arange(0,dC)[None,:] 144 | for j0 in range(0,C,dC): 145 | j = j0+tl.arange(0,dC)[None,:] 146 | dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) 147 | tl.store(ds_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.float32)) 148 | 149 | for t0 in range(T//dT-1,-1,-1): 150 | dt = tl.arange(0,dT)[:,None] 151 | t = t0*dT+dt 152 | tl.debug_barrier() 153 | for j0 in range(0,C,dC): 154 | j = j0+tl.arange(0,dC)[None,:] 155 | sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 156 | sq = tl.load(q_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 157 | sk = tl.load(k_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 158 | sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 159 | sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 160 | 161 | w = (-sw.exp()).exp() 162 | fw = tl.reduce(w, 0, _prod, keep_dims=True) 163 | incl_pref = tl.cumprod(w,axis=0) 164 | non_incl_pref = incl_pref / w 165 | inv_incl_pref = 1 / incl_pref 166 | 167 | wq = sq * incl_pref 168 | wa = sa * non_incl_pref 169 | kwi = sk * inv_incl_pref 170 | bwi = sb * inv_incl_pref 171 | 172 | tl.store(wq_+IND4(bi,hi,dt,j, H,dT,C), wq.to(tl.float32)) 173 | tl.store(wa_+IND4(bi,hi,dt,j, H,dT,C), wa.to(tl.float32)) 174 | tl.store(kwi_+IND4(bi,hi,dt,j, H,dT,C), kwi.to(tl.float32)) 175 | tl.store(bwi_+IND4(bi,hi,dt,j, H,dT,C), bwi.to(tl.float32)) 176 | tl.store(fw_+IND3(bi,hi,j, H,C), fw.to(tl.float32)) 177 | tl.debug_barrier() 178 | 179 | ab = tl.zeros((dT,dT), tl.float32) 180 | ak = tl.zeros((dT,dT), tl.float32) 181 | qb = tl.zeros((dT,dT), tl.float32) 182 | qk = tl.zeros((dT,dT), tl.float32) 183 | for j0 in range(0,C,dC): 184 | j = j0+tl.arange(0,dC)[None,:] 185 | 186 | wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 187 | wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 188 | bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 189 | kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 190 | 191 | sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 192 | sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 193 | 194 | ab += tl_dot(prec, wa, bwi.trans()) 195 | ak += tl_dot(prec, wa, kwi.trans()) 196 | qb += tl_dot(prec, wq, bwi.trans()) 197 | qk += tl_dot(prec, wq, kwi.trans()) 198 | 199 | mask1 = (t > t.trans()) 200 | mask2 = (t >= t.trans()) 201 | ab *= mask1 202 | ak *= mask1 203 | qb *= mask2 204 | qk *= mask2 205 | 206 | ab_inv = tri_minv(ab, dT, prec) 207 | 208 | dab = tl.zeros((dT,dT), tl.float32) 209 | dak = tl.zeros((dT,dT), tl.float32) 210 | dqb = tl.zeros((dT,dT), tl.float32) 211 | dqk = tl.zeros((dT,dT), tl.float32) 212 | 213 | tl.debug_barrier() 214 | for i0 in range(0,C,dC): 215 | i = i0+tl.arange(0,dC)[None,:] 216 | wa_state = tl.zeros((dT,dC), tl.float32) 217 | bwi_dw_dstate = tl.zeros((dT,dC), tl.float32) 218 | kwi_dw_dstate = tl.zeros((dT,dC), tl.float32) 219 | for j0 in range(0,C,dC): 220 | j = j0+tl.arange(0,dC)[None,:] 221 | state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) 222 | dstate = tl.load(ds_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) 223 | wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 224 | bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 225 | kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 226 | fw = tl.load(fw_+IND3(bi,hi,j, H,C)).to(tl.float32) 227 | 228 | wa_state += tl_dot(prec, wa, state.trans()) 229 | bwi_dw_dstate += tl_dot(prec, bwi*fw, dstate.trans()) 230 | kwi_dw_dstate += tl_dot(prec, kwi*fw, dstate.trans()) 231 | 232 | sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) 233 | sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) 234 | 235 | ab_u = tl_dot(prec, ak, sv) + wa_state 236 | u = tl_dot(prec, ab_inv, ab_u) 237 | du = tl_dot(prec, qb.trans(), sdy) + bwi_dw_dstate 238 | dab_u = tl_dot(prec, ab_inv.trans(), du) 239 | 240 | tl.store(u_+IND4(bi,hi,dt,i, H,dT,C), u.to(tl.float32)) 241 | tl.store(dab_u_+IND4(bi,hi,dt,i, H,dT,C), dab_u.to(tl.float32)) 242 | 243 | dv = tl_dot(prec, qk.trans(), sdy) + kwi_dw_dstate + tl_dot(prec, ak.trans(), dab_u) 244 | tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv) 245 | 246 | dab += tl_dot(prec, dab_u, u.trans()) * mask1 247 | dak += tl_dot(prec, dab_u, sv.trans()) * mask1 248 | dqb += tl_dot(prec, sdy, u.trans()) * mask2 249 | dqk += tl_dot(prec, sdy, sv.trans()) * mask2 250 | tl.debug_barrier() 251 | 252 | for j0 in range(0,C,dC): 253 | j = j0+tl.arange(0,dC)[None,:] 254 | 255 | dy_state = tl.zeros((dT,dC), tl.float32) 256 | dab_u_state = tl.zeros((dT,dC), tl.float32) 257 | fw_u_dstate = tl.zeros((dT,dC), tl.float32) 258 | fw_v_dstate = tl.zeros((dT,dC), tl.float32) 259 | state_dstate = tl.zeros((1,dC), tl.float32) 260 | 261 | fw = tl.load(fw_+IND3(bi,hi,j, H,C)).to(tl.float32) 262 | wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 263 | wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 264 | for i0 in range(0,C,dC): 265 | i = i0+tl.arange(0,dC)[None,:] 266 | 267 | u = tl.load(u_+IND4(bi,hi,dt,i, H,dT,C)).to(tl.float32) 268 | dab_u = tl.load(dab_u_+IND4(bi,hi,dt,i, H,dT,C)).to(tl.float32) 269 | sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) 270 | sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) 271 | 272 | state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) 273 | tl.debug_barrier() 274 | dstate = tl.load(ds_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) 275 | tl.debug_barrier() 276 | 277 | dab_u_state += tl_dot(prec, dab_u, state) 278 | fw_u_dstate += fw * tl_dot(prec, u, dstate) 279 | fw_v_dstate += fw * tl_dot(prec, sv, dstate) 280 | dy_state += tl_dot(prec, sdy, state) 281 | 282 | state_dstate += tl.sum(state*dstate, axis=0,keep_dims=True) 283 | 284 | dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa) 285 | if t0 > 0: 286 | tl.store(ds_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.float32)) 287 | else: 288 | tl.store(ds0_+IND4(bi,hi,i.trans(),j, H,C,C), dstate) 289 | 290 | sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) 291 | w = (-sw.exp()).exp() 292 | incl_pref = tl.cumprod(w,axis=0) 293 | non_incl_pref = incl_pref / w 294 | inv_incl_pref = 1 / incl_pref 295 | 296 | bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 297 | kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) 298 | 299 | da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state) 300 | tl.store(da_+IND4(bi,t,hi,j, T,H,C), da) 301 | 302 | dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state) 303 | tl.store(dq_+IND4(bi,t,hi,j, T,H,C), dq) 304 | 305 | db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate) 306 | tl.store(db_+IND4(bi,t,hi,j, T,H,C), db) 307 | 308 | dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate) 309 | tl.store(dk_+IND4(bi,t,hi,j, T,H,C), dk) 310 | 311 | dw0 = fw * state_dstate 312 | for k in range(t0*dT,t0*dT+dT): 313 | lmask = (tk) 315 | A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k) 316 | A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (tk) + dy_state*wq * (t>=k) 318 | dw = tl.sum(A, axis=0,keep_dims=True) + dw0 319 | 320 | wk = tl.load(w_+IND4(bi,k,hi,j, T,H,C)).to(tl.float32) 321 | dw *= -wk.exp() 322 | tl.store(dw_+IND4(bi,k,hi,j, T,H,C), dw) 323 | 324 | 325 | @triton.jit 326 | def tl_dot(prec:tl.constexpr, a, b): 327 | if prec == 'fp32': 328 | return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=False) 329 | #elif prec == 'tf32': # This sometimes runs into a bug in the triton language 330 | #return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=True) 331 | elif prec == 'bf16': 332 | return tl.dot(a.to(tl.bfloat16),b.trans().to(tl.bfloat16).trans(), allow_tf32=True) 333 | elif prec == 'fp16': 334 | return tl.dot(a.to(tl.float16),b.trans().to(tl.float16).trans(), allow_tf32=True) 335 | else: 336 | tl.static_assert(False) 337 | 338 | 339 | class RWKV7_bighead(th.autograd.Function): 340 | @staticmethod 341 | def forward(ctx, q,w,k,v,a,b,s0, dot_prec): 342 | K = 16 343 | B,T,H,C = w.shape 344 | assert T%K == 0 345 | assert C%16 == 0 346 | 347 | assert all(i.dtype==th.bfloat16 or (dot_prec == 'fp16' and i.dtype == th.float16) for i in [w,q,k,v,a,b,s0]) 348 | assert all(i.is_contiguous() for i in [w,q,k,v,a,b,s0]) 349 | assert all(i.shape == w.shape for i in [w,q,k,v,a,b]) 350 | assert list(s0.shape) == [B,H,C,C] 351 | 352 | y = th.empty_like(v) 353 | sT = th.empty_like(s0) 354 | s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device) 355 | wq,wa,kwi,bwi = [th.empty(B,H,K,C, dtype=th.float32,device=w.device) for i in range(4)] 356 | fw = th.empty(B,H,C, dtype=th.float32,device=w.device) 357 | fw_attn_triton[(H,B)](w,q,k,v,a,b, s0,y,s,sT, wq,wa,kwi,bwi,fw, B,T,H,C,K, dot_prec) 358 | ctx.dot_prec = dot_prec 359 | ctx.save_for_backward(w,q,k,v,a,b,s) 360 | return y, sT 361 | @staticmethod 362 | def backward(ctx, dy, dsT): 363 | K = 16 364 | w,q,k,v,a,b,s = ctx.saved_tensors 365 | B,T,H,C = w.shape 366 | dw,dq,dk,dv,da,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,a,b,dsT]] 367 | fw = th.empty(B,H,C, dtype=th.float32,device=w.device) 368 | ds = th.empty(B,H,C,C, dtype=th.float32,device=w.device) 369 | wq,wa,kwi,bwi,u,dab_u = [th.empty(B,H,K,C, dtype=th.float32,device=w.device) for i in range(6)] 370 | bw_attn_triton[(H,B)](w,q,k,v,a,b, dy,s,dsT,ds, dw,dq,dk,dv,da,db,ds0, wq,wa,kwi,bwi,fw,u,dab_u, B,T,H,C,K, ctx.dot_prec) 371 | return dq,dw,dk,dv,da,db,ds0,None 372 | 373 | def attn_triton_bighead(r,w,k,v,a,b, s0 = None, dot_prec='fp32'): 374 | B,T,H,C = w.shape 375 | if s0 is None: s0 = th.zeros(B,H,C,C, dtype=r.dtype,device=w.device) 376 | return RWKV7_bighead.apply(r,w,k,v,a,b,s0,dot_prec) 377 | 378 | def attn_triton_bighead_bf16(*args): return attn_triton_bighead(*args,dot_prec='bf16') 379 | def attn_triton_bighead_fp16(*args): return attn_triton_bighead(*args,dot_prec='fp16') 380 | #def attn_triton_bighead_tf32(*args): return attn_triton_bighead(*args,dot_prec='tf32') 381 | def attn_triton_bighead_fp32(*args): return attn_triton_bighead(*args,dot_prec='fp32') 382 | 383 | def attn_triton_bighead_wrap(r,w,k,v,a,b, s0 = None, return_state = False, head_size = 64, dot_prec = 'fp32'): 384 | B,T,HC = w.shape 385 | C = head_size 386 | H = HC//C 387 | r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,a,b]] 388 | s0 = th.zeros(B,H,C,C, dtype=r.dtype,device=w.device) 389 | return RWKV7_bighead.apply(r,w,k,v,a,b,s0,dot_prec)[0].view(B,T,HC) 390 | --------------------------------------------------------------------------------