├── README.md ├── depthwise_conv1d ├── cuda │ ├── timex_op.cpp │ ├── timex_cuda_v0.cu │ ├── timex_cuda_v1.cu │ ├── timex_cuda_v2.cu │ └── timex_cuda_v3.cu └── run.py ├── wkv └── cuda │ ├── wkv_op.cpp │ ├── wkv_cuda_v1.cu │ ├── wkv_cuda_v2.cu │ └── wkv_cuda_v0.cu ├── wkv6 └── cuda │ ├── wkv6_op.cpp │ ├── wkv5_op.cpp │ ├── wkv5_cuda_v1b2.cu │ ├── wkv6_cuda_v1.cu │ └── wkv6_cuda_v1a.cu ├── wkv5 └── cuda │ ├── wkv5_op.cpp │ ├── wkv5_ref.cpp │ ├── wkv5_cuda_v1c.cu │ ├── wkv5_cuda_v1a.cu │ ├── wkv5_cuda_v1d.cu │ ├── wkv5_cuda_v3.cu │ ├── wkv5_cuda_v1b.cu │ ├── wkv5_cuda_ref.cu │ ├── wkv5_cuda_v1.cu │ ├── wkv5_cuda_v2.cu │ └── wkv5_cuda_v1e.cu ├── wkv5_bf16 └── cuda │ ├── wkv5_op.cpp │ ├── wkv5_cuda_v1a.cu │ ├── wkv5_cuda_v1.cu │ ├── wkv5_cuda_v1b.cu │ ├── wkv5_cuda_v1b2.cu │ ├── wkv5_cuda_v2.cu │ └── wkv5_cuda_v3.cu ├── wkv6_state └── cuda │ └── wkv6state_op.cpp ├── rwkv7_fast_fused ├── cuda │ ├── rwkv7_clampw.cpp │ ├── rwkv7_state_clampw.cpp │ ├── rwkv7_statepassing_clampw.cpp │ ├── rwkv7_clampw.cu │ ├── rwkv7_state_clampw.cu │ └── rwkv7_statepassing_clampw.cu ├── rwkv7_cuda_benchmark.py └── rwkv7_cuda_benchmark_state.py ├── wkv5a └── cuda │ ├── wkv5a_op.cpp │ ├── wkv5a_cuda_v1.cu │ ├── wkv5a_cuda_v1a.cu │ └── wkv5a_cuda_v1a2.cu └── .gitignore /README.md: -------------------------------------------------------------------------------- 1 | # RWKV-CUDA 2 | The CUDA version of the RWKV language model ( https://github.com/BlinkDL/RWKV-LM ) 3 | 4 | ## Towards RWKV-4 (see the wkv folder) 5 | 6 | I have a basic RWKV-4 kernel in the wkv folder. Let's optimize it. 7 | 8 | 9 | 10 | ## Experiment 1 - depthwise_conv1d - 20x faster than pytorch 11 | 12 | The formula: 13 | ``` 14 | w.shape = (C, T) 15 | k.shape = (B, C, T) 16 | out.shape = (B, C, T) 17 | out[b][c][t] = sum_u{ w[c][(T-1)-(t-u)] * k[b][c][u] } 18 | ``` 19 | 20 | pytorch = fwd 94ms bwd 529ms 21 | 22 | CUDA kernel v0 = fwd 45ms bwd 84ms (simple) 23 | 24 | CUDA kernel v1 = fwd 17ms bwd 43ms (shared memory) 25 | 26 | CUDA kernel v2 = fwd 13ms bwd 31ms (float4) 27 | 28 | CUDA kernel v3 = fwd 3.4ms bwd 23ms (B-group) 29 | 30 | More test on RTX3090: 31 | 32 | pytorch = fwd 14ms bwd 65ms 33 | 34 | CUDA kernel v3 = fwd 0.8ms bwd 5.5ms 35 | 36 | How to use: ```python run.py``` and it will compile everything for you (```pip install Ninja``` if you don't have it). 37 | -------------------------------------------------------------------------------- /depthwise_conv1d/cuda/timex_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T); 4 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T); 5 | 6 | void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) { 7 | cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T); 8 | } 9 | void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) { 10 | cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T); 11 | } 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("forward", &forward, "timex forward"); 15 | m.def("backward", &backward, "timex backward"); 16 | } 17 | 18 | TORCH_LIBRARY(timex, m) { 19 | m.def("forward", forward); 20 | m.def("backward", backward); 21 | } 22 | -------------------------------------------------------------------------------- /wkv/cuda/wkv_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); 4 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); 5 | 6 | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { 7 | cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); 8 | } 9 | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { 10 | cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); 11 | } 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("forward", &forward, "wkv forward"); 15 | m.def("backward", &backward, "wkv backward"); 16 | } 17 | 18 | TORCH_LIBRARY(wkv, m) { 19 | m.def("forward", forward); 20 | m.def("backward", backward); 21 | } 22 | -------------------------------------------------------------------------------- /wkv6/cuda/wkv6_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv6 forward"); 16 | m.def("backward", &backward, "wkv6 backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv6, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /wkv5/cuda/wkv5_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y); 4 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *ww, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu); 5 | 6 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 7 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 8 | } 9 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 10 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); 11 | } 12 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 13 | m.def("forward", &forward, "wkv5 forward"); 14 | m.def("backward", &backward, "wkv5 backward"); 15 | } 16 | 17 | TORCH_LIBRARY(wkv5, m) { 18 | m.def("forward", forward); 19 | m.def("backward", backward); 20 | } 21 | -------------------------------------------------------------------------------- /wkv5/cuda/wkv5_ref.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y); 4 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *ww, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu); 5 | 6 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 7 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 8 | } 9 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 10 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); 11 | } 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("forward", &forward, "wkv5_ref forward"); 15 | m.def("backward", &backward, "wkv5_ref backward"); 16 | } 17 | 18 | TORCH_LIBRARY(wkv5_ref, m) { 19 | m.def("forward", forward); 20 | m.def("backward", backward); 21 | } 22 | -------------------------------------------------------------------------------- /wkv6/cuda/wkv5_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv5 forward"); 16 | m.def("backward", &backward, "wkv5 backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv5, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /wkv5_bf16/cuda/wkv5_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv5 forward"); 16 | m.def("backward", &backward, "wkv5 backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv5, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /wkv6_state/cuda/wkv6state_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) { 12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr(), gs.data_ptr()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv6b forward"); 16 | m.def("backward", &backward, "wkv6b backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv6b, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /rwkv7_fast_fused/cuda/rwkv7_clampw.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #ifdef _FP32_ 4 | using bf = float; 5 | #else 6 | #include 7 | using bf = __nv_bfloat16; 8 | #endif 9 | 10 | void cuda_forward(int B, int T, int H, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*y, float*s, float*sa); 11 | 12 | void forward(torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) { 13 | int B = r.sizes()[0], T = r.sizes()[1], H = r.sizes()[2]; 14 | cuda_forward(B, T, H, (bf*)r.data_ptr(), (bf*)w.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr()); 15 | } 16 | 17 | void cuda_backward(int B, int T, int H, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*dy, float*s, float*sa, bf*dr, bf*dw, bf*dk, bf*dv, bf*da, bf*db); 18 | 19 | void backward(torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &dy, 20 | torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dr, torch::Tensor &dw, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &da, torch::Tensor &db) { 21 | int B = r.sizes()[0], T = r.sizes()[1], H = r.sizes()[2]; 22 | cuda_backward(B, T, H, (bf*)r.data_ptr(), (bf*)w.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)dy.data_ptr(), 23 | (float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dr.data_ptr(), (bf*)dw.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)da.data_ptr(), (bf*)db.data_ptr()); 24 | } 25 | 26 | TORCH_LIBRARY(rwkv7_clampw, m) { 27 | m.def("forward", forward); 28 | m.def("backward", backward); 29 | } 30 | -------------------------------------------------------------------------------- /rwkv7_fast_fused/cuda/rwkv7_state_clampw.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #ifdef _FP32_ 4 | using bf = float; 5 | #else 6 | #include 7 | using bf = __nv_bfloat16; 8 | #endif 9 | 10 | void cuda_forward(int B, int T, int H, float*s0, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*y, float*s, float*sa); 11 | 12 | void forward(torch::Tensor &s0, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) { 13 | int B = r.sizes()[0], T = r.sizes()[1], H = r.sizes()[2]; 14 | cuda_forward(B, T, H, (float*)s0.data_ptr(), (bf*)r.data_ptr(), (bf*)w.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr()); 15 | } 16 | 17 | void cuda_backward(int B, int T, int H, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*dy, float*s, float*sa, float*ds0, bf*dr, bf*dw, bf*dk, bf*dv, bf*da, bf*db); 18 | 19 | void backward(torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &dy, 20 | torch::Tensor &s, torch::Tensor &sa, torch::Tensor &ds0, torch::Tensor &dr, torch::Tensor &dw, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &da, torch::Tensor &db) { 21 | int B = r.sizes()[0], T = r.sizes()[1], H = r.sizes()[2]; 22 | cuda_backward(B, T, H, (bf*)r.data_ptr(), (bf*)w.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)dy.data_ptr(), 23 | (float*)s.data_ptr(), (float*)sa.data_ptr(), (float*)ds0.data_ptr(), (bf*)dr.data_ptr(), (bf*)dw.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)da.data_ptr(), (bf*)db.data_ptr()); 24 | } 25 | 26 | TORCH_LIBRARY(rwkv7_state_clampw, m) { 27 | m.def("forward", forward); 28 | m.def("backward", backward); 29 | } 30 | -------------------------------------------------------------------------------- /rwkv7_fast_fused/cuda/rwkv7_statepassing_clampw.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #ifdef _FP32_ 4 | using bf = float; 5 | #else 6 | #include 7 | using bf = __nv_bfloat16; 8 | #endif 9 | 10 | void cuda_forward(int B, int T, int H, float*s0, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*y, float*sT, float*s, float*sa); 11 | 12 | void forward(torch::Tensor &s0, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y, torch::Tensor &sT, torch::Tensor &s, torch::Tensor &sa) { 13 | int B = r.sizes()[0], T = r.sizes()[1], H = r.sizes()[2]; 14 | cuda_forward(B, T, H, (float*)s0.data_ptr(), (bf*)r.data_ptr(), (bf*)w.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)y.data_ptr(), (float*)sT.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr()); 15 | } 16 | 17 | void cuda_backward(int B, int T, int H, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*dy, float*dsT, float*s, float*sa, float*ds0, bf*dr, bf*dw, bf*dk, bf*dv, bf*da, bf*db); 18 | 19 | void backward(torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &dy, torch::Tensor &dsT, 20 | torch::Tensor &s, torch::Tensor &sa, torch::Tensor &ds0, torch::Tensor &dr, torch::Tensor &dw, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &da, torch::Tensor &db) { 21 | int B = r.sizes()[0], T = r.sizes()[1], H = r.sizes()[2]; 22 | cuda_backward(B, T, H, (bf*)r.data_ptr(), (bf*)w.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)dy.data_ptr(), (float*)dsT.data_ptr(), 23 | (float*)s.data_ptr(), (float*)sa.data_ptr(), (float*)ds0.data_ptr(), (bf*)dr.data_ptr(), (bf*)dw.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)da.data_ptr(), (bf*)db.data_ptr()); 24 | } 25 | 26 | TORCH_LIBRARY(rwkv7_statepassing_clampw, m) { 27 | m.def("forward", forward); 28 | m.def("backward", backward); 29 | } 30 | -------------------------------------------------------------------------------- /wkv5a/cuda/wkv5a_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | typedef float DTYPE; 5 | 6 | void cuda_forward(int B, int T, int C, int H, DTYPE *r, DTYPE *k, DTYPE *v, float *w1, DTYPE *u1, float *w2, DTYPE *u2, DTYPE *y); 7 | void cuda_backward(int B, int T, int C, int H, DTYPE *r, DTYPE *k, DTYPE *v, float *w1, float *ww1, DTYPE *u1, float *w2, float *ww2, DTYPE *u2, DTYPE *gy, DTYPE *gr, DTYPE *gk, DTYPE *gv, DTYPE *gw1, DTYPE *gu1, DTYPE *gw2, DTYPE *gu2); 8 | 9 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w1, torch::Tensor &u1, torch::Tensor &w2, torch::Tensor &u2, torch::Tensor &y) { 10 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w1.data_ptr(), u1.data_ptr(), w2.data_ptr(), u2.data_ptr(), y.data_ptr()); 11 | } 12 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w1, torch::Tensor &ww1, torch::Tensor &u1, torch::Tensor &w2, torch::Tensor &ww2, torch::Tensor &u2, 13 | torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw1, torch::Tensor &gu1, torch::Tensor &gw2, torch::Tensor &gu2) { 14 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w1.data_ptr(), ww1.data_ptr(), u1.data_ptr(), w2.data_ptr(), ww2.data_ptr(), u2.data_ptr(), 15 | gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw1.data_ptr(), gu1.data_ptr(), gw2.data_ptr(), gu2.data_ptr()); 16 | } 17 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 18 | m.def("forward", &forward, "wkv5a forward"); 19 | m.def("backward", &backward, "wkv5a backward"); 20 | } 21 | 22 | TORCH_LIBRARY(wkv5a, m) { 23 | m.def("forward", forward); 24 | m.def("backward", backward); 25 | } 26 | -------------------------------------------------------------------------------- /depthwise_conv1d/cuda/timex_cuda_v0.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | template 4 | __global__ void kernel_forward(const F *__restrict__ const w, const F *__restrict__ const k, F *__restrict__ const x, 5 | const F eps, const int B, const int C, const int T) 6 | { 7 | const int i = blockIdx.y; 8 | const int t = threadIdx.x; 9 | 10 | F s = eps; 11 | const F *__restrict__ const www = w + (i % C) * T + (T - 1) - t; 12 | const F *__restrict__ const kk = k + i * T; 13 | for (int u = 0; u <= t; u++) 14 | { 15 | s += www[u] * kk[u]; 16 | } 17 | x[i * T + t] = s; 18 | } 19 | 20 | template 21 | __global__ void kernel_backward(const F *__restrict__ const w, const F *__restrict__ const k, const F *__restrict__ const gwk, 22 | F *__restrict__ const gw, F *__restrict__ const gk, 23 | const int B, const int C, const int T) 24 | { 25 | const int i = blockIdx.y; 26 | const int t = threadIdx.x; 27 | 28 | F s = 0; 29 | const F *__restrict__ const ggk = gwk + i * T + (T - 1) - t; 30 | const F *__restrict__ const kk = k + i * T; 31 | for (int u = 0; u <= t; u++) 32 | { 33 | s += ggk[u] * kk[u]; 34 | } 35 | gw[i * T + t] = s; 36 | 37 | s = 0; 38 | const F *__restrict__ const ggw = gwk + i * T + (T - 1) + t; 39 | const F *__restrict__ const ww = w + (i % C) * T; 40 | for (int u = t; u < T; u++) 41 | { 42 | s += ggw[-u] * ww[u]; 43 | } 44 | gk[i * T + t] = s; 45 | } 46 | 47 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) 48 | { 49 | dim3 gridDim(1, B * C); 50 | dim3 blockDim(T); 51 | kernel_forward<<>>(w, k, x, eps, B, C, T); 52 | } 53 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) 54 | { 55 | dim3 gridDim(1, B * C); 56 | dim3 blockDim(T); 57 | kernel_backward<<>>(w, k, gwk, gw, gk, B, C, T); 58 | } 59 | -------------------------------------------------------------------------------- /depthwise_conv1d/cuda/timex_cuda_v1.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // require T <= Tmax 4 | 5 | template 6 | __global__ void kernel_forward(const F *__restrict__ const w, const F *__restrict__ const k, F *__restrict__ const x, 7 | const F eps, const int B, const int C, const int T) 8 | { 9 | const int i = blockIdx.y; 10 | const int t = threadIdx.x; 11 | 12 | __shared__ F ww[Tmax]; 13 | __shared__ F kk[Tmax]; 14 | ww[t] = w[(i % C) * T + t]; 15 | kk[t] = k[i * T + t]; 16 | 17 | __syncthreads(); 18 | 19 | F s = eps; 20 | const F *__restrict__ const www = ww + (T - 1) - t; 21 | for (int u = 0; u <= t; u++) 22 | { 23 | s += www[u] * kk[u]; 24 | } 25 | x[i * T + t] = s; 26 | } 27 | 28 | template 29 | __global__ void kernel_backward(const F *__restrict__ const w, const F *__restrict__ const k, const F *__restrict__ const gwk, 30 | F *__restrict__ const gw, F *__restrict__ const gk, 31 | const int B, const int C, const int T) 32 | { 33 | const int i = blockIdx.y; 34 | const int t = threadIdx.x; 35 | 36 | __shared__ F gg[Tmax]; 37 | __shared__ F kk[Tmax]; 38 | __shared__ F ww[Tmax]; 39 | gg[t] = gwk[i * T + t]; 40 | kk[t] = k[i * T + t]; 41 | ww[t] = w[(i % C) * T + t]; 42 | 43 | __syncthreads(); 44 | 45 | F s = 0; 46 | const F *__restrict__ const ggk = gg + (T - 1) - t; 47 | for (int u = 0; u <= t; u++) 48 | { 49 | s += ggk[u] * kk[u]; 50 | } 51 | gw[i * T + t] = s; 52 | 53 | s = 0; 54 | const F *__restrict__ const ggw = gg + (T - 1) + t; 55 | for (int u = t; u < T; u++) 56 | { 57 | s += ggw[-u] * ww[u]; 58 | } 59 | gk[i * T + t] = s; 60 | } 61 | 62 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) 63 | { 64 | dim3 gridDim(1, B * C); 65 | dim3 blockDim(T); 66 | kernel_forward<<>>(w, k, x, eps, B, C, T); 67 | } 68 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) 69 | { 70 | dim3 gridDim(1, B * C); 71 | dim3 blockDim(T); 72 | kernel_backward<<>>(w, k, gwk, gw, gk, B, C, T); 73 | } 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /wkv5/cuda/wkv5_cuda_v1c.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | template 5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, 7 | F *__restrict__ const _y) 8 | { 9 | const int b = blockIdx.x / H; 10 | const int h = blockIdx.x % H; 11 | const int i = threadIdx.x; 12 | _w += h*N; 13 | _u += h*N; 14 | 15 | __shared__ float state[N * N], rr[N], kk[N]; 16 | 17 | for (int j = 0; j < N; ++j) 18 | state[j * N + i] = 0; 19 | 20 | for (int _t = b*T*C + h*N + i, _tend = (b+1)*T*C + h*N + i; _t < _tend; _t += C) 21 | { 22 | const F vv = _v[_t]; 23 | F yy = 0; 24 | 25 | rr[i] = _r[_t]; 26 | kk[i] = _k[_t]; 27 | 28 | __syncthreads(); 29 | 30 | for (int j = 0; j < N; j++) 31 | { 32 | const float ww = _w[j]; 33 | const float uu = _u[j]; 34 | 35 | float x = kk[j] * vv; 36 | 37 | float s = state[j * N + i]; 38 | yy += rr[j] * (uu * x + s); 39 | state[j * N + i] = s * ww + x; 40 | } 41 | _y[_t] = yy; 42 | 43 | __syncthreads(); 44 | } 45 | } 46 | 47 | template 48 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 49 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _gy, 50 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 51 | { 52 | 53 | } 54 | 55 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y) 56 | { 57 | assert(H*N == C); 58 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 59 | } 60 | 61 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu) 62 | { 63 | assert(H*N == C); 64 | const int SIZE = B*C; 65 | dim3 threadsPerBlock(min(SIZE, 32)); 66 | assert(SIZE % threadsPerBlock.x == 0); 67 | dim3 numBlocks(SIZE / threadsPerBlock.x); 68 | kernel_backward<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu); 69 | } 70 | -------------------------------------------------------------------------------- /wkv5/cuda/wkv5_cuda_v1a.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | template 5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, 7 | F *__restrict__ const _y) 8 | { 9 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 10 | const int _b = idx / C; 11 | const int _h = (idx / N) % H; 12 | const int _i = idx % N; 13 | 14 | const int _o0 = _b*T*C + _h*N; 15 | const int _o1 = _h*N; 16 | const F *__restrict__ const k = _k + _o0; 17 | const F *__restrict__ const v = _v + _o0 + _i; 18 | const F *__restrict__ const r = _r + _o0; 19 | F *__restrict__ const y = _y + _o0 + _i; 20 | 21 | float state[N] = {0}; 22 | 23 | for (int _t = 0; _t < T; _t++) 24 | { 25 | const int tt = _t*C; 26 | const F vv = v[tt]; 27 | F yy = 0; 28 | 29 | #pragma unroll 30 | for (int _j = 0; _j < N; _j++) 31 | { 32 | const int j = tt + _j; 33 | const int m = _o1 + _j; 34 | 35 | const float x = k[j] * vv; 36 | const float s = state[_j]; 37 | 38 | yy += r[j] * (_u[m] * x + s); 39 | state[_j] = s * _w[m] + x; 40 | } 41 | y[tt] = yy; 42 | } 43 | } 44 | 45 | template 46 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 47 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _gy, 48 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 49 | { 50 | 51 | } 52 | 53 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y) 54 | { 55 | assert(H*N == C); 56 | const int SIZE = B*C; 57 | dim3 threadsPerBlock(min(SIZE, 32)); 58 | assert(SIZE % threadsPerBlock.x == 0); 59 | dim3 numBlocks(SIZE / threadsPerBlock.x); 60 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 61 | } 62 | 63 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu) 64 | { 65 | assert(H*N == C); 66 | const int SIZE = B*C; 67 | dim3 threadsPerBlock(min(SIZE, 32)); 68 | assert(SIZE % threadsPerBlock.x == 0); 69 | dim3 numBlocks(SIZE / threadsPerBlock.x); 70 | kernel_backward<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu); 71 | } 72 | -------------------------------------------------------------------------------- /wkv5/cuda/wkv5_cuda_v1d.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | template 5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, 7 | F *__restrict__ const _y) 8 | { 9 | const int b = blockIdx.x / H; 10 | const int h = blockIdx.x % H; 11 | const int i = threadIdx.x; 12 | const float4 *__restrict__ const ww = (float4 *)(_w + h*N); 13 | const float4 *__restrict__ const uu = (float4 *)(_u + h*N); 14 | 15 | __shared__ float state[N*N], rr[N], kk[N]; 16 | 17 | #pragma unroll 18 | for (int j = 0; j < N; ++j) 19 | state[j*N + i] = 0; // will __syncthreads soon 20 | 21 | for (int bthi = b*T*H*N + 0*H*N + h*N + i; bthi < b*T*H*N + T*H*N + h*N + i; bthi += C) 22 | { 23 | __syncthreads(); 24 | rr[i] = _r[bthi]; // rr[0:N] = _r[b,t,h,0:N] 25 | kk[i] = _k[bthi]; // kk[0:N] = _r[b,t,h,0:N] 26 | __syncthreads(); 27 | 28 | const float v = _v[bthi]; 29 | float y = 0; 30 | 31 | const float4 *__restrict__ const rrr = (float4 *)(rr); 32 | const float4 *__restrict__ const kkk = (float4 *)(kk); 33 | float4 x, s; 34 | 35 | #pragma unroll 36 | for (int j = 0; j < N/4; ++j) 37 | { 38 | const float4 r = rrr[j]; 39 | const float4 k = kkk[j]; 40 | const float4 w = ww[j]; 41 | const float4 u = uu[j]; 42 | 43 | x.x = k.x * v; 44 | x.y = k.y * v; 45 | x.z = k.z * v; 46 | x.w = k.w * v; 47 | 48 | const int jj = (j<<2)*N + i; 49 | s.x = state[jj + 0*N]; 50 | s.y = state[jj + 1*N]; 51 | s.z = state[jj + 2*N]; 52 | s.w = state[jj + 3*N]; 53 | 54 | y += r.x * (u.x * x.x + s.x); 55 | y += r.y * (u.y * x.y + s.y); 56 | y += r.z * (u.z * x.z + s.z); 57 | y += r.w * (u.w * x.w + s.w); 58 | 59 | state[jj + 0*N] = s.x * w.x + x.x; 60 | state[jj + 1*N] = s.y * w.y + x.y; 61 | state[jj + 2*N] = s.z * w.z + x.z; 62 | state[jj + 3*N] = s.w * w.w + x.w; 63 | } 64 | _y[bthi] = y; 65 | } 66 | } 67 | 68 | template 69 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 70 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _gy, 71 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 72 | { 73 | 74 | } 75 | 76 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y) 77 | { 78 | assert(H*N == C); 79 | assert(N%4 == 0); 80 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 81 | } 82 | 83 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu) 84 | { 85 | 86 | } 87 | -------------------------------------------------------------------------------- /wkv5/cuda/wkv5_cuda_v3.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | template 5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, 7 | F *__restrict__ const _y) 8 | { 9 | 10 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 11 | const int _b = idx / C; 12 | const int _h = (idx / N) - ((idx / N) / H) * H; 13 | const int _i = idx - (idx / N) * N; 14 | 15 | const int _o0 = _b * T * C + _h * N; 16 | const int _o1 = _h * N; 17 | 18 | const float4 *__restrict__ const k = (float4 *)(_k + _o0); 19 | const float4 *__restrict__ const r = (float4 *)(_r + _o0); 20 | const float4 *__restrict__ const w = (float4 *)(_w + _o1); 21 | const float4 *__restrict__ const u = (float4 *)(_u + _o1); 22 | const F *__restrict__ const v = _v + _o0 + _i; 23 | F *__restrict__ const y = _y + _o0 + _i; 24 | 25 | __align__(16) float4 state[N >> 2] = { make_float4(0.0f, 0.0f, 0.0f, 0.0f) }; 26 | 27 | for (int __t = 0; __t < T; __t++) 28 | { 29 | const int _t = __t * (C >> 2); 30 | const int tt = __t * C; 31 | const F vv = v[tt]; 32 | float yy = 0.0f; 33 | 34 | #pragma unroll 35 | for (int _j = 0; _j < N >> 2; _j++) 36 | { 37 | const int j = _t + _j; 38 | 39 | const float4 k_val = k[j]; 40 | const float4 r_val = r[j]; 41 | const float4 ww = w[_j]; 42 | const float4 uu = u[_j]; 43 | float4 x; 44 | x.x = k_val.x * vv; 45 | x.y = k_val.y * vv; 46 | x.z = k_val.z * vv; 47 | x.w = k_val.w * vv; 48 | 49 | float4 &s = state[_j]; 50 | 51 | yy += r_val.x * (uu.x * x.x + s.x) + r_val.y * (uu.y * x.y + s.y) + r_val.z * (uu.z * x.z + s.z) + r_val.w * (uu.w * x.w + s.w); 52 | 53 | s.x = s.x * ww.x + x.x; 54 | s.y = s.y * ww.y + x.y; 55 | s.z = s.z * ww.z + x.z; 56 | s.w = s.w * ww.w + x.w; 57 | } 58 | 59 | y[tt] = yy; 60 | } 61 | } 62 | 63 | template 64 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 65 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _gy, 66 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 67 | { 68 | 69 | } 70 | 71 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y) 72 | { 73 | assert(H*N == C); 74 | const int SIZE = B*C; 75 | dim3 threadsPerBlock(min(SIZE, 32)); 76 | assert(SIZE % threadsPerBlock.x == 0); 77 | dim3 numBlocks(SIZE / threadsPerBlock.x); 78 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 79 | } 80 | 81 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu) 82 | { 83 | assert(H*N == C); 84 | const int SIZE = B*C; 85 | dim3 threadsPerBlock(min(SIZE, 32)); 86 | assert(SIZE % threadsPerBlock.x == 0); 87 | dim3 numBlocks(SIZE / threadsPerBlock.x); 88 | kernel_backward<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu); 89 | } 90 | -------------------------------------------------------------------------------- /wkv5/cuda/wkv5_cuda_v1b.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | template 5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, 7 | F *__restrict__ const _y) 8 | { 9 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 10 | const int _b = idx / C; 11 | const int _h = (idx / N) % H; 12 | const int i = idx % N; 13 | 14 | const int _o0 = _b*T*C + _h*N; 15 | const int _o1 = _h*N; 16 | const float4 *__restrict__ const r = (float4 *)(_r + _o0); 17 | const float4 *__restrict__ const k = (float4 *)(_k + _o0); 18 | const float4 *__restrict__ const w = (float4 *)(_w + _o1); 19 | const float4 *__restrict__ const u = (float4 *)(_u + _o1); 20 | 21 | const F *__restrict__ const v = _v + _o0 + i; 22 | F *__restrict__ const y = _y + _o0 + i; 23 | 24 | __align__(16) float4 state[N/4] = { make_float4(0.0f, 0.0f, 0.0f, 0.0f) }; 25 | 26 | for (int _t = 0; _t < T; _t++) 27 | { 28 | const int tt = _t*C; 29 | const int ttt = tt >> 2; 30 | const F vv = v[tt]; 31 | F yy = 0; 32 | 33 | #pragma unroll 34 | for (int j = 0; j < N/4; j++) 35 | { 36 | const float4 rr = r[ttt + j]; 37 | const float4 kk = k[ttt + j]; 38 | const float4 ww = w[j]; 39 | const float4 uu = u[j]; 40 | 41 | float4 x; 42 | x.x = kk.x * vv; 43 | x.y = kk.y * vv; 44 | x.z = kk.z * vv; 45 | x.w = kk.w * vv; 46 | 47 | float4 s = state[j]; 48 | yy += rr.x * (uu.x * x.x + s.x) + rr.y * (uu.y * x.y + s.y) + rr.z * (uu.z * x.z + s.z) + rr.w * (uu.w * x.w + s.w); 49 | 50 | float4* ss = state + j; 51 | ss->x = s.x * ww.x + x.x; 52 | ss->y = s.y * ww.y + x.y; 53 | ss->z = s.z * ww.z + x.z; 54 | ss->w = s.w * ww.w + x.w; 55 | } 56 | y[tt] = yy; 57 | } 58 | } 59 | 60 | template 61 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 62 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _gy, 63 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 64 | { 65 | 66 | } 67 | 68 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y) 69 | { 70 | assert(H*N == C); 71 | const int SIZE = B*C; 72 | dim3 threadsPerBlock(min(SIZE, 32)); 73 | assert(SIZE % threadsPerBlock.x == 0); 74 | dim3 numBlocks(SIZE / threadsPerBlock.x); 75 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 76 | } 77 | 78 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu) 79 | { 80 | assert(H*N == C); 81 | const int SIZE = B*C; 82 | dim3 threadsPerBlock(min(SIZE, 32)); 83 | assert(SIZE % threadsPerBlock.x == 0); 84 | dim3 numBlocks(SIZE / threadsPerBlock.x); 85 | kernel_backward<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu); 86 | } 87 | -------------------------------------------------------------------------------- /depthwise_conv1d/cuda/timex_cuda_v2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // require T % 4 == 0 and T <= Tmax (passed by compiler) 4 | 5 | #define F4(A, B) ((float4 *)(A))[(B) >> 2] 6 | 7 | template 8 | __global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x, 9 | const F eps, const int B, const int C, const int T) { 10 | const int i = blockIdx.y; 11 | const int t = threadIdx.x << 2; 12 | 13 | __shared__ F ww[Tmax]; 14 | __shared__ F k[Tmax]; 15 | F4(ww, t) = F4(__w, t + T * (i % C)); 16 | F4(k, t) = F4(__k, t + T * i); 17 | __syncthreads(); 18 | 19 | float4 s = {eps, eps, eps, eps}; 20 | 21 | const F *__restrict__ const w = ww + T - t - 4; 22 | for (int u = 0; u <= t; u++) { 23 | F x = k[u]; 24 | s.x += w[u + 3] * x; 25 | s.y += w[u + 2] * x; 26 | s.z += w[u + 1] * x; 27 | s.w += w[u + 0] * x; 28 | } 29 | s.y += w[t + 3] * k[t + 1]; 30 | s.z += w[t + 2] * k[t + 1]; 31 | s.z += w[t + 3] * k[t + 2]; 32 | s.w += w[t + 1] * k[t + 1]; 33 | s.w += w[t + 2] * k[t + 2]; 34 | s.w += w[t + 3] * k[t + 3]; 35 | 36 | F4(x, t + T * i) = s; 37 | } 38 | 39 | template 40 | __global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk, 41 | F *__restrict__ const gw, F *__restrict__ const gk, 42 | const int B, const int C, const int T) { 43 | const int i = blockIdx.y; 44 | const int t = threadIdx.x << 2; 45 | 46 | __shared__ F w[Tmax]; 47 | __shared__ F k[Tmax]; 48 | __shared__ F gg[Tmax]; 49 | F4(w, t) = F4(__w, t + T * (i % C)); 50 | F4(k, t) = F4(__k, t + T * i); 51 | F4(gg, t) = F4(__gwk, t + T * i); 52 | __syncthreads(); 53 | 54 | float4 s = {0, 0, 0, 0}; 55 | const F *__restrict__ const ga = gg + T - t - 4; 56 | 57 | for (int u = 0; u <= t; u++) { 58 | F x = k[u]; 59 | s.x += ga[u + 3] * x; 60 | s.y += ga[u + 2] * x; 61 | s.z += ga[u + 1] * x; 62 | s.w += ga[u] * x; 63 | } 64 | s.y += ga[t + 3] * k[t + 1]; 65 | s.z += ga[t + 2] * k[t + 1]; 66 | s.z += ga[t + 3] * k[t + 2]; 67 | s.w += ga[t + 1] * k[t + 1]; 68 | s.w += ga[t + 2] * k[t + 2]; 69 | s.w += ga[t + 3] * k[t + 3]; 70 | 71 | F4(gw, t + T * i) = s; 72 | 73 | s.x = 0; 74 | s.y = 0; 75 | s.z = 0; 76 | s.w = 0; 77 | const F *__restrict__ const gb = gg + T + t - 3; 78 | 79 | for (int u = t + 3; u < T; u++) { 80 | F x = w[u]; 81 | s.x += gb[2 - u] * x; 82 | s.y += gb[3 - u] * x; 83 | s.z += gb[4 - u] * x; 84 | s.w += gb[5 - u] * x; 85 | } 86 | s.x += gb[2 - t] * w[t + 0]; 87 | s.x += gb[1 - t] * w[t + 1]; 88 | s.x += gb[0 - t] * w[t + 2]; 89 | s.y += gb[2 - t] * w[t + 1]; 90 | s.y += gb[1 - t] * w[t + 2]; 91 | s.z += gb[2 - t] * w[t + 2]; 92 | 93 | F4(gk, t + T * i) = s; 94 | } 95 | 96 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) { 97 | dim3 gridDim(1, B * C); 98 | dim3 blockDim(T >> 2); 99 | kernel_forward<<>>(w, k, x, eps, B, C, T); 100 | } 101 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) { 102 | dim3 gridDim(1, B * C); 103 | dim3 blockDim(T >> 2); 104 | kernel_backward<<>>(w, k, gwk, gw, gk, B, C, T); 105 | } 106 | -------------------------------------------------------------------------------- /wkv/cuda/wkv_cuda_v1.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | template 4 | __global__ void kernel_forward(const int B, const int T, const int C, 5 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, 6 | F *__restrict__ const _y) 7 | { 8 | const int _b = blockIdx.x; 9 | const int _c = threadIdx.x; 10 | const int _offset = _b * T * C + _c; 11 | 12 | F u = _u[_c]; 13 | F w = _w[_c]; 14 | const F *__restrict__ const k = _k + _offset; 15 | const F *__restrict__ const v = _v + _offset; 16 | F *__restrict__ const y = _y + _offset; 17 | 18 | F p = 0, q = 0, o = -65500; 19 | // p and q are running sums divided by exp(o) (to avoid overflows) 20 | for (int i = 0; i < T; i++) 21 | { 22 | const int ii = i * C; 23 | 24 | F no = max(o, u + k[ii]); 25 | F A = exp(o - no); 26 | F B = exp(u + k[ii] - no); 27 | y[ii] = (A * p + B * v[ii]) / (A * q + B); 28 | 29 | no = max(w + o, k[ii]); 30 | A = exp(w + o - no); 31 | B = exp(k[ii] - no); 32 | p = A * p + B * v[ii]; 33 | q = A * q + B; 34 | o = no; 35 | } 36 | } 37 | 38 | template 39 | __global__ void kernel_backward(const int B, const int T, const int C, 40 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, 41 | F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) 42 | { 43 | const int _b = blockIdx.x; 44 | const int _c = threadIdx.x; 45 | const int _offset = _b * T * C + _c; 46 | 47 | F u = _u[_c]; 48 | F w = _w[_c]; 49 | const F *__restrict__ const k = _k + _offset; 50 | const F *__restrict__ const v = _v + _offset; 51 | const F *__restrict__ const gy = _gy + _offset; 52 | 53 | F *__restrict__ const gk = _gk + _offset; 54 | F *__restrict__ const gv = _gv + _offset; 55 | 56 | F y[1024], z[1024], zexp[1024]; 57 | 58 | F gw = 0, gu = 0; 59 | F p = 0, q = 0; 60 | F dpdw = 0, dqdw = 0; 61 | F o = -65500; 62 | for (int i = 0; i < T; i++) 63 | { 64 | const int ii = i * C; 65 | F no = max(o, k[ii] + u); 66 | F A = exp(o - no); 67 | F B = exp(k[ii] + u - no); 68 | 69 | F num = A * p + B * v[ii]; 70 | F iden = 1 / (A * q + B); 71 | 72 | y[i] = num * iden; 73 | z[i] = iden; 74 | zexp[i] = k[ii] + u - no; 75 | 76 | gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; 77 | gu += gy[ii] * (v[ii] - y[i]) * B * iden; 78 | 79 | no = max(w + o, k[ii]); 80 | A = exp(w + o - no); 81 | B = exp(k[ii] - no); 82 | dpdw = A * (p + dpdw); 83 | dqdw = A * (q + dqdw); 84 | p = A * p + B * v[ii]; 85 | q = A * q + B; 86 | o = no; 87 | } 88 | 89 | F gp = 0, gq = 0; 90 | o = -65500; 91 | for (int i = T - 1; i >= 0; i--) 92 | { 93 | const int ii = i * C; 94 | F A = gy[ii] * z[i] * exp(zexp[i]); 95 | F B = exp(k[ii] + o); 96 | gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); 97 | gv[ii] = A + B * gp; 98 | 99 | F no = max(w + o, zexp[i] - k[ii] - u); 100 | A = exp(w + o - no); 101 | B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); 102 | gp = A * gp + B; 103 | gq = A * gq - B * y[i]; 104 | o = no; 105 | } 106 | 107 | // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass 108 | const int _offsetBC = _b * C + _c; 109 | _gw[_offsetBC] += gw * _w[_c]; 110 | _gu[_offsetBC] += gu; 111 | } 112 | 113 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) 114 | { 115 | dim3 numBlocks(B); 116 | dim3 threadsPerBlock(C); 117 | kernel_forward<<>>(B, T, C, w, u, k, v, y); 118 | } 119 | 120 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) 121 | { 122 | dim3 numBlocks(B); 123 | dim3 threadsPerBlock(C); 124 | kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); 125 | } 126 | -------------------------------------------------------------------------------- /wkv5/cuda/wkv5_cuda_ref.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | template 5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, 7 | F *__restrict__ const _y) 8 | { 9 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 10 | const int _b = idx / C; 11 | const int _h = (idx / N) % H; 12 | const int _i = idx % N; 13 | 14 | const int _o0 = _b*T*C + _h*N; 15 | const int _o1 = _h*N; 16 | const F *__restrict__ const k = _k + _o0; 17 | const F *__restrict__ const v = _v + _o0 + _i; 18 | const F *__restrict__ const r = _r + _o0; 19 | F *__restrict__ const y = _y + _o0 + _i; 20 | 21 | float state[N] = {0}; 22 | 23 | for (int __t = 0; __t < T; __t++) 24 | { 25 | const int _t = __t*C; 26 | const F vv = v[_t]; 27 | 28 | for (int _j = 0; _j < N; _j++) 29 | { 30 | const int j = _t + _j; 31 | const int m = _o1 + _j; 32 | 33 | const float x = k[j] * vv; 34 | const float s = state[_j]; 35 | 36 | atomicAdd(y + _t, r[j] * (_u[m] * x + s)); 37 | state[_j] = s * _w[m] + x; 38 | } 39 | } 40 | } 41 | 42 | template 43 | __global__ void kernel_backward (const int B, const int T, const int C, const int H, 44 | const F *__restrict__ const r, const F *__restrict__ const k, const F *__restrict__ const v, const F *__restrict__ const w, const F *__restrict__ const wwww, const F *__restrict__ const _u, const F *__restrict__ const gy, 45 | F *__restrict__ const gr, F *__restrict__ const gk, F *__restrict__ const gv, F *__restrict__ const gw, F *__restrict__ const gu) 46 | { 47 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 48 | const int b = idx / C; 49 | const int h = (idx / N) % H; 50 | const int i = idx % N; 51 | 52 | for (int t = 0; t < T; t++) { 53 | for (int j = 0; j < N; j++) { 54 | for (int tt = 0; tt <= t; tt++) { 55 | F ww = (tt == t) ? _u[h*N + i] : pow(w[h*N + i], t-tt-1); 56 | 57 | gr[b*T*H*N + t*H*N + h*N + i] += ww * k[b*T*H*N + tt*H*N + h*N + i] * 58 | v[b*T*H*N + tt*H*N + h*N + j] * gy[b*T*H*N + t*H*N + h*N + j]; 59 | } 60 | 61 | for (int tt = t; tt < T; tt++) { 62 | F ww = (tt == t) ? _u[h*N + i] : pow(w[h*N + i], tt-t-1); 63 | 64 | gk[b*T*H*N + t*H*N + h*N + i] += r[b*T*H*N + tt*H*N + h*N + i] * ww * 65 | v[b*T*H*N + t*H*N + h*N + j] * gy[b*T*H*N + tt*H*N + h*N + j]; 66 | 67 | ww = (tt == t) ? _u[h*N + j] : pow(w[h*N + j], tt-t-1); 68 | 69 | gv[b*T*H*N + t*H*N + h*N + i] += r[b*T*H*N + tt*H*N + h*N + j] * ww * 70 | k[b*T*H*N + t*H*N + h*N + j] * gy[b*T*H*N + tt*H*N + h*N + i]; 71 | } 72 | 73 | atomicAdd(gu + h*N + i, r[b*T*H*N + t*H*N + h*N + i] * k[b*T*H*N + t*H*N + h*N + i] * 74 | v[b*T*H*N + t*H*N + h*N + j] * gy[b*T*H*N + t*H*N + h*N + j]); 75 | 76 | for (int tt = 0; tt < t-1; tt++) { 77 | F ww = (t-tt-1) * wwww[h*N + i] * pow(w[h*N + i], t-tt-1); 78 | 79 | atomicAdd(gw + h*N + i, r[b*T*H*N + t*H*N + h*N + i] * ww * k[b*T*H*N + tt*H*N + h*N + i] * 80 | v[b*T*H*N + tt*H*N + h*N + j] * gy[b*T*H*N + t*H*N + h*N + j]); 81 | } 82 | } 83 | } 84 | } 85 | 86 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y) 87 | { 88 | dim3 threadsPerBlock( min(B*C, 32) ); 89 | assert(B * C % threadsPerBlock.x == 0); 90 | dim3 numBlocks(B * C / threadsPerBlock.x); 91 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 92 | } 93 | 94 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *ww, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu) 95 | { 96 | dim3 threadsPerBlock( min(B*C, 32) ); 97 | assert(B * C % threadsPerBlock.x == 0); 98 | dim3 numBlocks(B * C / threadsPerBlock.x); 99 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); 100 | } 101 | -------------------------------------------------------------------------------- /wkv/cuda/wkv_cuda_v2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | template 5 | __global__ void kernel_forward(const int B, const int T, const int C, 6 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, 7 | F *__restrict__ const _y) 8 | { 9 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 10 | const int _b = idx / C; 11 | const int _c = idx % C; 12 | const int _offset = _b * T * C + _c; 13 | 14 | F u = _u[_c]; 15 | F w = _w[_c]; 16 | const F *__restrict__ const k = _k + _offset; 17 | const F *__restrict__ const v = _v + _offset; 18 | F *__restrict__ const y = _y + _offset; 19 | 20 | F p = 0, q = 0, o = -65500; 21 | // p and q are running sums divided by exp(o) (to avoid overflows) 22 | for (int i = 0; i < T; i++) 23 | { 24 | const int ii = i * C; 25 | 26 | F no = max(o, u + k[ii]); 27 | F A = exp(o - no); 28 | F B = exp(u + k[ii] - no); 29 | y[ii] = (A * p + B * v[ii]) / (A * q + B); 30 | 31 | no = max(w + o, k[ii]); 32 | A = exp(w + o - no); 33 | B = exp(k[ii] - no); 34 | p = A * p + B * v[ii]; 35 | q = A * q + B; 36 | o = no; 37 | } 38 | } 39 | 40 | template 41 | __global__ void kernel_backward(const int B, const int T, const int C, 42 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, 43 | F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) 44 | { 45 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 46 | const int _b = idx / C; 47 | const int _c = idx % C; 48 | const int _offset = _b * T * C + _c; 49 | 50 | F u = _u[_c]; 51 | F w = _w[_c]; 52 | const F *__restrict__ const k = _k + _offset; 53 | const F *__restrict__ const v = _v + _offset; 54 | const F *__restrict__ const gy = _gy + _offset; 55 | 56 | F *__restrict__ const gk = _gk + _offset; 57 | F *__restrict__ const gv = _gv + _offset; 58 | 59 | F y[4096], z[4096], zexp[4096]; 60 | 61 | F gw = 0, gu = 0; 62 | F p = 0, q = 0; 63 | F dpdw = 0, dqdw = 0; 64 | F o = -65500; 65 | for (int i = 0; i < T; i++) 66 | { 67 | const int ii = i * C; 68 | F no = max(o, k[ii] + u); 69 | F A = exp(o - no); 70 | F B = exp(k[ii] + u - no); 71 | 72 | F num = A * p + B * v[ii]; 73 | F iden = 1 / (A * q + B); 74 | 75 | y[i] = num * iden; 76 | z[i] = iden; 77 | zexp[i] = k[ii] + u - no; 78 | 79 | gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; 80 | gu += gy[ii] * (v[ii] - y[i]) * B * iden; 81 | 82 | no = max(w + o, k[ii]); 83 | A = exp(w + o - no); 84 | B = exp(k[ii] - no); 85 | dpdw = A * (p + dpdw); 86 | dqdw = A * (q + dqdw); 87 | p = A * p + B * v[ii]; 88 | q = A * q + B; 89 | o = no; 90 | } 91 | 92 | F gp = 0, gq = 0; 93 | o = -65500; 94 | for (int i = T - 1; i >= 0; i--) 95 | { 96 | const int ii = i * C; 97 | F A = gy[ii] * z[i] * exp(zexp[i]); 98 | F B = exp(k[ii] + o); 99 | gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); 100 | gv[ii] = A + B * gp; 101 | 102 | F no = max(w + o, zexp[i] - k[ii] - u); 103 | A = exp(w + o - no); 104 | B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); 105 | gp = A * gp + B; 106 | gq = A * gq - B * y[i]; 107 | o = no; 108 | } 109 | 110 | // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass 111 | const int _offsetBC = _b * C + _c; 112 | _gw[_offsetBC] += gw * _w[_c]; 113 | _gu[_offsetBC] += gu; 114 | } 115 | 116 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) 117 | { 118 | dim3 threadsPerBlock( min(C, 32) ); 119 | assert(B * C % threadsPerBlock.x == 0); 120 | dim3 numBlocks(B * C / threadsPerBlock.x); 121 | kernel_forward<<>>(B, T, C, w, u, k, v, y); 122 | } 123 | 124 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) 125 | { 126 | dim3 threadsPerBlock( min(C, 32) ); 127 | assert(B * C % threadsPerBlock.x == 0); 128 | dim3 numBlocks(B * C / threadsPerBlock.x); 129 | kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); 130 | } 131 | -------------------------------------------------------------------------------- /wkv/cuda/wkv_cuda_v0.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | template 4 | __global__ void kernel_forward(const int B, const int T, const int C, 5 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, 6 | F *__restrict__ const _y) { 7 | const int _b = blockIdx.x; 8 | const int _c = threadIdx.x; 9 | const int _offset = _b*T*C + _c; 10 | 11 | F u = _u[_c]; 12 | F w = _w[_c]; 13 | const F *__restrict__ const k = _k + _offset; 14 | const F *__restrict__ const v = _v + _offset; 15 | F *__restrict__ const y = _y + _offset; 16 | 17 | y[0] = v[0]; 18 | F a = v[0]; 19 | F b = 1; 20 | F p = k[0]; 21 | for (int i = 1; i < T; i++) 22 | { 23 | const int ii = i*C; 24 | F kk = k[ii]; 25 | F vv = v[ii]; 26 | 27 | F q = max(p, u+kk); 28 | F e1 = exp(p - q); 29 | F e2 = exp(u+kk - q); 30 | y[ii] = (e1 * a + e2 * vv) / (e1 * b + e2); 31 | 32 | q = max(p+w, kk); 33 | e1 = exp(p+w - q); 34 | e2 = exp(kk - q); 35 | a = e1 * a + e2 * vv; 36 | b = e1 * b + e2; 37 | p = q; 38 | } 39 | } 40 | 41 | template 42 | __global__ void kernel_backward(const int B, const int T, const int C, 43 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, 44 | F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { 45 | const int _b = blockIdx.x; 46 | const int _c = threadIdx.x; 47 | 48 | F u = _u[_c]; 49 | F w = _w[_c]; 50 | const int _offset = _b*T*C + _c; 51 | const F *__restrict__ const gy = _gy + _offset; 52 | F k[1024]; 53 | F v[1024]; 54 | for (int i = 0; i < T; i++) 55 | { 56 | const int ii = _offset + i*C; 57 | k[i] = _k[ii]; 58 | v[i] = _v[ii]; 59 | } 60 | 61 | F gw = 0; 62 | F gu = 0; 63 | F gk[1024] = {0}; 64 | F gv[1024] = {0}; 65 | 66 | F a = 0; 67 | F b = 0; 68 | F p = -65500; 69 | F qq = 0; 70 | F r = 0; 71 | F rr = 0; 72 | F s = 0; 73 | F ss = 0; 74 | F ee = 0; 75 | for (int i = 0; i < T; i++) 76 | { 77 | F kk = k[i]; 78 | F vv = v[i]; 79 | F gg = gy[i*C]; 80 | 81 | F q = max(p, u+kk); 82 | F e1 = exp(p - q); 83 | F e2 = exp(u+kk - q); 84 | 85 | F c = e1 * a + e2 * vv; 86 | F d = e1 * b + e2; 87 | 88 | for (int j = 0; j < i; j++) 89 | { 90 | ee = exp((i-j-1)*w + k[j] - q) * gg / d; 91 | gv[j] += ee; 92 | gk[j] += ee * (v[j] - c / d); 93 | } 94 | ee = e2 * gg / d; 95 | gv[i] += ee; 96 | ee *= (vv - c / d); 97 | gk[i] += ee; 98 | gu += ee; 99 | 100 | if (i > 2) 101 | { 102 | e1 = exp(w + qq - q); 103 | e2 = exp(w + k[i-2] - q); 104 | ss = e1 * ss + e2; 105 | s = e1 * s + ss; 106 | rr = e1 * rr + e2 * v[i-2]; 107 | r = e1 * r + rr; 108 | } 109 | if (i == 2) 110 | { 111 | ss = exp(w + k[0] - q); 112 | s = ss; 113 | rr = ss * v[0]; 114 | r = rr; 115 | } 116 | gw += (r / d - c * s / (d * d)) * gg * w; 117 | qq = q; 118 | 119 | q = max(p+w, kk); 120 | e1 = exp(p+w - q); 121 | e2 = exp(kk - q); 122 | a = e1 * a + e2 * vv; 123 | b = e1 * b + e2; 124 | p = q; 125 | } 126 | 127 | const int _offsetBC = _b*C + _c; 128 | _gw[_offsetBC] += gw; 129 | _gu[_offsetBC] += gu; 130 | F *__restrict__ const __gk = _gk + _offset; 131 | F *__restrict__ const __gv = _gv + _offset; 132 | for (int i = 0; i < T; i++) 133 | { 134 | const int ii = i*C; 135 | __gk[ii] = gk[i]; 136 | __gv[ii] = gv[i]; 137 | } 138 | } 139 | 140 | // note: test B,C & 1,BC & BC,1 combinations 141 | 142 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { 143 | dim3 numBlocks(B); 144 | dim3 threadsPerBlock(C); 145 | kernel_forward<<>>(B, T, C, w, u, k, v, y); 146 | } 147 | 148 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) 149 | { 150 | dim3 numBlocks(B); 151 | dim3 threadsPerBlock(C); 152 | kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); 153 | } 154 | -------------------------------------------------------------------------------- /depthwise_conv1d/cuda/timex_cuda_v3.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler) 4 | 5 | #define F4(A, B) ((float4 *)(A))[(B) >> 2] 6 | 7 | template 8 | __global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x, 9 | const F eps, const int B, const int C, const int T) { 10 | const int i = blockIdx.y; 11 | const int t = threadIdx.x << 2; 12 | const int ti = t + T * i; 13 | const int tj = T * (B * C) / BF; 14 | 15 | __shared__ F ww[Tmax]; 16 | __shared__ F kk[Tmax * BF]; 17 | F4(ww, t) = F4(__w, t + T * (i % C)); 18 | 19 | #pragma unroll 20 | for (int j = 0; j < BF; j++) { 21 | F4(kk, t + Tmax * j) = F4(__k, ti + tj * j); 22 | } 23 | __syncthreads(); 24 | 25 | float4 ss[BF]; 26 | #pragma unroll 27 | for (int j = 0; j < BF; j++) { 28 | ss[j] = {eps, eps, eps, eps}; 29 | } 30 | for (int u = 0; u <= t; u++) { 31 | const F *__restrict__ const w = ww + T - t + u - 4; 32 | #pragma unroll 33 | for (int j = 0; j < BF; j++) { 34 | float4 *__restrict__ const s = ss + j; 35 | const F k = kk[u + Tmax * j]; 36 | s->x += w[3] * k; 37 | s->y += w[2] * k; 38 | s->z += w[1] * k; 39 | s->w += w[0] * k; 40 | } 41 | } 42 | #pragma unroll 43 | for (int j = 0; j < BF; j++) { 44 | float4 *__restrict__ const s = ss + j; 45 | const F *__restrict__ const w = ww + T - 3; 46 | const F *__restrict__ const k = kk + Tmax * j + t + 1; 47 | s->y += w[2] * k[0]; 48 | s->z += w[1] * k[0]; 49 | s->z += w[2] * k[1]; 50 | s->w += w[0] * k[0]; 51 | s->w += w[1] * k[1]; 52 | s->w += w[2] * k[2]; 53 | F4(x, ti + tj * j) = *s; 54 | } 55 | } 56 | 57 | template 58 | __global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk, 59 | F *__restrict__ const gw, F *__restrict__ const gk, 60 | const int B, const int C, const int T) { 61 | const int i = blockIdx.y; 62 | const int t = threadIdx.x << 2; 63 | const int ti = t + T * i; 64 | const int tj = T * (B * C) / BB; 65 | 66 | __shared__ F ww[Tmax]; 67 | __shared__ F kk[Tmax * BB]; 68 | __shared__ F gg[Tmax * BB]; 69 | F4(ww, t) = F4(__w, t + T * (i % C)); 70 | 71 | #pragma unroll 72 | for (int j = 0; j < BB; j++) { 73 | F4(kk, t + Tmax * j) = F4(__k, ti + tj * j); 74 | F4(gg, t + Tmax * j) = F4(__gwk, ti + tj * j); 75 | } 76 | __syncthreads(); 77 | 78 | float4 ss[BB]; 79 | #pragma unroll 80 | for (int j = 0; j < BB; j++) { 81 | ss[j] = {0, 0, 0, 0}; 82 | } 83 | for (int u = 0; u <= t; u++) { 84 | #pragma unroll 85 | for (int j = 0; j < BB; j++) { 86 | float4 *__restrict__ const s = ss + j; 87 | const F *__restrict__ const g = gg + Tmax * j + T - t + u - 4; 88 | const F k = kk[u + Tmax * j]; 89 | s->x += g[3] * k; 90 | s->y += g[2] * k; 91 | s->z += g[1] * k; 92 | s->w += g[0] * k; 93 | } 94 | } 95 | #pragma unroll 96 | for (int j = 0; j < BB; j++) { 97 | float4 *__restrict__ const s = ss + j; 98 | const F *__restrict__ const k = kk + Tmax * j + t + 1; 99 | const F *__restrict__ const g = gg + Tmax * j + T - 3; 100 | s->y += g[2] * k[0]; 101 | s->z += g[1] * k[0]; 102 | s->z += g[2] * k[1]; 103 | s->w += g[0] * k[0]; 104 | s->w += g[1] * k[1]; 105 | s->w += g[2] * k[2]; 106 | F4(gw, ti + tj * j) = *s; 107 | } 108 | 109 | #pragma unroll 110 | for (int j = 0; j < BB; j++) { 111 | ss[j] = {0, 0, 0, 0}; 112 | } 113 | for (int u = t + 3; u < T; u++) { 114 | const F w = ww[u]; 115 | #pragma unroll 116 | for (int j = 0; j < BB; j++) { 117 | float4 *__restrict__ const s = ss + j; 118 | const F *__restrict__ const g = gg + Tmax * j + T + t - u - 1; 119 | s->x += g[0] * w; 120 | s->y += g[1] * w; 121 | s->z += g[2] * w; 122 | s->w += g[3] * w; 123 | } 124 | } 125 | #pragma unroll 126 | for (int j = 0; j < BB; j++) { 127 | float4 *__restrict__ const s = ss + j; 128 | const F *__restrict__ const g = gg + Tmax * j + T - 3; 129 | const F *__restrict__ const w = ww + t; 130 | s->x += g[2] * w[0]; 131 | s->x += g[1] * w[1]; 132 | s->x += g[0] * w[2]; 133 | s->y += g[2] * w[1]; 134 | s->y += g[1] * w[2]; 135 | s->z += g[2] * w[2]; 136 | F4(gk, ti + tj * j) = *s; 137 | } 138 | } 139 | 140 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) { 141 | dim3 gridDim(1, B * C / BF); 142 | dim3 blockDim(T >> 2); 143 | kernel_forward<<>>(w, k, x, eps, B, C, T); 144 | } 145 | 146 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) { 147 | dim3 gridDim(1, B * C / BB); 148 | dim3 blockDim(T >> 2); 149 | kernel_backward<<>>(w, k, gwk, gw, gk, B, C, T); 150 | } 151 | -------------------------------------------------------------------------------- /wkv5_bf16/cuda/wkv5_cuda_v1a.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 9 | F *__restrict__ const _y) 10 | { 11 | const int b = blockIdx.x / H; 12 | const int h = blockIdx.x % H; 13 | const int i = threadIdx.x; 14 | _w += h*_N_; 15 | _u += h*_N_; 16 | 17 | __shared__ float r[_N_], k[_N_], u[_N_]; 18 | float state[_N_] = {0}; 19 | 20 | __syncthreads(); 21 | u[i] = float(_u[i]); 22 | __syncthreads(); 23 | 24 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 25 | { 26 | __syncthreads(); 27 | r[i] = float(_r[t]); 28 | k[i] = float(_k[t]); 29 | __syncthreads(); 30 | 31 | const float v = float(_v[t]); 32 | float y = 0; 33 | 34 | #pragma unroll 35 | for (int j = 0; j < _N_; j++) 36 | { 37 | float x = k[j] * v; 38 | float& s = state[j]; 39 | 40 | y += r[j] * (u[j] * x + s); 41 | s = s * _w[j] + x; 42 | } 43 | _y[t] = F(y); 44 | } 45 | } 46 | 47 | template 48 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 49 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, 50 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 51 | { 52 | const int b = blockIdx.x / H; 53 | const int h = blockIdx.x % H; 54 | const int i = threadIdx.x; 55 | _w += h*_N_; 56 | _u += h*_N_; 57 | __w += h*_N_; 58 | const float w = _w[i]; 59 | const float u = float(_u[i]); 60 | const float ww = __w[i]; 61 | 62 | __shared__ float v[_N_], r[_N_], k[_N_], gy[_N_], w_[_N_], u_[_N_]; 63 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}; 64 | 65 | float gw = 0, gu = 0; 66 | const int t000 = b*T*C + h*_N_ + i; 67 | const int t111 = (b+1)*T*C + h*_N_ + i; 68 | const int t222 = t111 - 2*C; 69 | 70 | for (int _t = t000; _t < t111; _t += C) 71 | { 72 | __syncthreads(); 73 | v[i] = float(_v[_t]); 74 | gy[i] = float(_gy[_t]); 75 | __syncthreads(); 76 | 77 | const float k = float(_k[_t]); 78 | const float r = float(_r[_t]); 79 | 80 | float gr = 0; 81 | 82 | #pragma unroll 83 | for (int j = 0; j < _N_; j++) 84 | { 85 | float x = v[j] * k; 86 | float& s = state[j]; 87 | 88 | gr += gy[j] * (u * x + s); 89 | gu += r * x * gy[j]; 90 | s = s * w + x; 91 | } 92 | _gr[_t] = F(gr); 93 | } 94 | _gu[b*C + h*_N_ + i] = F(gu); 95 | 96 | for (int _t = t000; _t < t222; _t += C) 97 | { 98 | __syncthreads(); 99 | v[i] = float(_v[_t]); 100 | gy[i] = float(_gy[_t + 2*C]); 101 | __syncthreads(); 102 | 103 | const float k = float(_k[_t]); 104 | const float r = float(_r[_t + 2*C]); 105 | 106 | #pragma unroll 107 | for (int j = 0; j < _N_; j++) 108 | { 109 | float x = v[j] * k; 110 | saaaa[j] = w * (saaaa[j] + sbbbb[j] + x); 111 | sbbbb[j] = w * (sbbbb[j] + x); 112 | 113 | gw += r * ww * saaaa[j] * gy[j]; 114 | } 115 | } 116 | _gw[b*C + h*_N_ + i] = F(gw); 117 | 118 | #pragma unroll 119 | for (int j = 0; j < _N_; ++j) { 120 | saaaa[j] = 0; 121 | sbbbb[j] = 0; 122 | } 123 | 124 | __syncthreads(); 125 | w_[i] = _w[i]; 126 | u_[i] = float(_u[i]); 127 | __syncthreads(); 128 | 129 | for (int _t = t111 - C; _t >= t000; _t -= C) 130 | { 131 | __syncthreads(); 132 | r[i] = float(_r[_t]); 133 | k[i] = float(_k[_t]); 134 | v[i] = float(_v[_t]); 135 | gy[i] = float(_gy[_t]); 136 | __syncthreads(); 137 | 138 | float gk = 0, gv = 0; 139 | 140 | #pragma unroll 141 | for (int j = 0; j < _N_; j++) 142 | { 143 | float x = gy[j] * r[i]; 144 | float& s = saaaa[j]; 145 | gk += v[j] * (u * x + s); 146 | s = s * w + x; 147 | 148 | float x2 = gy[i] * r[j]; 149 | float& s2 = sbbbb[j]; 150 | gv += k[j] * (u_[j] * x2 + s2); 151 | s2 = s2 * w_[j] + x2; 152 | } 153 | _gk[_t] = F(gk); 154 | _gv[_t] = F(gv); 155 | } 156 | } 157 | 158 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 159 | { 160 | assert(H*_N_ == C); 161 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 162 | } 163 | 164 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 165 | { 166 | assert(H*_N_ == C); 167 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); 168 | } 169 | -------------------------------------------------------------------------------- /wkv5_bf16/cuda/wkv5_cuda_v1.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 9 | F *__restrict__ const _y) 10 | { 11 | const int b = blockIdx.x / H; 12 | const int h = blockIdx.x % H; 13 | const int i = threadIdx.x; 14 | _w += h*_N_; 15 | _u += h*_N_; 16 | 17 | __shared__ float r[_N_], k[_N_]; 18 | 19 | float state[_N_] = {0}; 20 | 21 | for (int _t = b*T*C + h*_N_ + i; _t < (b+1)*T*C + h*_N_ + i; _t += C) 22 | { 23 | __syncthreads(); 24 | r[i] = float(_r[_t]); 25 | k[i] = float(_k[_t]); 26 | __syncthreads(); 27 | 28 | const float v = float(_v[_t]); 29 | float y = 0; 30 | 31 | for (int j = 0; j < _N_; j++) 32 | { 33 | float x = k[j] * v; 34 | 35 | float s = state[j]; 36 | state[j] = s * _w[j] + x; 37 | 38 | y += r[j] * (float(_u[j]) * x + s); 39 | } 40 | _y[_t] = F(y); 41 | } 42 | } 43 | 44 | template 45 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 46 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, 47 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 48 | { 49 | const int b = blockIdx.x / H; 50 | const int h = blockIdx.x % H; 51 | const int i = threadIdx.x; 52 | _w += h*_N_; 53 | _u += h*_N_; 54 | __w += h*_N_; 55 | 56 | __shared__ float v[_N_], r[_N_], k[_N_], gy[_N_]; 57 | 58 | const float w = _w[i]; 59 | const float u = float(_u[i]); 60 | const float ww = __w[i]; 61 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}; 62 | float gw = 0, gu = 0; 63 | 64 | for (int _t = b*T*C + h*_N_ + i, _tend = (b+1)*T*C + h*_N_ + i; _t < _tend; _t += C) 65 | { 66 | __syncthreads(); 67 | v[i] = float(_v[_t]); 68 | gy[i] = float(_gy[_t]); 69 | __syncthreads(); 70 | 71 | const float k = float(_k[_t]); 72 | const float r = float(_r[_t]); 73 | float gr = 0; 74 | 75 | #pragma unroll 76 | for (int j = 0; j < _N_; j++) 77 | { 78 | float x = v[j] * k; 79 | float s = state[j]; 80 | state[j] = s * w + x; 81 | 82 | gr += gy[j] * (u * x + s); 83 | gu += r * x * gy[j]; 84 | } 85 | 86 | _gr[_t] = F(gr); 87 | 88 | if (_t < _tend - 2*C) 89 | { 90 | __syncthreads(); 91 | gy[i] = float(_gy[_t + 2*C]); 92 | __syncthreads(); 93 | 94 | const float r = float(_r[_t + 2*C]); 95 | 96 | #pragma unroll 97 | for (int j = 0; j < _N_; j++) 98 | { 99 | float x = v[j] * k; 100 | saaaa[j] = w * (saaaa[j] + sbbbb[j] + x); 101 | sbbbb[j] = w * (sbbbb[j] + x); 102 | 103 | gw += r * ww * saaaa[j] * gy[j]; 104 | } 105 | } 106 | } 107 | _gu[b*C + h*_N_ + i] = F(gu); 108 | _gw[b*C + h*_N_ + i] = F(gw); 109 | 110 | #pragma unroll 111 | for (int j = 0; j < _N_; ++j) 112 | state[j] = 0; 113 | 114 | for (int _t = (b+1)*T*C + h*_N_ + i - C, _tend = b*T*C + h*_N_ + i; _t >= _tend; _t -= C) 115 | { 116 | __syncthreads(); 117 | v[i] = float(_v[_t]); 118 | gy[i] = float(_gy[_t]); 119 | __syncthreads(); 120 | 121 | const float r = float(_r[_t]); 122 | float gk = 0; 123 | 124 | #pragma unroll 125 | for (int j = 0; j < _N_; j++) 126 | { 127 | float x = gy[j] * r; 128 | float s = state[j]; 129 | state[j] = s * w + x; 130 | 131 | gk += v[j] * (u * x + s); 132 | } 133 | _gk[_t] = F(gk); 134 | } 135 | 136 | #pragma unroll 137 | for (int j = 0; j < _N_; ++j) 138 | state[j] = 0; 139 | 140 | for (int _t = (b+1)*T*C + h*_N_ + i - C, _tend = b*T*C + h*_N_ + i; _t >= _tend; _t -= C) 141 | { 142 | __syncthreads(); 143 | k[i] = float(_k[_t]); 144 | r[i] = float(_r[_t]); 145 | __syncthreads(); 146 | 147 | const float gy = float(_gy[_t]); 148 | float gv = 0; 149 | 150 | #pragma unroll 151 | for (int j = 0; j < _N_; j++) 152 | { 153 | float x = gy * r[j]; 154 | float s = state[j]; 155 | state[j] = s * float(_w[j]) + x; 156 | 157 | gv += k[j] * (float(_u[j]) * x + s); 158 | } 159 | _gv[_t] = F(gv); 160 | } 161 | } 162 | 163 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 164 | { 165 | assert(H*_N_ == C); 166 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 167 | } 168 | 169 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 170 | { 171 | assert(H*_N_ == C); 172 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); 173 | } 174 | -------------------------------------------------------------------------------- /wkv5/cuda/wkv5_cuda_v1.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | template 5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, 7 | F *__restrict__ const _y) 8 | { 9 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 10 | const int _b = idx / C; 11 | const int _h = (idx / N) % H; 12 | const int _i = idx % N; 13 | 14 | const int _o0 = _b*T*C + _h*N; 15 | const int _o1 = _h*N; 16 | const F *__restrict__ const k = _k + _o0; 17 | const F *__restrict__ const v = _v + _o0 + _i; 18 | const F *__restrict__ const r = _r + _o0; 19 | F *__restrict__ const y = _y + _o0 + _i; 20 | 21 | float state[N] = {0}; 22 | 23 | for (int __t = 0; __t < T; __t++) 24 | { 25 | const int _t = __t*C; 26 | const F vv = v[_t]; 27 | 28 | for (int _j = 0; _j < N; _j++) 29 | { 30 | const int j = _t + _j; 31 | const int m = _o1 + _j; 32 | 33 | const float x = k[j] * vv; 34 | const float s = state[_j]; 35 | 36 | atomicAdd(y + _t, r[j] * (_u[m] * x + s)); 37 | state[_j] = s * _w[m] + x; 38 | } 39 | } 40 | } 41 | 42 | template 43 | __global__ void kernel_backward (const int B, const int T, const int C, const int H, 44 | const F *__restrict__ const r, const F *__restrict__ const k, const F *__restrict__ const v, const F *__restrict__ w, const F *__restrict__ wwww, const F *__restrict__ u, const F *__restrict__ const gy, 45 | F *__restrict__ const gr, F *__restrict__ const gk, F *__restrict__ const gv, F *__restrict__ gw, F *__restrict__ gu) 46 | { 47 | const int b = blockIdx.x / H; 48 | const int h = blockIdx.x % H; 49 | const int i = threadIdx.x; 50 | w += h*N; 51 | u += h*N; 52 | gu += h*N; 53 | gw += h*N; 54 | wwww += h*N; 55 | 56 | __shared__ float state[N * N], vv[N], rr[N], kk[N], gyy[N]; 57 | 58 | #pragma unroll 59 | for (int j = 0; j < N; ++j){ 60 | state[j * N + i] = 0; 61 | } 62 | 63 | const float ww = w[i]; 64 | const float uu = u[i]; 65 | const float wwwww = wwww[i]; 66 | float saaaa[N] = {0.0f}, sbbbb[N] = {0.0f}; 67 | 68 | for (int _t = b*T*C + h*N + i, _tend = (b+1)*T*C + h*N + i; _t < _tend; _t += C) 69 | { 70 | const F kk = k[_t]; 71 | const F rr = r[_t]; 72 | F grr = 0; 73 | F guu = 0; 74 | 75 | vv[i] = v[_t]; 76 | gyy[i] = gy[_t]; 77 | 78 | __syncthreads(); 79 | 80 | #pragma unroll 81 | for (int j = 0; j < N; j++) 82 | { 83 | 84 | float x = vv[j] * kk; 85 | float s = state[j * N + i]; 86 | 87 | grr += gyy[j] * (uu * x + s); 88 | state[j * N + i] = s * ww + x; 89 | guu += rr * x * gyy[j]; 90 | 91 | } 92 | gr[_t] = grr; 93 | atomicAdd(gu + i, guu); 94 | 95 | __syncthreads(); 96 | if (_t < _tend - 2 * C){ 97 | const F rr_value = r[_t+2*C]; 98 | gyy[i] = gy[_t+2*C]; 99 | __syncthreads(); 100 | 101 | #pragma unroll 102 | for (int j = 0; j < N; j++){ 103 | float x = vv[j] * kk; 104 | saaaa[j] = ww * (saaaa[j] + sbbbb[j] + x); 105 | sbbbb[j] = ww * (sbbbb[j] + x); 106 | atomicAdd(gw+i, rr_value * wwwww * saaaa[j] * gyy[j]); 107 | } 108 | 109 | __syncthreads(); 110 | } 111 | } 112 | 113 | #pragma unroll 114 | for (int j = 0; j < N; ++j) 115 | state[j * N + i] = 0; 116 | 117 | for (int _t = (b+1)*T*C + h*N + i - C, _tend = b*T*C + h*N + i; _t >= _tend; _t -= C) 118 | { 119 | const F rr = r[_t]; 120 | F gkk = 0; 121 | 122 | vv[i] = v[_t]; 123 | gyy[i] = gy[_t]; 124 | 125 | __syncthreads(); 126 | 127 | #pragma unroll 128 | for (int j = 0; j < N; j++) 129 | { 130 | 131 | float x = gyy[j] * rr; 132 | float s = state[j * N + i]; 133 | 134 | gkk += vv[j] * (uu * x + s); 135 | state[j * N + i] = s * ww + x; 136 | } 137 | gk[_t] = gkk; 138 | __syncthreads(); 139 | } 140 | 141 | #pragma unroll 142 | for (int j = 0; j < N; ++j) 143 | state[j * N + i] = 0; 144 | 145 | for (int _t = (b+1)*T*C + h*N + i - C, _tend = b*T*C + h*N + i; _t >= _tend; _t -= C) 146 | { 147 | const F gy_value = gy[_t]; 148 | F gvv = 0; 149 | 150 | kk[i] = k[_t]; 151 | rr[i] = r[_t]; 152 | 153 | __syncthreads(); 154 | 155 | #pragma unroll 156 | for (int j = 0; j < N; j++) 157 | { 158 | 159 | float x = gy_value * rr[j]; 160 | float s = state[j * N + i]; 161 | 162 | gvv += kk[j] * (u[j] * x + s); 163 | state[j * N + i] = s * w[j] + x; 164 | } 165 | gv[_t] = gvv; 166 | __syncthreads(); 167 | } 168 | } 169 | 170 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y) 171 | { 172 | dim3 threadsPerBlock( min(B*C, 32) ); 173 | assert(B * C % threadsPerBlock.x == 0); 174 | dim3 numBlocks(B * C / threadsPerBlock.x); 175 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 176 | } 177 | 178 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *ww, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu) 179 | { 180 | assert(H*N == C); 181 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); 182 | } 183 | -------------------------------------------------------------------------------- /wkv5/cuda/wkv5_cuda_v2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | template 5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, 7 | F *__restrict__ const _y) 8 | { 9 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 10 | const int _b = idx / C; 11 | const int _h = (idx / N) % H; 12 | const int _i = idx % N; 13 | 14 | const int _o0 = _b * T * C + _h * N; 15 | const int _o1 = _h * N; 16 | 17 | const float4 *__restrict__ const k = (float4 *)(_k + _o0); 18 | const float4 *__restrict__ const r = (float4 *)(_r + _o0); 19 | const float4 *__restrict__ const w = (float4 *)(_w + _o1); 20 | const float4 *__restrict__ const u = (float4 *)(_u + _o1); 21 | const F *__restrict__ const v = _v + _o0 + _i; 22 | F *__restrict__ const y = _y + _o0 + _i; 23 | 24 | __align__(16) float4 state[N / 4] = { make_float4(0.0f, 0.0f, 0.0f, 0.0f) }; 25 | 26 | for (int __t = 0; __t < T; __t++) 27 | { 28 | const int _t = __t * (C >> 2); 29 | const int tt = __t * C; 30 | const F vv = v[tt]; 31 | 32 | for (int _j = 0; _j < N / 4; _j++) 33 | { 34 | const int j = _t + _j; 35 | 36 | const float4 k_val = k[j]; 37 | const float4 r_val = r[j]; 38 | float4 x; 39 | x.x = k_val.x * vv; 40 | x.y = k_val.y * vv; 41 | x.z = k_val.z * vv; 42 | x.w = k_val.w * vv; 43 | 44 | float4 s = state[_j]; 45 | 46 | float4 result; 47 | result.x = r_val.x * (u[_j].x * x.x + s.x); 48 | result.y = r_val.y * (u[_j].y * x.y + s.y); 49 | result.z = r_val.z * (u[_j].z * x.z + s.z); 50 | result.w = r_val.w * (u[_j].w * x.w + s.w); 51 | 52 | atomicAdd(&(y[tt]), result.x); 53 | atomicAdd(&(y[tt]), result.y); 54 | atomicAdd(&(y[tt]), result.z); 55 | atomicAdd(&(y[tt]), result.w); 56 | 57 | state[_j].x = s.x * w[_j].x + x.x; 58 | state[_j].y = s.y * w[_j].y + x.y; 59 | state[_j].z = s.z * w[_j].z + x.z; 60 | state[_j].w = s.w * w[_j].w + x.w; 61 | } 62 | } 63 | } 64 | 65 | template 66 | __global__ void kernel_backward (const int B, const int T, const int C, const int H, 67 | const F *__restrict__ const r, const F *__restrict__ const k, const F *__restrict__ const v, const F *__restrict__ const w, const F *__restrict__ const wwww, const F *__restrict__ const _u, const F *__restrict__ const gy, 68 | F *__restrict__ const gr, F *__restrict__ const gk, F *__restrict__ const gv, F *__restrict__ const gw, F *__restrict__ const gu) 69 | { 70 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; // B * H * T * N 71 | const int b = idx / H / T / N; 72 | const int h = (idx / T / N) % H; 73 | const int t = (idx / N) % T; 74 | const int n = idx % N; 75 | const int index1 = b*T*H*N + t*H*N + h*N + n; 76 | const F& w_h_n = w[h*N+n]; 77 | const F& u_h_n = _u[h*N+n]; 78 | F& gu_h_n = gu[h*N+n]; 79 | F& gw_h_n = gw[h*N+n]; 80 | F w_pow[4096]; 81 | for (int t =0; t < T; t++){ 82 | w_pow[t] = pow(w_h_n, t); 83 | } 84 | F &gr_index1 = gr[index1]; 85 | F &gk_index1 = gk[index1]; 86 | F &gv_index1 = gv[index1]; 87 | const F& r_index1 = r[index1]; 88 | const F& k_index1 = k[index1]; 89 | 90 | for(int nn = 0; nn < N; nn++){ 91 | const F& u_h_nn = _u[h*N + nn]; 92 | const F& w_h_nn = w[h*N + nn]; 93 | const int index2 = b*T*H*N + t*H*N + h*N + nn; 94 | for (int tt = 0; tt <= t; tt++) { 95 | const int index3 = b*T*H*N + tt*H*N + h*N + n; 96 | const int index4 = b*T*H*N + tt*H*N + h*N + nn; 97 | F ww = (tt == t) ? u_h_n : (t-tt-1 >= 0 ? w_pow[t-tt-1] : pow(w_h_n, t-tt-1)); 98 | gr_index1 += ww * gy[index2] * k[index3] * v[index4]; 99 | } 100 | 101 | for (int tt = t; tt < T; tt++) { 102 | const int index3 = b*T*H*N + tt*H*N + h*N + n; 103 | const int index4 = b*T*H*N + tt*H*N + h*N + nn; 104 | F ww = (tt == t) ? u_h_n : (tt-t-1>=0 ? w_pow[tt-t-1] : pow(w_h_n, tt-t-1)); 105 | gk_index1 += ww * v[index2] * r[index3] * gy[index4]; 106 | ww = (tt == t) ? u_h_nn : pow(w_h_nn, tt-t-1); 107 | gv_index1 += ww * k[index2] * gy[index3] * r[index4]; 108 | } 109 | 110 | atomicAdd(&gu_h_n, r_index1 * k_index1 * v[index2] * gy[index2]); 111 | 112 | for (int tt = 0; tt < t-1; tt++) { 113 | const int index3 = b*T*H*N + tt*H*N + h*N + n; 114 | const int index4 = b*T*H*N + tt*H*N + h*N + nn; 115 | F ww = (t-tt-1) * wwww[h*N + n] * (t-tt-1 >= 0 ? w_pow[t-tt-1] : pow(w_h_n, t-tt-1)); 116 | 117 | atomicAdd(&gw_h_n, r_index1 * ww * k[index3] * v[index4] * gy[index2]); 118 | } 119 | } 120 | } 121 | 122 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y) 123 | { 124 | assert(H*N == C); 125 | const int SIZE = B*C; 126 | dim3 threadsPerBlock(min(SIZE, 32)); 127 | assert(SIZE % threadsPerBlock.x == 0); 128 | dim3 numBlocks(SIZE / threadsPerBlock.x); 129 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 130 | } 131 | 132 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *ww, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu) 133 | { 134 | dim3 threadsPerBlock( min(B*H*T*N, 32) ); 135 | assert(B * H * T * N % threadsPerBlock.x == 0); 136 | dim3 numBlocks(B * H * T * N / threadsPerBlock.x); 137 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); 138 | } 139 | -------------------------------------------------------------------------------- /rwkv7_fast_fused/cuda/rwkv7_clampw.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #ifdef _FP32_ 4 | using bf = float; 5 | #define to_float(u) (u) 6 | #define to_bf(u) (u) 7 | #else 8 | #include 9 | using bf = __nv_bfloat16; 10 | #define to_float(u) (__bfloat162float(u)) 11 | #define to_bf(u) (__float2bfloat16_rn(u)) 12 | #endif 13 | 14 | using i64 = long long int; 15 | typedef bf * __restrict__ F_; 16 | constexpr float W_SCALE = -0.6065306597f; // -exp(-0.5) 17 | 18 | //###################################################################################################### 19 | 20 | template __launch_bounds__(N,2) 21 | __global__ void forward_kernel(int T,int H,F_ r_,F_ w_,F_ k_,F_ v_,F_ a_,F_ b_,bf* __restrict__ y_,float* s__,float* __restrict__ sa_) 22 | { 23 | const int bb=blockIdx.y, hh=blockIdx.x, i=threadIdx.x; 24 | float* __restrict__ s_ = s__ + i64(bb*H+hh) * i64((T/_CHUNK_LEN_)*N*N); 25 | float state[N]; 26 | #pragma unroll 27 | for (int j=0; j<<>>(T,H,r,w,k,v,a,b,y,s,sa); 79 | } 80 | 81 | //###################################################################################################### 82 | 83 | template 84 | __global__ void backward_kernel(int T, int H, F_ r_, F_ w_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s__, float * __restrict__ sa_, bf* dr_, bf* dw_, bf* dk_, bf* dv_, bf* da_, bf* db_) 85 | { 86 | int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x; 87 | float* __restrict__ s_ = s__ + i64(bb*H+hh) * i64((T/_CHUNK_LEN_)*N*N); 88 | 89 | float stateT[N] = {0}, dstate[N] = {0}, dstateT[N] = {0}; 90 | __shared__ float r[N], w[N], k[N], v[N], a[N], b[N], dy[N], sa[N], dSb_shared[N]; 91 | float ri, wi, ki, ai, bi, dyi; 92 | 93 | for (int t = T-1; t >= 0; t--) 94 | { 95 | int idx = bb*T*H*N + t*H*N + hh * N + i; 96 | 97 | __syncthreads(); 98 | r[i] = ri = to_float(r_[idx]); 99 | float w_sig = 1.0f / (1.0f + __expf(-to_float(w_[idx]))); 100 | w[i] = wi = __expf(W_SCALE * w_sig); 101 | k[i] = ki = to_float(k_[idx]); 102 | v[i] = to_float(v_[idx]); 103 | a[i] = ai = to_float(a_[idx]); 104 | b[i] = bi = to_float(b_[idx]); 105 | dy[i] = dyi = to_float(dy_[idx]); 106 | sa[i] = sa_[idx]; 107 | __syncthreads(); 108 | 109 | if ((t+1)%_CHUNK_LEN_ == 0) { 110 | int base = (t/_CHUNK_LEN_)*N*N + i*N; 111 | const float4* s4 = (const float4*)(s_ + base); 112 | #pragma unroll 113 | for (int j4 = 0; j4 < N/4; j4++) { 114 | float4 q = s4[j4]; 115 | const int j = j4<<2; 116 | stateT[j+0] = q.x; 117 | stateT[j+1] = q.y; 118 | stateT[j+2] = q.z; 119 | stateT[j+3] = q.w; 120 | } 121 | } 122 | 123 | float dr = 0; 124 | #pragma unroll 125 | for (int j = 0; j < N; j++) { 126 | dr += stateT[j] * dy[j]; 127 | } 128 | dr_[idx] = to_bf(dr); 129 | 130 | float iwi = 1.0f / wi; 131 | #pragma unroll 132 | for (int j = 0; j < N; j++) { 133 | stateT[j] = (stateT[j] - ki * v[j] - bi * sa[j]) * iwi; 134 | dstate[j] += dyi * r[j]; 135 | dstateT[j] += ri * dy[j]; 136 | } 137 | 138 | float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0; 139 | #pragma unroll 140 | for (int j = 0; j < N; j++) { 141 | dw += dstateT[j] * stateT[j]; 142 | dk += dstateT[j] * v[j]; 143 | dv += dstate[j] * k[j]; 144 | dSb += dstate[j] * b[j]; 145 | db += dstateT[j] * sa[j]; 146 | } 147 | dw_[idx] = to_bf(W_SCALE * dw * wi * w_sig * (1.0f - w_sig)); 148 | 149 | dk_[idx] = to_bf(dk); 150 | dv_[idx] = to_bf(dv); 151 | db_[idx] = to_bf(db); 152 | 153 | __syncthreads(); 154 | dSb_shared[i] = dSb; 155 | __syncthreads(); 156 | 157 | float da = 0; 158 | #pragma unroll 159 | for (int j = 0; j < N; j++) { 160 | da += stateT[j]*dSb_shared[j]; 161 | } 162 | da_[idx] = to_bf(da); 163 | 164 | #pragma unroll 165 | for (int j = 0; j < N; j++) { 166 | dstate[j] = dstate[j] * w[j] + dSb * a[j]; 167 | dstateT[j] = dstateT[j] * wi + ai * dSb_shared[j]; 168 | } 169 | } 170 | } 171 | 172 | void cuda_backward(int B, int T, int H, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*dy, float*s, float*sa, bf*dr, bf*dw, bf*dk, bf*dv, bf*da, bf*db) 173 | { 174 | assert(T%_CHUNK_LEN_ == 0); 175 | backward_kernel<_N_><<>>(T,H,r,w,k,v,a,b,dy,s,sa,dr,dw,dk,dv,da,db); 176 | } 177 | -------------------------------------------------------------------------------- /wkv5_bf16/cuda/wkv5_cuda_v1b.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 9 | F *__restrict__ const _y) 10 | { 11 | const int b = blockIdx.x / H; 12 | const int h = blockIdx.x % H; 13 | const int i = threadIdx.x; 14 | _w += h*_N_; 15 | _u += h*_N_; 16 | 17 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 18 | float state[_N_] = {0}; 19 | 20 | __syncthreads(); 21 | u[i] = float(_u[i]); 22 | w[i] = float(_w[i]); 23 | __syncthreads(); 24 | 25 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 26 | { 27 | __syncthreads(); 28 | r[i] = float(_r[t]); 29 | k[i] = float(_k[t]); 30 | __syncthreads(); 31 | 32 | const float v = float(_v[t]); 33 | float y = 0; 34 | 35 | #pragma unroll 36 | for (int j = 0; j < _N_; j+=4) 37 | { 38 | const float4& r_ = (float4&)(r[j]); 39 | const float4& k_ = (float4&)(k[j]); 40 | const float4& w_ = (float4&)(w[j]); 41 | const float4& u_ = (float4&)(u[j]); 42 | float4& s = (float4&)(state[j]); 43 | float4 x; 44 | 45 | x.x = k_.x * v; 46 | x.y = k_.y * v; 47 | x.z = k_.z * v; 48 | x.w = k_.w * v; 49 | 50 | y += r_.x * (u_.x * x.x + s.x); 51 | y += r_.y * (u_.y * x.y + s.y); 52 | y += r_.z * (u_.z * x.z + s.z); 53 | y += r_.w * (u_.w * x.w + s.w); 54 | 55 | s.x = s.x * w_.x + x.x; 56 | s.y = s.y * w_.y + x.y; 57 | s.z = s.z * w_.z + x.z; 58 | s.w = s.w * w_.w + x.w; 59 | } 60 | _y[t] = F(y); 61 | } 62 | } 63 | 64 | template 65 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 66 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, 67 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 68 | { 69 | const int b = blockIdx.x / H; 70 | const int h = blockIdx.x % H; 71 | const int i = threadIdx.x; 72 | _w += h*_N_; 73 | _u += h*_N_; 74 | __w += h*_N_; 75 | const float w = _w[i]; 76 | const float u = float(_u[i]); 77 | const float ww = __w[i]; 78 | 79 | __shared__ float v[_N_], r[_N_], k[_N_], gy[_N_], w_[_N_], u_[_N_]; 80 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}; 81 | 82 | float gw = 0, gu = 0; 83 | const int t000 = b*T*C + h*_N_ + i; 84 | const int t111 = (b+1)*T*C + h*_N_ + i; 85 | const int t222 = t111 - 2*C; 86 | 87 | for (int t = t000; t < t111; t += C) 88 | { 89 | __syncthreads(); 90 | v[i] = float(_v[t]); 91 | gy[i] = float(_gy[t]); 92 | __syncthreads(); 93 | 94 | const float k = float(_k[t]); 95 | float gr = 0, _gu_ = 0; 96 | 97 | #pragma unroll 98 | for (int j = 0; j < _N_; j++) 99 | { 100 | float& s = state[j]; 101 | float x = k * v[j]; 102 | 103 | gr += (u * x + s) * gy[j]; 104 | _gu_ += x * gy[j]; 105 | s = s * w + x; 106 | } 107 | _gr[t] = F(gr); 108 | gu += float(_r[t]) * _gu_; 109 | } 110 | _gu[b*C + h*_N_ + i] = F(gu); 111 | 112 | for (int t = t000; t < t222; t += C) 113 | { 114 | __syncthreads(); 115 | v[i] = float(_v[t]); 116 | gy[i] = float(_gy[t + 2*C]); 117 | __syncthreads(); 118 | 119 | const float k = float(_k[t]); 120 | const float r = float(_r[t + 2*C]); 121 | 122 | #pragma unroll 123 | for (int j = 0; j < _N_; j++) 124 | { 125 | float& s = saaaa[j]; 126 | float& s2 = sbbbb[j]; 127 | float x = k * v[j]; 128 | 129 | float tmp = w * (x + s); 130 | s = tmp; 131 | s2 = tmp + w * s2; 132 | gw += r * s2 * gy[j]; 133 | } 134 | } 135 | _gw[b*C + h*_N_ + i] = F(ww * gw); 136 | 137 | #pragma unroll 138 | for (int j = 0; j < _N_; ++j) { 139 | saaaa[j] = 0; 140 | sbbbb[j] = 0; 141 | } 142 | 143 | __syncthreads(); 144 | w_[i] = _w[i]; 145 | u_[i] = float(_u[i]); 146 | __syncthreads(); 147 | 148 | for (int t = t111 - C; t >= t000; t -= C) 149 | { 150 | __syncthreads(); 151 | r[i] = float(_r[t]); 152 | k[i] = float(_k[t]); 153 | v[i] = float(_v[t]); 154 | gy[i] = float(_gy[t]); 155 | __syncthreads(); 156 | 157 | const float rr = r[i]; 158 | const float gyy = gy[i]; 159 | float gk = 0, gv = 0; 160 | 161 | #pragma unroll 162 | for (int j = 0; j < _N_; j++) 163 | { 164 | float& s = saaaa[j]; 165 | float& s2 = sbbbb[j]; 166 | float x = rr * gy[j]; 167 | float x2 = gyy * r[j]; 168 | 169 | gk += (u * x + s) * v[j]; 170 | gv += (u_[j] * x2 + s2) * k[j]; 171 | s = x + s * w; 172 | s2 = x2 + s2 * w_[j]; 173 | } 174 | _gk[t] = F(gk); 175 | _gv[t] = F(gv); 176 | } 177 | } 178 | 179 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 180 | { 181 | assert(H*_N_ == C); 182 | assert(_N_%4 == 0); 183 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 184 | } 185 | 186 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 187 | { 188 | assert(H*_N_ == C); 189 | assert(_N_%4 == 0); 190 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); 191 | } 192 | -------------------------------------------------------------------------------- /rwkv7_fast_fused/cuda/rwkv7_state_clampw.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #ifdef _FP32_ 4 | using bf = float; 5 | #define to_float(u) (u) 6 | #define to_bf(u) (u) 7 | #else 8 | #include 9 | using bf = __nv_bfloat16; 10 | #define to_float(u) (__bfloat162float(u)) 11 | #define to_bf(u) (__float2bfloat16_rn(u)) 12 | #endif 13 | 14 | using i64 = long long int; 15 | typedef bf * __restrict__ F_; 16 | constexpr float W_SCALE = -0.6065306597f; // -exp(-0.5) 17 | 18 | //###################################################################################################### 19 | 20 | template __launch_bounds__(N,2) 21 | __global__ void forward_kernel(int T,int H,float *__restrict__ s0_,F_ r_,F_ w_,F_ k_,F_ v_,F_ a_,F_ b_,bf* __restrict__ y_,float* s__,float* __restrict__ sa_) 22 | { 23 | const int bb=blockIdx.y, hh=blockIdx.x, i=threadIdx.x; 24 | float* __restrict__ s_ = s__ + i64(bb*H+hh) * i64((T/_CHUNK_LEN_)*N*N); 25 | s0_ += i64(bb*H+hh) * i64(N*N) + i64(i*N); 26 | float state[N]; 27 | #pragma unroll 28 | for (int j=0; j<<>>(T,H,s0,r,w,k,v,a,b,y,s,sa); 80 | } 81 | 82 | //###################################################################################################### 83 | 84 | template 85 | __global__ void backward_kernel(int T, int H, F_ r_, F_ w_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s__, float * __restrict__ sa_, float* ds0_, bf* dr_, bf* dw_, bf* dk_, bf* dv_, bf* da_, bf* db_) 86 | { 87 | int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x; 88 | float* __restrict__ s_ = s__ + i64(bb*H+hh) * i64((T/_CHUNK_LEN_)*N*N); 89 | ds0_ += i64(bb*H+hh) * i64(N*N) + i64(i*N); 90 | 91 | float stateT[N] = {0}, dstate[N] = {0}, dstateT[N] = {0}; 92 | __shared__ float r[N], w[N], k[N], v[N], a[N], b[N], dy[N], sa[N], dSb_shared[N]; 93 | float ri, wi, ki, ai, bi, dyi; 94 | 95 | for (int t = T-1; t >= 0; t--) 96 | { 97 | int idx = bb*T*H*N + t*H*N + hh * N + i; 98 | 99 | __syncthreads(); 100 | r[i] = ri = to_float(r_[idx]); 101 | float w_sig = 1.0f / (1.0f + __expf(-to_float(w_[idx]))); 102 | w[i] = wi = __expf(W_SCALE * w_sig); 103 | k[i] = ki = to_float(k_[idx]); 104 | v[i] = to_float(v_[idx]); 105 | a[i] = ai = to_float(a_[idx]); 106 | b[i] = bi = to_float(b_[idx]); 107 | dy[i] = dyi = to_float(dy_[idx]); 108 | sa[i] = sa_[idx]; 109 | __syncthreads(); 110 | 111 | if ((t+1)%_CHUNK_LEN_ == 0) { 112 | int base = (t/_CHUNK_LEN_)*N*N + i*N; 113 | const float4* s4 = (const float4*)(s_ + base); 114 | #pragma unroll 115 | for (int j4 = 0; j4 < N/4; j4++) { 116 | float4 q = s4[j4]; 117 | const int j = j4<<2; 118 | stateT[j+0] = q.x; 119 | stateT[j+1] = q.y; 120 | stateT[j+2] = q.z; 121 | stateT[j+3] = q.w; 122 | } 123 | } 124 | 125 | float dr = 0; 126 | #pragma unroll 127 | for (int j = 0; j < N; j++) { 128 | dr += stateT[j] * dy[j]; 129 | } 130 | dr_[idx] = to_bf(dr); 131 | 132 | float iwi = 1.0f / wi; 133 | #pragma unroll 134 | for (int j = 0; j < N; j++) { 135 | stateT[j] = (stateT[j] - ki * v[j] - bi * sa[j]) * iwi; 136 | dstate[j] += dyi * r[j]; 137 | dstateT[j] += ri * dy[j]; 138 | } 139 | 140 | float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0; 141 | #pragma unroll 142 | for (int j = 0; j < N; j++) { 143 | dw += dstateT[j] * stateT[j]; 144 | dk += dstateT[j] * v[j]; 145 | dv += dstate[j] * k[j]; 146 | dSb += dstate[j] * b[j]; 147 | db += dstateT[j] * sa[j]; 148 | } 149 | dw_[idx] = to_bf(W_SCALE * dw * wi * w_sig * (1.0f - w_sig)); 150 | 151 | dk_[idx] = to_bf(dk); 152 | dv_[idx] = to_bf(dv); 153 | db_[idx] = to_bf(db); 154 | 155 | __syncthreads(); 156 | dSb_shared[i] = dSb; 157 | __syncthreads(); 158 | 159 | float da = 0; 160 | #pragma unroll 161 | for (int j = 0; j < N; j++) { 162 | da += stateT[j]*dSb_shared[j]; 163 | } 164 | da_[idx] = to_bf(da); 165 | 166 | #pragma unroll 167 | for (int j = 0; j < N; j++) { 168 | dstate[j] = dstate[j] * w[j] + dSb * a[j]; 169 | dstateT[j] = dstateT[j] * wi + ai * dSb_shared[j]; 170 | } 171 | } 172 | #pragma unroll 173 | for (int j = 0; j < N; j++) { 174 | ds0_[j] = dstate[j]; 175 | } 176 | } 177 | 178 | void cuda_backward(int B, int T, int H, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*dy, float*s, float*sa, float*ds0, bf*dr, bf*dw, bf*dk, bf*dv, bf*da, bf*db) 179 | { 180 | assert(T%_CHUNK_LEN_ == 0); 181 | backward_kernel<_N_><<>>(T,H,r,w,k,v,a,b,dy,s,sa,ds0,dr,dw,dk,dv,da,db); 182 | } 183 | -------------------------------------------------------------------------------- /wkv6/cuda/wkv5_cuda_v1b2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 9 | F *__restrict__ const _y) 10 | { 11 | const int b = blockIdx.x / H; 12 | const int h = blockIdx.x % H; 13 | const int i = threadIdx.x; 14 | _w += h*_N_; 15 | _u += h*_N_; 16 | 17 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 18 | float state[_N_] = {0}; 19 | 20 | __syncthreads(); 21 | w[i] = _w[i]; 22 | u[i] = float(_u[i]); 23 | __syncthreads(); 24 | 25 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 26 | { 27 | __syncthreads(); 28 | r[i] = float(_r[t]); 29 | k[i] = float(_k[t]); 30 | __syncthreads(); 31 | 32 | const float v = float(_v[t]); 33 | float y = 0; 34 | 35 | #pragma unroll 36 | for (int j = 0; j < _N_; j+=4) 37 | { 38 | const float4& r_ = (float4&)(r[j]); 39 | const float4& k_ = (float4&)(k[j]); 40 | const float4& w_ = (float4&)(w[j]); 41 | const float4& u_ = (float4&)(u[j]); 42 | float4& s = (float4&)(state[j]); 43 | float4 x; 44 | 45 | x.x = k_.x * v; 46 | x.y = k_.y * v; 47 | x.z = k_.z * v; 48 | x.w = k_.w * v; 49 | 50 | y += r_.x * (u_.x * x.x + s.x); 51 | y += r_.y * (u_.y * x.y + s.y); 52 | y += r_.z * (u_.z * x.z + s.z); 53 | y += r_.w * (u_.w * x.w + s.w); 54 | 55 | s.x = s.x * w_.x + x.x; 56 | s.y = s.y * w_.y + x.y; 57 | s.z = s.z * w_.z + x.z; 58 | s.w = s.w * w_.w + x.w; 59 | } 60 | _y[t] = F(y); 61 | } 62 | } 63 | 64 | template 65 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 66 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, 67 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 68 | { 69 | const int b = blockIdx.x / H; 70 | const int h = blockIdx.x % H; 71 | const int i = threadIdx.x; 72 | _w += h*_N_; 73 | _u += h*_N_; 74 | __w += h*_N_; 75 | 76 | __shared__ float w_[_N_], u_[_N_]; 77 | __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_]; 78 | __syncthreads(); 79 | w_[i] = _w[i]; 80 | u_[i] = float(_u[i]); 81 | __syncthreads(); 82 | 83 | const float w = w_[i]; 84 | const float ww = __w[i]; 85 | const float u = u_[i]; 86 | 87 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; 88 | 89 | float gw = 0, gu = 0; 90 | const int t000 = b*T*C + h*_N_ + i; 91 | const int t111 = (b+1)*T*C + h*_N_ + i; 92 | const int t222 = t111 - 2*C; 93 | 94 | for (int t = t000; t < t111; t += C) 95 | { 96 | __syncthreads(); 97 | v[i] = float(_v[t]); 98 | gy[i] = float(_gy[t]); 99 | __syncthreads(); 100 | 101 | const float k = float(_k[t]); 102 | float gr = 0, gu_ = 0; 103 | 104 | #pragma unroll 105 | for (int j = 0; j < _N_; j++) 106 | { 107 | float& s = state[j]; 108 | float x = k * v[j]; 109 | 110 | gr += (u * x + s) * gy[j]; 111 | gu_ += x * gy[j]; 112 | s = s * w + x; 113 | } 114 | _gr[t] = F(gr); 115 | gu += float(_r[t]) * gu_; 116 | } 117 | _gu[b*C + h*_N_ + i] = F(gu); 118 | 119 | for (int t = t000; t < t222; t += C) 120 | { 121 | __syncthreads(); 122 | v[i] = float(_v[t]); 123 | gy[i] = float(_gy[t + 2*C]); 124 | __syncthreads(); 125 | 126 | const float k = float(_k[t]); 127 | float gw_ = 0; 128 | 129 | #pragma unroll 130 | for (int j = 0; j < _N_; j++) 131 | { 132 | float& s = saaaa[j]; 133 | float& s2 = sbbbb[j]; 134 | float x = k * v[j]; 135 | 136 | float tmp = w * (x + s); 137 | s = tmp; 138 | s2 = tmp + w * s2; 139 | gw_ += s2 * gy[j]; 140 | } 141 | gw += float(_r[t + 2*C]) * gw_; 142 | } 143 | _gw[b*C + h*_N_ + i] = F(ww * gw); 144 | 145 | for (int t = t111 - C; t >= t000; t -= C) 146 | { 147 | __syncthreads(); 148 | v[i] = float(_v[t]); 149 | gy[i] = float(_gy[t]); 150 | __syncthreads(); 151 | 152 | const float rr = float(_r[t]); 153 | float gk = 0; 154 | 155 | #pragma unroll 156 | for (int j = 0; j < _N_; j++) 157 | { 158 | float& s = scccc[j]; 159 | float x = rr * gy[j]; 160 | 161 | gk += (u * x + s) * v[j]; 162 | s = x + s * w; 163 | } 164 | _gk[t] = F(gk); 165 | } 166 | 167 | for (int t = t111 - C; t >= t000; t -= C) 168 | { 169 | __syncthreads(); 170 | r[i] = float(_r[t]); 171 | k[i] = float(_k[t]); 172 | __syncthreads(); 173 | 174 | const float gyy = float(_gy[t]); 175 | float gv = 0; 176 | 177 | #pragma unroll 178 | for (int j = 0; j < _N_; j++) 179 | { 180 | float& s = sdddd[j]; 181 | float x = gyy * r[j]; 182 | 183 | gv += (u_[j] * x + s) * k[j]; 184 | s = x + s * w_[j]; 185 | } 186 | _gv[t] = F(gv); 187 | } 188 | } 189 | 190 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 191 | { 192 | assert(H*_N_ == C); 193 | assert(_N_%4 == 0); 194 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 195 | } 196 | 197 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 198 | { 199 | assert(H*_N_ == C); 200 | assert(_N_%4 == 0); 201 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); 202 | } 203 | -------------------------------------------------------------------------------- /wkv5_bf16/cuda/wkv5_cuda_v1b2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 9 | F *__restrict__ const _y) 10 | { 11 | const int b = blockIdx.x / H; 12 | const int h = blockIdx.x % H; 13 | const int i = threadIdx.x; 14 | _w += h*_N_; 15 | _u += h*_N_; 16 | 17 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 18 | float state[_N_] = {0}; 19 | 20 | __syncthreads(); 21 | w[i] = _w[i]; 22 | u[i] = float(_u[i]); 23 | __syncthreads(); 24 | 25 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 26 | { 27 | __syncthreads(); 28 | r[i] = float(_r[t]); 29 | k[i] = float(_k[t]); 30 | __syncthreads(); 31 | 32 | const float v = float(_v[t]); 33 | float y = 0; 34 | 35 | #pragma unroll 36 | for (int j = 0; j < _N_; j+=4) 37 | { 38 | const float4& r_ = (float4&)(r[j]); 39 | const float4& k_ = (float4&)(k[j]); 40 | const float4& w_ = (float4&)(w[j]); 41 | const float4& u_ = (float4&)(u[j]); 42 | float4& s = (float4&)(state[j]); 43 | float4 x; 44 | 45 | x.x = k_.x * v; 46 | x.y = k_.y * v; 47 | x.z = k_.z * v; 48 | x.w = k_.w * v; 49 | 50 | y += r_.x * (u_.x * x.x + s.x); 51 | y += r_.y * (u_.y * x.y + s.y); 52 | y += r_.z * (u_.z * x.z + s.z); 53 | y += r_.w * (u_.w * x.w + s.w); 54 | 55 | s.x = s.x * w_.x + x.x; 56 | s.y = s.y * w_.y + x.y; 57 | s.z = s.z * w_.z + x.z; 58 | s.w = s.w * w_.w + x.w; 59 | } 60 | _y[t] = F(y); 61 | } 62 | } 63 | 64 | template 65 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 66 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, 67 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 68 | { 69 | const int b = blockIdx.x / H; 70 | const int h = blockIdx.x % H; 71 | const int i = threadIdx.x; 72 | _w += h*_N_; 73 | _u += h*_N_; 74 | __w += h*_N_; 75 | 76 | __shared__ float w_[_N_], u_[_N_]; 77 | __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_]; 78 | __syncthreads(); 79 | w_[i] = _w[i]; 80 | u_[i] = float(_u[i]); 81 | __syncthreads(); 82 | 83 | const float w = w_[i]; 84 | const float ww = __w[i]; 85 | const float u = u_[i]; 86 | 87 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; 88 | 89 | float gw = 0, gu = 0; 90 | const int t000 = b*T*C + h*_N_ + i; 91 | const int t111 = (b+1)*T*C + h*_N_ + i; 92 | const int t222 = t111 - 2*C; 93 | 94 | for (int t = t000; t < t111; t += C) 95 | { 96 | __syncthreads(); 97 | v[i] = float(_v[t]); 98 | gy[i] = float(_gy[t]); 99 | __syncthreads(); 100 | 101 | const float k = float(_k[t]); 102 | float gr = 0, gu_ = 0; 103 | 104 | #pragma unroll 105 | for (int j = 0; j < _N_; j++) 106 | { 107 | float& s = state[j]; 108 | float x = k * v[j]; 109 | 110 | gr += (u * x + s) * gy[j]; 111 | gu_ += x * gy[j]; 112 | s = s * w + x; 113 | } 114 | _gr[t] = F(gr); 115 | gu += float(_r[t]) * gu_; 116 | } 117 | _gu[b*C + h*_N_ + i] = F(gu); 118 | 119 | for (int t = t000; t < t222; t += C) 120 | { 121 | __syncthreads(); 122 | v[i] = float(_v[t]); 123 | gy[i] = float(_gy[t + 2*C]); 124 | __syncthreads(); 125 | 126 | const float k = float(_k[t]); 127 | float gw_ = 0; 128 | 129 | #pragma unroll 130 | for (int j = 0; j < _N_; j++) 131 | { 132 | float& s = saaaa[j]; 133 | float& s2 = sbbbb[j]; 134 | float x = k * v[j]; 135 | 136 | float tmp = w * (x + s); 137 | s = tmp; 138 | s2 = tmp + w * s2; 139 | gw_ += s2 * gy[j]; 140 | } 141 | gw += float(_r[t + 2*C]) * gw_; 142 | } 143 | _gw[b*C + h*_N_ + i] = F(ww * gw); 144 | 145 | for (int t = t111 - C; t >= t000; t -= C) 146 | { 147 | __syncthreads(); 148 | v[i] = float(_v[t]); 149 | gy[i] = float(_gy[t]); 150 | __syncthreads(); 151 | 152 | const float rr = float(_r[t]); 153 | float gk = 0; 154 | 155 | #pragma unroll 156 | for (int j = 0; j < _N_; j++) 157 | { 158 | float& s = scccc[j]; 159 | float x = rr * gy[j]; 160 | 161 | gk += (u * x + s) * v[j]; 162 | s = x + s * w; 163 | } 164 | _gk[t] = F(gk); 165 | } 166 | 167 | for (int t = t111 - C; t >= t000; t -= C) 168 | { 169 | __syncthreads(); 170 | r[i] = float(_r[t]); 171 | k[i] = float(_k[t]); 172 | __syncthreads(); 173 | 174 | const float gyy = float(_gy[t]); 175 | float gv = 0; 176 | 177 | #pragma unroll 178 | for (int j = 0; j < _N_; j++) 179 | { 180 | float& s = sdddd[j]; 181 | float x = gyy * r[j]; 182 | 183 | gv += (u_[j] * x + s) * k[j]; 184 | s = x + s * w_[j]; 185 | } 186 | _gv[t] = F(gv); 187 | } 188 | } 189 | 190 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 191 | { 192 | assert(H*_N_ == C); 193 | assert(_N_%4 == 0); 194 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 195 | } 196 | 197 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 198 | { 199 | assert(H*_N_ == C); 200 | assert(_N_%4 == 0); 201 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); 202 | } 203 | -------------------------------------------------------------------------------- /rwkv7_fast_fused/cuda/rwkv7_statepassing_clampw.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #ifdef _FP32_ 4 | using bf = float; 5 | #define to_float(u) (u) 6 | #define to_bf(u) (u) 7 | #else 8 | #include 9 | using bf = __nv_bfloat16; 10 | #define to_float(u) (__bfloat162float(u)) 11 | #define to_bf(u) (__float2bfloat16_rn(u)) 12 | #endif 13 | 14 | using i64 = long long int; 15 | typedef bf * __restrict__ F_; 16 | constexpr float W_SCALE = -0.6065306597f; // -exp(-0.5) 17 | 18 | //###################################################################################################### 19 | 20 | template __launch_bounds__(N,2) 21 | __global__ void forward_kernel(int T,int H,float *__restrict__ s0_,F_ r_,F_ w_,F_ k_,F_ v_,F_ a_,F_ b_,bf* __restrict__ y_,float* sT_,float* s__,float* __restrict__ sa_) 22 | { 23 | const int bb=blockIdx.y, hh=blockIdx.x, i=threadIdx.x; 24 | float* __restrict__ s_ = s__ + i64(bb*H+hh) * i64((T/_CHUNK_LEN_)*N*N); 25 | const i64 s_shift = i64(bb*H+hh) * i64(N*N) + i64(i*N); 26 | s0_ += s_shift; 27 | sT_ += s_shift; 28 | float state[N]; 29 | #pragma unroll 30 | for (int j=0; j<<>>(T,H,s0,r,w,k,v,a,b,y,sT,s,sa); 85 | } 86 | 87 | //###################################################################################################### 88 | 89 | template 90 | __global__ void backward_kernel(int T, int H, F_ r_, F_ w_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float* dsT_, float * __restrict__ s__, float * __restrict__ sa_, float* ds0_, bf* dr_, bf* dw_, bf* dk_, bf* dv_, bf* da_, bf* db_) 91 | { 92 | int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x; 93 | float* __restrict__ s_ = s__ + i64(bb*H+hh) * i64((T/_CHUNK_LEN_)*N*N); 94 | const i64 s_shift = i64(bb*H+hh) * i64(N*N); 95 | ds0_ += s_shift + i64(i*N); 96 | dsT_ += s_shift; 97 | float stateT[N], dstate[N], dstateT[N]; 98 | #pragma unroll 99 | for (int j = 0; j < N; j++) { 100 | dstate[j] = dsT_[i*N + j]; 101 | dstateT[j] = dsT_[j*N + i]; 102 | } 103 | __shared__ float r[N], w[N], k[N], v[N], a[N], b[N], dy[N], sa[N], dSb_shared[N]; 104 | float ri, wi, ki, ai, bi, dyi; 105 | 106 | for (int t = T-1; t >= 0; t--) 107 | { 108 | int idx = bb*T*H*N + t*H*N + hh * N + i; 109 | 110 | __syncthreads(); 111 | r[i] = ri = to_float(r_[idx]); 112 | float w_sig = 1.0f / (1.0f + __expf(-to_float(w_[idx]))); 113 | w[i] = wi = __expf(W_SCALE * w_sig); 114 | k[i] = ki = to_float(k_[idx]); 115 | v[i] = to_float(v_[idx]); 116 | a[i] = ai = to_float(a_[idx]); 117 | b[i] = bi = to_float(b_[idx]); 118 | dy[i] = dyi = to_float(dy_[idx]); 119 | sa[i] = sa_[idx]; 120 | __syncthreads(); 121 | 122 | if ((t+1)%_CHUNK_LEN_ == 0) { 123 | int base = (t/_CHUNK_LEN_)*N*N + i*N; 124 | const float4* s4 = (const float4*)(s_ + base); 125 | #pragma unroll 126 | for (int j4 = 0; j4 < N/4; j4++) { 127 | float4 q = s4[j4]; 128 | const int j = j4<<2; 129 | stateT[j+0] = q.x; 130 | stateT[j+1] = q.y; 131 | stateT[j+2] = q.z; 132 | stateT[j+3] = q.w; 133 | } 134 | } 135 | 136 | float dr = 0; 137 | #pragma unroll 138 | for (int j = 0; j < N; j++) { 139 | dr += stateT[j] * dy[j]; 140 | } 141 | dr_[idx] = to_bf(dr); 142 | 143 | float iwi = 1.0f / wi; 144 | #pragma unroll 145 | for (int j = 0; j < N; j++) { 146 | stateT[j] = (stateT[j] - ki * v[j] - bi * sa[j]) * iwi; 147 | dstate[j] += dyi * r[j]; 148 | dstateT[j] += ri * dy[j]; 149 | } 150 | 151 | float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0; 152 | #pragma unroll 153 | for (int j = 0; j < N; j++) { 154 | dw += dstateT[j] * stateT[j]; 155 | dk += dstateT[j] * v[j]; 156 | dv += dstate[j] * k[j]; 157 | dSb += dstate[j] * b[j]; 158 | db += dstateT[j] * sa[j]; 159 | } 160 | dw_[idx] = to_bf(W_SCALE * dw * wi * w_sig * (1.0f - w_sig)); 161 | 162 | dk_[idx] = to_bf(dk); 163 | dv_[idx] = to_bf(dv); 164 | db_[idx] = to_bf(db); 165 | 166 | __syncthreads(); 167 | dSb_shared[i] = dSb; 168 | __syncthreads(); 169 | 170 | float da = 0; 171 | #pragma unroll 172 | for (int j = 0; j < N; j++) { 173 | da += stateT[j]*dSb_shared[j]; 174 | } 175 | da_[idx] = to_bf(da); 176 | 177 | #pragma unroll 178 | for (int j = 0; j < N; j++) { 179 | dstate[j] = dstate[j] * w[j] + dSb * a[j]; 180 | dstateT[j] = dstateT[j] * wi + ai * dSb_shared[j]; 181 | } 182 | } 183 | #pragma unroll 184 | for (int j = 0; j < N; j++) { 185 | ds0_[j] = dstate[j]; 186 | } 187 | } 188 | 189 | void cuda_backward(int B, int T, int H, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*dy, float*dsT, float*s, float*sa, float*ds0, bf*dr, bf*dw, bf*dk, bf*dv, bf*da, bf*db) 190 | { 191 | assert(T%_CHUNK_LEN_ == 0); 192 | backward_kernel<_N_><<>>(T,H,r,w,k,v,a,b,dy,dsT,s,sa,ds0,dr,dw,dk,dv,da,db); 193 | } 194 | -------------------------------------------------------------------------------- /wkv5/cuda/wkv5_cuda_v1e.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | template 5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, 7 | F *__restrict__ const _y) 8 | { 9 | const int b = blockIdx.x / H; 10 | const int h = blockIdx.x - (blockIdx.x / H) * H; 11 | const int i = threadIdx.x; 12 | _w += h*N; 13 | _u += h*N; 14 | 15 | const float4 *__restrict__ const k = (float4 *)(_k); 16 | const float4 *__restrict__ const r = (float4 *)(_r); 17 | const float4 *__restrict__ const w = (float4 *)(_w); 18 | const float4 *__restrict__ const u = (float4 *)(_u); 19 | const F *__restrict__ const v = _v; 20 | F *__restrict__ const y = _y; 21 | 22 | __shared__ float state[N * N]; 23 | __shared__ float4 rr[N >> 2], kk[N >> 2]; 24 | 25 | for (int j = 0; j < N; ++j) 26 | state[j * N + i] = 0.0f; 27 | 28 | for (int _tt = b*T*C + h*N + i, _tend = (b+1)*T*C + h*N + i; _tt < _tend; _tt += C) 29 | { 30 | const int _t = _tt >> 2; 31 | const F vv = v[_tt]; 32 | F yy = 0.0; 33 | 34 | rr[i >> 2] = r[_t]; 35 | kk[i >> 2] = k[_t]; 36 | 37 | __syncthreads(); 38 | 39 | #pragma unroll 40 | for (int j = 0; j < N >> 2; j++) 41 | { 42 | const int j4n = (j << 2) * N; 43 | 44 | const float4 ww = w[j]; 45 | const float4 uu = u[j]; 46 | const float4 rrr = rr[j]; 47 | const float4 kkk = kk[j]; 48 | 49 | float4 x; 50 | x.x = kkk.x * vv; 51 | x.y = kkk.y * vv; 52 | x.z = kkk.z * vv; 53 | x.w = kkk.w * vv; 54 | 55 | F &s0 = state[j4n + i]; 56 | F &s1 = state[j4n + N + i]; 57 | F &s2 = state[j4n + 2*N + i]; 58 | F &s3 = state[j4n + 3*N + i]; 59 | 60 | yy += rrr.x * (uu.x * x.x + s0) + rrr.y * (uu.y * x.y + s1) + rrr.z * (uu.z * x.z + s2) + rrr.w * (uu.w * x.w + s3); 61 | s0 = s0 * ww.x + x.x; 62 | s1 = s1 * ww.y + x.y; 63 | s2 = s2 * ww.z + x.z; 64 | s3 = s3 * ww.w + x.w; 65 | } 66 | y[_tt] = yy; 67 | __syncthreads(); 68 | } 69 | } 70 | 71 | template 72 | __global__ void kernel_backward (const int B, const int T, const int C, const int H, 73 | const F *__restrict__ const r, const F *__restrict__ const k, const F *__restrict__ const v, const F *__restrict__ w, const F *__restrict__ wwww, const F *__restrict__ u, const F *__restrict__ const gy, 74 | F *__restrict__ const gr, F *__restrict__ const gk, F *__restrict__ const gv, F *__restrict__ gw, F *__restrict__ gu) 75 | { 76 | const int b = blockIdx.x / H; 77 | const int h = blockIdx.x % H; 78 | const int i = threadIdx.x; 79 | w += h*N; 80 | u += h*N; 81 | gu += h*N; 82 | gw += h*N; 83 | wwww += h*N; 84 | 85 | __shared__ float state[N * N], vv[N], rr[N], kk[N], gyy[N]; 86 | 87 | #pragma unroll 88 | for (int j = 0; j < N; ++j){ 89 | state[j * N + i] = 0; 90 | } 91 | 92 | const float ww = w[i]; 93 | const float uu = u[i]; 94 | const float wwwww = wwww[i]; 95 | float saaaa[N] = {0.0f}, sbbbb[N] = {0.0f}; 96 | 97 | for (int _t = b*T*C + h*N + i, _tend = (b+1)*T*C + h*N + i; _t < _tend; _t += C) 98 | { 99 | const F kk = k[_t]; 100 | const F rr = r[_t]; 101 | F grr = 0; 102 | F guu = 0; 103 | 104 | vv[i] = v[_t]; 105 | gyy[i] = gy[_t]; 106 | 107 | __syncthreads(); 108 | 109 | #pragma unroll 110 | for (int j = 0; j < N; j++) 111 | { 112 | 113 | float x = vv[j] * kk; 114 | float s = state[j * N + i]; 115 | 116 | grr += gyy[j] * (uu * x + s); 117 | state[j * N + i] = s * ww + x; 118 | guu += rr * x * gyy[j]; 119 | 120 | } 121 | gr[_t] = grr; 122 | atomicAdd(gu + i, guu); 123 | 124 | __syncthreads(); 125 | if (_t < _tend - 2 * C){ 126 | const F rr_value = r[_t+2*C]; 127 | gyy[i] = gy[_t+2*C]; 128 | __syncthreads(); 129 | 130 | #pragma unroll 131 | for (int j = 0; j < N; j++){ 132 | float x = vv[j] * kk; 133 | saaaa[j] = ww * (saaaa[j] + sbbbb[j] + x); 134 | sbbbb[j] = ww * (sbbbb[j] + x); 135 | atomicAdd(gw+i, rr_value * wwwww * saaaa[j] * gyy[j]); 136 | } 137 | 138 | __syncthreads(); 139 | } 140 | } 141 | 142 | #pragma unroll 143 | for (int j = 0; j < N; ++j) 144 | state[j * N + i] = 0; 145 | 146 | for (int _t = (b+1)*T*C + h*N + i - C, _tend = b*T*C + h*N + i; _t >= _tend; _t -= C) 147 | { 148 | const F rr = r[_t]; 149 | F gkk = 0; 150 | 151 | vv[i] = v[_t]; 152 | gyy[i] = gy[_t]; 153 | 154 | __syncthreads(); 155 | 156 | #pragma unroll 157 | for (int j = 0; j < N; j++) 158 | { 159 | 160 | float x = gyy[j] * rr; 161 | float s = state[j * N + i]; 162 | 163 | gkk += vv[j] * (uu * x + s); 164 | state[j * N + i] = s * ww + x; 165 | } 166 | gk[_t] = gkk; 167 | __syncthreads(); 168 | } 169 | 170 | #pragma unroll 171 | for (int j = 0; j < N; ++j) 172 | state[j * N + i] = 0; 173 | 174 | for (int _t = (b+1)*T*C + h*N + i - C, _tend = b*T*C + h*N + i; _t >= _tend; _t -= C) 175 | { 176 | const F gy_value = gy[_t]; 177 | F gvv = 0; 178 | 179 | kk[i] = k[_t]; 180 | rr[i] = r[_t]; 181 | 182 | __syncthreads(); 183 | 184 | #pragma unroll 185 | for (int j = 0; j < N; j++) 186 | { 187 | 188 | float x = gy_value * rr[j]; 189 | float s = state[j * N + i]; 190 | 191 | gvv += kk[j] * (u[j] * x + s); 192 | state[j * N + i] = s * w[j] + x; 193 | } 194 | gv[_t] = gvv; 195 | __syncthreads(); 196 | } 197 | } 198 | 199 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y) 200 | { 201 | assert(H*N == C); 202 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 203 | } 204 | 205 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *ww, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu) 206 | { 207 | assert(H*N == C); 208 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); 209 | } 210 | -------------------------------------------------------------------------------- /wkv5_bf16/cuda/wkv5_cuda_v2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 9 | F *__restrict__ const _y) 10 | { 11 | const int b = blockIdx.x / H; 12 | const int h = blockIdx.x % H; 13 | const int i = threadIdx.x; 14 | _w += h*_N_; 15 | _u += h*_N_; 16 | 17 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 18 | float state[_N_] = {0}; 19 | 20 | __syncthreads(); 21 | u[i] = float(_u[i]); 22 | w[i] = float(_w[i]); 23 | __syncthreads(); 24 | 25 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 26 | { 27 | __syncthreads(); 28 | r[i] = float(_r[t]); 29 | k[i] = float(_k[t]); 30 | __syncthreads(); 31 | 32 | const float v = float(_v[t]); 33 | float y = 0; 34 | 35 | #pragma unroll 36 | for (int j = 0; j < _N_; j+=4) 37 | { 38 | const float4& r_ = (float4&)(r[j]); 39 | const float4& k_ = (float4&)(k[j]); 40 | const float4& w_ = (float4&)(w[j]); 41 | const float4& u_ = (float4&)(u[j]); 42 | float4& s = (float4&)(state[j]); 43 | float4 x; 44 | 45 | x.x = k_.x * v; 46 | x.y = k_.y * v; 47 | x.z = k_.z * v; 48 | x.w = k_.w * v; 49 | 50 | y += r_.x * (u_.x * x.x + s.x); 51 | y += r_.y * (u_.y * x.y + s.y); 52 | y += r_.z * (u_.z * x.z + s.z); 53 | y += r_.w * (u_.w * x.w + s.w); 54 | 55 | s.x = s.x * w_.x + x.x; 56 | s.y = s.y * w_.y + x.y; 57 | s.z = s.z * w_.z + x.z; 58 | s.w = s.w * w_.w + x.w; 59 | } 60 | _y[t] = F(y); 61 | } 62 | } 63 | 64 | template 65 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 66 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, 67 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 68 | { 69 | const int b = blockIdx.x / H; 70 | const int h = blockIdx.x % H; 71 | const int i = threadIdx.x; 72 | _w += h*_N_; 73 | _u += h*_N_; 74 | __w += h*_N_; 75 | const float w = _w[i]; 76 | const float u = float(_u[i]); 77 | const float ww = __w[i]; 78 | 79 | __shared__ float v[_N_], r[_N_], k[_N_], gy[_N_], gy2[_N_], w_[_N_], u_[_N_]; 80 | float state[_N_*2] = {0}; 81 | 82 | float gw = 0, gu = 0; 83 | const int t000 = b*T*C + h*_N_ + i; 84 | const int t111 = (b+1)*T*C + h*_N_ + i; 85 | const int t222 = t111 - 2*C; 86 | 87 | for (int _t = t000; _t < t111; _t += C) 88 | { 89 | __syncthreads(); 90 | v[i] = float(_v[_t]); 91 | gy[i] = float(_gy[_t]); 92 | __syncthreads(); 93 | 94 | const float k = float(_k[_t]); 95 | const float r = float(_r[_t]); 96 | 97 | float gr = 0; 98 | 99 | #pragma unroll 100 | for (int j = 0; j < _N_; j++) 101 | { 102 | float x = v[j] * k; 103 | float &s = state[j]; 104 | 105 | gr += gy[j] * (u * x + s); 106 | gu += r * x * gy[j]; 107 | s = s * w + x; 108 | } 109 | _gr[_t] = F(gr); 110 | } 111 | _gu[b*C + h*_N_ + i] = F(gu); 112 | 113 | #pragma unroll 114 | for (int j = 0; j < _N_; ++j) { 115 | state[j] = 0; 116 | } 117 | 118 | for (int _t = t000; _t < t222; _t += C) 119 | { 120 | __syncthreads(); 121 | v[i] = float(_v[_t]); 122 | gy2[i] = float(_gy[_t + 2*C]); 123 | __syncthreads(); 124 | const float r2 = float(_r[_t + 2*C]); 125 | const float k = float(_k[_t]); 126 | 127 | #pragma unroll 128 | for (int j = 0; j < _N_; j++) 129 | { 130 | float x = v[j] * k; 131 | // accum[j] = w[h,i] * (accum[j] + accum[j+N] + x) 132 | // accum[j+N] = w[h,i] * (accum[j+N] + x) 133 | // gw[h,i] += r[b,t+2,h,i] * _w[h,i] * accum[j] * gy[b,t+2,h,j] 134 | float &s1 = state[j]; 135 | float &s2 = state[j + _N_]; 136 | s1 = w * (s1 + s2 + x); 137 | s2 = w * (s2 + x); 138 | gw += r2 * ww * s1 * gy2[j]; 139 | } 140 | } 141 | 142 | _gw[b*C + h*_N_ + i] = F(gw); 143 | 144 | #pragma unroll 145 | for (int j = 0; j < _N_; ++j) { 146 | state[j] = 0; 147 | } 148 | 149 | __syncthreads(); 150 | w_[i] = float(_w[i]); 151 | u_[i] = float(_u[i]); 152 | __syncthreads(); 153 | 154 | for (int _t = t111 - C; _t >= t000; _t -= C) 155 | { 156 | __syncthreads(); 157 | v[i] = float(_v[_t]); 158 | gy[i] = float(_gy[_t]); 159 | k[i] = float(_k[_t]); 160 | r[i] = float(_r[_t]); 161 | __syncthreads(); 162 | 163 | float gk = 0, x; 164 | 165 | #pragma unroll 166 | for (int j = 0; j < _N_; j++) 167 | { 168 | x = gy[j] * r[i]; 169 | float &s = state[j]; 170 | gk += v[j] * (u * x + s); 171 | s = s * w + x; 172 | } 173 | _gk[_t] = F(gk); 174 | } 175 | 176 | #pragma unroll 177 | for (int j = 0; j < _N_; ++j) { 178 | state[j] = 0; 179 | } 180 | 181 | for (int _t = t111 - C; _t >= t000; _t -= C) 182 | { 183 | __syncthreads(); 184 | gy[i] = float(_gy[_t]); 185 | r[i] = float(_r[_t]); 186 | k[i] = float(_k[_t]); 187 | __syncthreads(); 188 | 189 | float gv = 0, x; 190 | 191 | #pragma unroll 192 | for (int j = 0; j < _N_; j++) 193 | { 194 | x = gy[i] * r[j]; 195 | float &s = state[j]; 196 | gv += k[j] * (u_[j] * x + s); 197 | s = s * w_[j] + x; 198 | } 199 | _gv[_t] = F(gv); 200 | } 201 | } 202 | 203 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 204 | { 205 | assert(H*_N_ == C); 206 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 207 | } 208 | 209 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 210 | { 211 | assert(H*_N_ == C); 212 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); 213 | } 214 | -------------------------------------------------------------------------------- /wkv6/cuda/wkv6_cuda_v1.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 9 | F *__restrict__ const _y) 10 | { 11 | const int b = blockIdx.x / H; 12 | const int h = blockIdx.x % H; 13 | const int i = threadIdx.x; 14 | _u += h*_N_; 15 | 16 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 17 | float state[_N_] = {0}; 18 | 19 | __syncthreads(); 20 | u[i] = float(_u[i]); 21 | __syncthreads(); 22 | 23 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 24 | { 25 | __syncthreads(); 26 | w[i] = exp(_w[t]); 27 | r[i] = float(_r[t]); 28 | k[i] = float(_k[t]); 29 | __syncthreads(); 30 | 31 | const float v = float(_v[t]); 32 | float y = 0; 33 | 34 | #pragma unroll 35 | for (int j = 0; j < _N_; j+=4) 36 | { 37 | const float4& r_ = (float4&)(r[j]); 38 | const float4& k_ = (float4&)(k[j]); 39 | const float4& w_ = (float4&)(w[j]); 40 | const float4& u_ = (float4&)(u[j]); 41 | float4& s = (float4&)(state[j]); 42 | float4 x; 43 | 44 | x.x = k_.x * v; 45 | x.y = k_.y * v; 46 | x.z = k_.z * v; 47 | x.w = k_.w * v; 48 | 49 | y += r_.x * (u_.x * x.x + s.x); 50 | y += r_.y * (u_.y * x.y + s.y); 51 | y += r_.z * (u_.z * x.z + s.z); 52 | y += r_.w * (u_.w * x.w + s.w); 53 | 54 | s.x = s.x * w_.x + x.x; 55 | s.y = s.y * w_.y + x.y; 56 | s.z = s.z * w_.z + x.z; 57 | s.w = s.w * w_.w + x.w; 58 | } 59 | _y[t] = F(y); 60 | } 61 | } 62 | 63 | template 64 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 65 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, 66 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 67 | { 68 | const int b = blockIdx.x / H; 69 | const int h = blockIdx.x % H; 70 | const int i = threadIdx.x; 71 | _u += h*_N_; 72 | 73 | __shared__ float u_[_N_]; 74 | __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; 75 | __syncthreads(); 76 | u_[i] = float(_u[i]); 77 | __syncthreads(); 78 | 79 | const float u = u_[i]; 80 | 81 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; 82 | float buf[_T_*_N_] = {0}; 83 | 84 | const int t_0 = b*T*C + h*_N_ + i; 85 | const int t_1 = t_0 + C; 86 | const int t_2 = t_0 + 2*C; 87 | const int t_T_2 = t_0 + (T-2)*C; 88 | const int t_T_1 = t_0 + (T-1)*C; 89 | const int t_T = t_0 + T*C; 90 | 91 | float gu = 0; 92 | for (int t = t_0; t < t_T; t += C) 93 | { 94 | __syncthreads(); 95 | v[i] = float(_v[t]); 96 | gy[i] = float(_gy[t]); 97 | __syncthreads(); 98 | 99 | const float k = float(_k[t]); 100 | const float w = exp(_w[t]); 101 | float gr = 0, gu_ = 0; 102 | 103 | #pragma unroll 104 | for (int j = 0; j < _N_; j++) 105 | { 106 | float& s = state[j]; 107 | float x = k * v[j]; 108 | 109 | gr += (u * x + s) * gy[j]; 110 | gu_ += x * gy[j]; 111 | s = s * w + x; 112 | } 113 | _gr[t] = F(gr); 114 | gu += float(_r[t]) * gu_; 115 | } 116 | _gu[b*C + h*_N_ + i] = F(gu); 117 | 118 | for (int t = t_0; t < t_T_2; t += C) 119 | { 120 | __syncthreads(); 121 | v[i] = float(_v[t]); 122 | __syncthreads(); 123 | 124 | const float k = float(_k[t]); 125 | const float w = exp(_w[t]); 126 | const int tt = (t-t_0)/C*_N_; 127 | 128 | #pragma unroll 129 | for (int j = 0; j < _N_; j++) 130 | { 131 | float& s = saaaa[j]; 132 | float x = k * v[j]; 133 | 134 | float tmp = w * s + x; 135 | s = tmp; 136 | buf[tt + j] = tmp; 137 | } 138 | } 139 | 140 | for (int t = t_T_1; t > t_1; t -= C) 141 | { 142 | __syncthreads(); 143 | gy[i] = float(_gy[t]); 144 | __syncthreads(); 145 | 146 | const float r = float(_r[t]); 147 | const float w = exp(_w[t]); 148 | float sum = 0.0f; 149 | const int tt = (t-t_2)/C*_N_; 150 | 151 | #pragma unroll 152 | for (int j = 0; j < _N_; j++) 153 | { 154 | float& s = sbbbb[j]; 155 | float x = r * gy[j]; 156 | 157 | float tmp = w * s + x; 158 | s = tmp; 159 | sum += buf[tt + j] * tmp; 160 | } 161 | _gw[t-C] = F(sum * _w[t-C] * exp(_w[t-C])); 162 | } 163 | 164 | for (int t = t_T_1; t >= t_0; t -= C) 165 | { 166 | __syncthreads(); 167 | v[i] = float(_v[t]); 168 | gy[i] = float(_gy[t]); 169 | __syncthreads(); 170 | 171 | const float rr = float(_r[t]); 172 | const float w = exp(_w[t]); 173 | float gk = 0; 174 | 175 | #pragma unroll 176 | for (int j = 0; j < _N_; j++) 177 | { 178 | float& s = scccc[j]; 179 | float x = rr * gy[j]; 180 | 181 | gk += (u * x + s) * v[j]; 182 | s = x + s * w; 183 | } 184 | _gk[t] = F(gk); 185 | } 186 | 187 | for (int t = t_T_1; t >= t_0; t -= C) 188 | { 189 | __syncthreads(); 190 | r[i] = float(_r[t]); 191 | k[i] = float(_k[t]); 192 | w_[i] = exp(_w[t]); 193 | __syncthreads(); 194 | 195 | const float gyy = float(_gy[t]); 196 | float gv = 0; 197 | 198 | #pragma unroll 199 | for (int j = 0; j < _N_; j++) 200 | { 201 | float& s = sdddd[j]; 202 | float x = gyy * r[j]; 203 | 204 | gv += (u_[j] * x + s) * k[j]; 205 | s = x + s * w_[j]; 206 | } 207 | _gv[t] = F(gv); 208 | } 209 | } 210 | 211 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 212 | { 213 | assert(H*_N_ == C); 214 | assert(_N_%4 == 0); 215 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 216 | } 217 | 218 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 219 | { 220 | assert(H*_N_ == C); 221 | assert(_N_%4 == 0); 222 | kernel_backward<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu); 223 | } 224 | -------------------------------------------------------------------------------- /wkv5a/cuda/wkv5a_cuda_v1.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | typedef float DTYPE; 6 | 7 | template 8 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 9 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w1, const F *__restrict__ _u1, const float *__restrict__ _w2, const F *__restrict__ _u2, 10 | F *__restrict__ const _y) 11 | { 12 | const int b = blockIdx.x / H; 13 | const int h = blockIdx.x % H; 14 | const int i = threadIdx.x; 15 | _w1 += h*_N_; 16 | _u1 += h*_N_; 17 | 18 | __shared__ float r[_N_], k[_N_], u__[_N_], w__[_N_]; 19 | 20 | __syncthreads(); 21 | w__[i] = _w1[i]; 22 | u__[i] = float(_u1[i]); 23 | __syncthreads(); 24 | 25 | float state[_N_] = {0}; 26 | float u[_N_], w[_N_]; 27 | 28 | #pragma unroll 29 | for (int j = 0; j < _N_; j++) { 30 | w[j] = w__[j] * _w2[h*_N_+i]; 31 | u[j] = u__[j] + _u2[h*_N_+i]; 32 | } 33 | 34 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 35 | { 36 | __syncthreads(); 37 | r[i] = float(_r[t]); 38 | k[i] = float(_k[t]); 39 | __syncthreads(); 40 | 41 | const float v = float(_v[t]); 42 | float y = 0; 43 | 44 | #pragma unroll 45 | for (int j = 0; j < _N_; j++) 46 | { 47 | float x = k[j] * v; 48 | 49 | float s = state[j]; 50 | state[j] = s * w[j] + x; 51 | 52 | y += r[j] * (u[j] * x + s); 53 | } 54 | _y[t] = F(y); 55 | } 56 | } 57 | 58 | template 59 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 60 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w1, const float *__restrict__ __w1, const F *__restrict__ _u1, const float *__restrict__ _w2, const float *__restrict__ __w2, const F *__restrict__ _u2, const F *__restrict__ const _gy, 61 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw1, F *__restrict__ const _gu1, F *__restrict__ const _gw2, F *__restrict__ const _gu2) 62 | { 63 | const int b = blockIdx.x / H; 64 | const int h = blockIdx.x % H; 65 | const int i = threadIdx.x; 66 | 67 | _w1 += h*_N_; 68 | _u1 += h*_N_; 69 | __w1 += h*_N_; 70 | _w2 += h*_N_; 71 | _u2 += h*_N_; 72 | __w2 += h*_N_; 73 | const float w1 = _w1[i]; 74 | const float u1 = float(_u1[i]); 75 | const float ww1 = __w1[i]; 76 | const float w2 = _w2[i]; 77 | const float u2 = float(_u2[i]); 78 | const float ww2 = __w2[i]; 79 | 80 | __shared__ float v[_N_], r[_N_], k[_N_], gy[_N_]; 81 | 82 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; 83 | float gu1 = 0, gu2 = 0, gw1 = 0, gw2 = 0; 84 | 85 | for (int t = b*T*C + h*_N_ + i, tend = (b+1)*T*C + h*_N_ + i; t < tend; t += C) 86 | { 87 | __syncthreads(); 88 | v[i] = float(_v[t]); 89 | k[i] = float(_k[t]); 90 | r[i] = float(_r[t]); 91 | gy[i] = float(_gy[t]); 92 | __syncthreads(); 93 | 94 | float ki = k[i]; 95 | float vi = v[i]; 96 | float ri = r[i]; 97 | float gyi = gy[i]; 98 | 99 | float gr = 0; 100 | 101 | #pragma unroll 102 | for (int j = 0; j < _N_; j++) 103 | { 104 | float x = v[j] * ki; 105 | float x2 = vi * k[j]; 106 | float s = state[j]; 107 | state[j] = s * (w1*_w2[j]) + x; 108 | 109 | gr += gy[j] * ((u1+_u2[j]) * x + s); 110 | gu1 += ri * x * gy[j]; 111 | gu2 += r[j] * x2 * gyi; 112 | } 113 | 114 | _gr[t] = F(gr); 115 | 116 | if (t < tend - 2*C) 117 | { 118 | __syncthreads(); 119 | gy[i] = float(_gy[t + 2*C]); 120 | r[i] = float(_r[t + 2*C]); 121 | __syncthreads(); 122 | 123 | float ri = r[i]; 124 | float gyi = gy[i]; 125 | 126 | #pragma unroll 127 | for (int j = 0; j < _N_; j++) 128 | { 129 | float x = v[j] * ki; 130 | float x2 = vi * k[j]; 131 | saaaa[j] = (w1*_w2[j]) * (saaaa[j] + sbbbb[j] + x); 132 | sbbbb[j] = (w1*_w2[j]) * (sbbbb[j] + x); 133 | scccc[j] = (_w1[j]*w2) * (scccc[j] + sdddd[j] + x2); 134 | sdddd[j] = (_w1[j]*w2) * (sdddd[j] + x2); 135 | 136 | gw1 += ri * ww1 * saaaa[j] * gy[j]; 137 | gw2 += r[j] * ww2 * scccc[j] * gyi; 138 | } 139 | } 140 | } 141 | _gu1[b*C + h*_N_ + i] = F(gu1); 142 | _gu2[b*C + h*_N_ + i] = F(gu2); 143 | _gw1[b*C + h*_N_ + i] = F(gw1); 144 | _gw2[b*C + h*_N_ + i] = F(gw2); 145 | 146 | #pragma unroll 147 | for (int j = 0; j < _N_; ++j) 148 | state[j] = 0; 149 | 150 | for (int t = (b+1)*T*C + h*_N_ + i - C, tend = b*T*C + h*_N_ + i; t >= tend; t -= C) 151 | { 152 | __syncthreads(); 153 | v[i] = float(_v[t]); 154 | gy[i] = float(_gy[t]); 155 | __syncthreads(); 156 | 157 | const float rr = float(_r[t]); 158 | float gk = 0; 159 | 160 | #pragma unroll 161 | for (int j = 0; j < _N_; j++) 162 | { 163 | float x = gy[j] * rr; 164 | float s = state[j]; 165 | state[j] = s * (w1*_w2[j]) + x; 166 | 167 | gk += v[j] * ((u1+_u2[j]) * x + s); 168 | } 169 | _gk[t] = F(gk); 170 | } 171 | 172 | #pragma unroll 173 | for (int j = 0; j < _N_; ++j) 174 | state[j] = 0; 175 | 176 | for (int t = (b+1)*T*C + h*_N_ + i - C, tend = b*T*C + h*_N_ + i; t >= tend; t -= C) 177 | { 178 | __syncthreads(); 179 | k[i] = float(_k[t]); 180 | r[i] = float(_r[t]); 181 | __syncthreads(); 182 | 183 | const float gy = float(_gy[t]); 184 | float gv = 0; 185 | 186 | #pragma unroll 187 | for (int j = 0; j < _N_; j++) 188 | { 189 | float x = gy * r[j]; 190 | float s = state[j]; 191 | state[j] = s * float(_w1[j]*w2) + x; 192 | 193 | gv += k[j] * (float(_u1[j]+u2) * x + s); 194 | } 195 | _gv[t] = F(gv); 196 | } 197 | } 198 | 199 | void cuda_forward(int B, int T, int C, int H, DTYPE *r, DTYPE *k, DTYPE *v, float *w1, DTYPE *u1, float *w2, DTYPE *u2, DTYPE *y) 200 | { 201 | assert(H*_N_ == C); 202 | kernel_forward<<>>(B, T, C, H, r, k, v, w1, u1, w2, u2, y); 203 | } 204 | 205 | void cuda_backward(int B, int T, int C, int H, DTYPE *r, DTYPE *k, DTYPE *v, float *w1, float *ww1, DTYPE *u1, float *w2, float *ww2, DTYPE *u2, DTYPE *gy, DTYPE *gr, DTYPE *gk, DTYPE *gv, DTYPE *gw1, DTYPE *gu1, DTYPE *gw2, DTYPE *gu2) 206 | { 207 | assert(H*_N_ == C); 208 | kernel_backward<<>>(B, T, C, H, r, k, v, w1, ww1, u1, w2, ww2, u2, gy, gr, gk, gv, gw1, gu1, gw2, gu2); 209 | } 210 | -------------------------------------------------------------------------------- /depthwise_conv1d/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from torch.utils.cpp_extension import load 5 | import numpy as np 6 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 7 | torch.backends.cudnn.benchmark = True 8 | # turn off TF32 for higher accuracy 9 | torch.backends.cudnn.allow_tf32 = False 10 | torch.backends.cuda.matmul.allow_tf32 = False 11 | 12 | ###################################################################################################### 13 | # From https://github.com/BlinkDL/RWKV-CUDA 14 | # On GTX1070 mobile: 15 | # pytorch = fwd 94ms bwd 529ms 16 | # CUDA kernel v0 = fwd 45ms bwd 84ms (simple) 17 | # CUDA kernel v1 = fwd 17ms bwd 43ms (shared memory) 18 | # CUDA kernel v2 = fwd 13ms bwd 31ms (float4) 19 | # CUDA kernel v3 = fwd 3.4ms bwd 23ms (B-group) 20 | ###################################################################################################### 21 | 22 | CUDA_KERNEL_VERSION = 3 # CUDA kernel version = 0,1,2 23 | 24 | 25 | def set_seed(seed): 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | 30 | 31 | def get_err_ratio(x, y): 32 | err = (x-y).flatten().square().mean().sqrt().item() 33 | base = (x).flatten().square().mean().sqrt().item() 34 | return err / base 35 | 36 | ###################################################################################################### 37 | # The formula: 38 | # w.shape = (C, T) 39 | # k.shape = (B, C, T) 40 | # out.shape = (B, C, T) 41 | # out[b][c][t] = sum_u{ w[c][(T-1)-(t-u)] * k[b][c][u] } 42 | ###################################################################################################### 43 | 44 | 45 | def RUN_FORMULA_VERY_SLOW(w, k, B, C, T, eps): 46 | # this is the formula (very slow) 47 | out = torch.empty((B, C, T), device='cuda') 48 | for b in range(B): 49 | for c in range(C): 50 | for t in range(T): 51 | s = eps 52 | for u in range(0, t+1): 53 | s += w[c][(T-1)-(t-u)] * k[b][c][u] 54 | out[b][c][t] = s 55 | return out 56 | 57 | 58 | def RUN_PYTORCH(w, k, B, C, T, eps): 59 | # this shall equal the formula 60 | return F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w.unsqueeze(1), groups=C) + eps 61 | 62 | 63 | ###################################################################################################### 64 | # Load the CUDA kernel 65 | ###################################################################################################### 66 | 67 | T_MAX = 768 68 | B_GROUP_FORWARD = 8 69 | B_GROUP_BACKWARD = 2 70 | 71 | timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda_v" + str(CUDA_KERNEL_VERSION) + ".cu"], 72 | verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}'], extra_cflags=['/wd4624']) 73 | 74 | 75 | # we call it the "TimeX" operator because it's used for time-mixing in my RWKV language model 76 | class TimeX(torch.autograd.Function): 77 | @staticmethod 78 | def forward(ctx, w, k, B, C, T, eps): 79 | ctx.B = B 80 | ctx.C = C 81 | ctx.T = T 82 | assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0, "require T % 4 == 0 and T <= T_MAX and B % B_GROUP_* == 0" 83 | w = w.contiguous() 84 | k = k.contiguous() 85 | ctx.save_for_backward(w, k) 86 | wk = torch.empty((B, C, T), device='cuda', 87 | memory_format=torch.contiguous_format) 88 | timex_cuda.forward(w, k, wk, eps, B, C, T) 89 | return wk 90 | 91 | @staticmethod 92 | def backward(ctx, gwk): 93 | assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0, "require T % 4 == 0 and T <= T_MAX and B % B_GROUP_* == 0" 94 | w, k = ctx.saved_tensors 95 | gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda', 96 | memory_format=torch.contiguous_format) 97 | gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda', 98 | memory_format=torch.contiguous_format) 99 | timex_cuda.backward(w, k, gwk.contiguous(), gw, 100 | gk, ctx.B, ctx.C, ctx.T) 101 | # actually pytorch will do gw.sum(dim=0) but we will do it anyway just to be safe 102 | return (gw.sum(dim=0), gk, None, None, None, None) 103 | 104 | 105 | def RUN_CUDA(w, k, B, C, T, eps): 106 | return TimeX.apply(w.cuda(), k.cuda(), B, C, T, eps) 107 | 108 | 109 | ###################################################################################################### 110 | # Check correctness & speed benchmark 111 | ###################################################################################################### 112 | 113 | def CHECK_PYTORCH(): 114 | B = 3 115 | C = 5 116 | T = 11 117 | eps = 0.1 118 | 119 | set_seed(42) 120 | w = torch.rand(C, T, requires_grad=True, device='cuda') 121 | k = torch.rand(B, C, T, requires_grad=True, device='cuda') 122 | 123 | r0 = RUN_FORMULA_VERY_SLOW(w, k, B, C, T, eps) 124 | r1 = RUN_PYTORCH(w, k, B, C, T, eps) 125 | 126 | print('--> pytorch correct =', torch.allclose(r0, r1), 127 | ', err ratio =', get_err_ratio(r0, r1)) 128 | 129 | 130 | def CHECK_CUDA(silent=False): 131 | B = 32 132 | C = 768 133 | T = 768 134 | eps = 0.1 135 | 136 | set_seed(42) 137 | w = torch.rand(C, T, requires_grad=True, device='cuda') 138 | k = torch.rand(B, C, T, requires_grad=True, device='cuda') 139 | 140 | # check forward 141 | 142 | with torch.autograd.profiler.profile(use_cuda=True) as prof: 143 | r1 = RUN_PYTORCH(w, k, B, C, T, eps) 144 | if not silent: 145 | print('pytorch forward\n', prof.key_averages(group_by_stack_n=5).table( 146 | sort_by='self_cuda_time_total', row_limit=5)) 147 | 148 | with torch.autograd.profiler.profile(use_cuda=True) as prof: 149 | r2 = RUN_CUDA(w, k, B, C, T, eps) 150 | if not silent: 151 | print('CUDA forward\n', prof.key_averages(group_by_stack_n=5).table( 152 | sort_by='self_cuda_time_total', row_limit=5)) 153 | 154 | print('--> fwd correct =', torch.allclose(r1, r2), 155 | ', err ratio =', get_err_ratio(r1, r2)) 156 | 157 | # check backward 158 | 159 | # a strange loss for better verification 160 | loss1 = ((r1 * r1) - torch.tanh(r1)).sum() 161 | with torch.autograd.profiler.profile(use_cuda=True) as prof: 162 | loss1.backward() 163 | if not silent: 164 | print('pytorch backward\n', prof.key_averages(group_by_stack_n=5).table( 165 | sort_by='self_cuda_time_total', row_limit=5)) 166 | gw1 = w.grad.data.clone() 167 | gk1 = k.grad.data.clone() 168 | 169 | w.grad.data.zero_() 170 | k.grad.data.zero_() 171 | 172 | loss2 = ((r2 * r2) - torch.tanh(r2)).sum() 173 | with torch.autograd.profiler.profile(use_cuda=True) as prof: 174 | loss2.backward() 175 | if not silent: 176 | print('CUDA backward\n', prof.key_averages(group_by_stack_n=5).table( 177 | sort_by='self_cuda_time_total', row_limit=5)) 178 | gw2 = w.grad.data.clone() 179 | gk2 = k.grad.data.clone() 180 | 181 | print('--> bwd gradW correct =', torch.allclose(gw1, gw2), 182 | ', err ratio =', get_err_ratio(gw1, gw2)) 183 | print('--> bwd gradK correct =', torch.allclose(gk1, gk2), 184 | ', err ratio =', get_err_ratio(gk1, gk2)) 185 | 186 | 187 | if __name__ == "__main__": 188 | print('\n\nVerify pytorch...') 189 | CHECK_PYTORCH() 190 | print('\n\nCUDA warmup...') 191 | CHECK_CUDA(silent=True) # warmup 192 | CHECK_CUDA(silent=True) # warmup 193 | print('\n\nCUDA benchmark...') 194 | CHECK_CUDA(silent=False) # benchmark 195 | -------------------------------------------------------------------------------- /wkv5a/cuda/wkv5a_cuda_v1a.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | typedef float DTYPE; 6 | 7 | template 8 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 9 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w1, const F *__restrict__ _u1, const float *__restrict__ _w2, const F *__restrict__ _u2, 10 | F *__restrict__ const _y) 11 | { 12 | const int b = blockIdx.x / H; 13 | const int h = blockIdx.x % H; 14 | const int i = threadIdx.x; 15 | _w1 += h*_N_; 16 | _u1 += h*_N_; 17 | const float w2 = _w2[h*_N_+i]; 18 | const float u2 = float(_u2[h*_N_+i]); 19 | 20 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 21 | 22 | __syncthreads(); 23 | w[i] = _w1[i]; 24 | u[i] = float(_u1[i]); 25 | __syncthreads(); 26 | 27 | float state[_N_] = {0}; 28 | 29 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 30 | { 31 | __syncthreads(); 32 | r[i] = float(_r[t]); 33 | k[i] = float(_k[t]); 34 | __syncthreads(); 35 | 36 | const float v = float(_v[t]); 37 | float y = 0; 38 | 39 | #pragma unroll 40 | for (int j = 0; j < _N_; j+=4) 41 | { 42 | const float4& r_ = (float4&)(r[j]); 43 | const float4& k_ = (float4&)(k[j]); 44 | const float4& w_ = (float4&)(w[j]); 45 | const float4& u_ = (float4&)(u[j]); 46 | float4& s = (float4&)(state[j]); 47 | float4 x; 48 | 49 | x.x = k_.x * v; 50 | x.y = k_.y * v; 51 | x.z = k_.z * v; 52 | x.w = k_.w * v; 53 | 54 | y += r_.x * ((u2+u_.x) * x.x + s.x); 55 | y += r_.y * ((u2+u_.y) * x.y + s.y); 56 | y += r_.z * ((u2+u_.z) * x.z + s.z); 57 | y += r_.w * ((u2+u_.w) * x.w + s.w); 58 | 59 | s.x = s.x * (w2*w_.x) + x.x; 60 | s.y = s.y * (w2*w_.y) + x.y; 61 | s.z = s.z * (w2*w_.z) + x.z; 62 | s.w = s.w * (w2*w_.w) + x.w; 63 | } 64 | _y[t] = F(y); 65 | } 66 | } 67 | 68 | template 69 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 70 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w1, const float *__restrict__ __w1, const F *__restrict__ _u1, const float *__restrict__ _w2, const float *__restrict__ __w2, const F *__restrict__ _u2, const F *__restrict__ const _gy, 71 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw1, F *__restrict__ const _gu1, F *__restrict__ const _gw2, F *__restrict__ const _gu2) 72 | { 73 | const int b = blockIdx.x / H; 74 | const int h = blockIdx.x % H; 75 | const int i = threadIdx.x; 76 | _w1 += h*_N_; 77 | _u1 += h*_N_; 78 | __w1 += h*_N_; 79 | _w2 += h*_N_; 80 | _u2 += h*_N_; 81 | __w2 += h*_N_; 82 | 83 | __shared__ float w1_[_N_], u1_[_N_], w2_[_N_], u2_[_N_]; 84 | __syncthreads(); 85 | w1_[i] = _w1[i]; 86 | u1_[i] = float(_u1[i]); 87 | w2_[i] = _w2[i]; 88 | u2_[i] = float(_u2[i]); 89 | __syncthreads(); 90 | 91 | const float w1 = w1_[i]; 92 | const float u1 = u1_[i]; 93 | const float ww1 = __w1[i]; 94 | const float w2 = w2_[i]; 95 | const float u2 = u2_[i]; 96 | const float ww2 = __w2[i]; 97 | 98 | __shared__ float v[_N_], r[_N_], k[_N_], gy[_N_]; 99 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; 100 | 101 | float gu1 = 0, gu2 = 0, gw1 = 0, gw2 = 0; 102 | const int t000 = b*T*C + h*_N_ + i; 103 | const int t111 = (b+1)*T*C + h*_N_ + i; 104 | const int t222 = t111 - 2*C; 105 | 106 | for (int t = t000; t < t111; t += C) 107 | { 108 | __syncthreads(); 109 | r[i] = float(_r[t]); 110 | k[i] = float(_k[t]); 111 | v[i] = float(_v[t]); 112 | gy[i] = float(_gy[t]); 113 | __syncthreads(); 114 | 115 | const float ki = k[i]; 116 | const float vi = v[i]; 117 | float gr = 0, gu1_ = 0, gu2_ = 0; 118 | 119 | #pragma unroll 120 | for (int j = 0; j < _N_; j++) 121 | { 122 | float& s = state[j]; 123 | float x = ki * v[j]; 124 | float x2 = vi * k[j]; 125 | 126 | gr += ((u1+u2_[j]) * x + s) * gy[j]; 127 | gu1_ += x * gy[j]; 128 | gu2_ += x2 * r[j]; 129 | 130 | s = s * (w1*w2_[j]) + x; 131 | } 132 | _gr[t] = F(gr); 133 | gu1 += gu1_ * r[i]; 134 | gu2 += gu2_ * gy[i]; 135 | } 136 | _gu1[b*C + h*_N_ + i] = F(gu1); 137 | _gu2[b*C + h*_N_ + i] = F(gu2); 138 | 139 | for (int t = t000; t < t222; t += C) 140 | { 141 | __syncthreads(); 142 | r[i] = float(_r[t + 2*C]); 143 | k[i] = float(_k[t]); 144 | v[i] = float(_v[t]); 145 | gy[i] = float(_gy[t + 2*C]); 146 | __syncthreads(); 147 | 148 | const float ki = k[i]; 149 | const float vi = v[i]; 150 | const float ri = r[i]; 151 | const float gyi = gy[i]; 152 | 153 | #pragma unroll 154 | for (int j = 0; j < _N_; j++) 155 | { 156 | float& sa = saaaa[j]; 157 | float& sb = sbbbb[j]; 158 | float& sc = scccc[j]; 159 | float& sd = sdddd[j]; 160 | float x = ki * v[j]; 161 | float x2 = vi * k[j]; 162 | 163 | float tmp = (w1*w2_[j]) * (sa + x); 164 | sa = tmp; 165 | sb = tmp + (w1*w2_[j]) * sb; 166 | tmp = (w1_[j]*w2) * (sc + x2); 167 | sc = tmp; 168 | sd = tmp + (w1_[j]*w2) * sd; 169 | 170 | gw1 += ri * sb * gy[j]; 171 | gw2 += gyi * sd * r[j]; 172 | } 173 | } 174 | _gw1[b*C + h*_N_ + i] = F(ww1 * gw1); 175 | _gw2[b*C + h*_N_ + i] = F(ww2 * gw2); 176 | 177 | #pragma unroll 178 | for (int j = 0; j < _N_; ++j) { 179 | saaaa[j] = 0; 180 | sbbbb[j] = 0; 181 | } 182 | 183 | for (int t = t111 - C; t >= t000; t -= C) 184 | { 185 | __syncthreads(); 186 | r[i] = float(_r[t]); 187 | k[i] = float(_k[t]); 188 | v[i] = float(_v[t]); 189 | gy[i] = float(_gy[t]); 190 | __syncthreads(); 191 | 192 | const float rr = r[i]; 193 | const float gyy = gy[i]; 194 | float gk = 0, gv = 0; 195 | 196 | #pragma unroll 197 | for (int j = 0; j < _N_; j++) 198 | { 199 | float& s = saaaa[j]; 200 | float& s2 = sbbbb[j]; 201 | float x = rr * gy[j]; 202 | float x2 = gyy * r[j]; 203 | 204 | gk += ((u1+u2_[j]) * x + s) * v[j]; 205 | gv += ((u2+u1_[j]) * x2 + s2) * k[j]; 206 | s = x + s * (w1*w2_[j]); 207 | s2 = x2 + s2 * (w2*w1_[j]); 208 | } 209 | _gk[t] = F(gk); 210 | _gv[t] = F(gv); 211 | } 212 | } 213 | 214 | void cuda_forward(int B, int T, int C, int H, DTYPE *r, DTYPE *k, DTYPE *v, float *w1, DTYPE *u1, float *w2, DTYPE *u2, DTYPE *y) 215 | { 216 | assert(H*_N_ == C); 217 | assert(_N_%4 == 0); 218 | kernel_forward<<>>(B, T, C, H, r, k, v, w1, u1, w2, u2, y); 219 | } 220 | 221 | void cuda_backward(int B, int T, int C, int H, DTYPE *r, DTYPE *k, DTYPE *v, float *w1, float *ww1, DTYPE *u1, float *w2, float *ww2, DTYPE *u2, DTYPE *gy, DTYPE *gr, DTYPE *gk, DTYPE *gv, DTYPE *gw1, DTYPE *gu1, DTYPE *gw2, DTYPE *gu2) 222 | { 223 | assert(H*_N_ == C); 224 | assert(_N_%4 == 0); 225 | kernel_backward<<>>(B, T, C, H, r, k, v, w1, ww1, u1, w2, ww2, u2, gy, gr, gk, gv, gw1, gu1, gw2, gu2); 226 | } 227 | -------------------------------------------------------------------------------- /wkv5a/cuda/wkv5a_cuda_v1a2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | typedef float DTYPE; 6 | 7 | template 8 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 9 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w1, const F *__restrict__ _u1, const float *__restrict__ _w2, const F *__restrict__ _u2, 10 | F *__restrict__ const _y) 11 | { 12 | const int b = blockIdx.x / H; 13 | const int h = blockIdx.x % H; 14 | const int i = threadIdx.x; 15 | _w1 += h*_N_; 16 | _u1 += h*_N_; 17 | const float w2 = _w2[h*_N_+i]; 18 | const float u2 = float(_u2[h*_N_+i]); 19 | 20 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 21 | 22 | __syncthreads(); 23 | w[i] = _w1[i]; 24 | u[i] = float(_u1[i]); 25 | __syncthreads(); 26 | 27 | float state[_N_] = {0}; 28 | 29 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 30 | { 31 | __syncthreads(); 32 | r[i] = float(_r[t]); 33 | k[i] = float(_k[t]); 34 | __syncthreads(); 35 | 36 | const float v = float(_v[t]); 37 | float y = 0; 38 | 39 | #pragma unroll 40 | for (int j = 0; j < _N_; j+=4) 41 | { 42 | const float4& r_ = (float4&)(r[j]); 43 | const float4& k_ = (float4&)(k[j]); 44 | const float4& w_ = (float4&)(w[j]); 45 | const float4& u_ = (float4&)(u[j]); 46 | float4& s = (float4&)(state[j]); 47 | float4 x; 48 | 49 | x.x = k_.x * v; 50 | x.y = k_.y * v; 51 | x.z = k_.z * v; 52 | x.w = k_.w * v; 53 | 54 | y += r_.x * ((u2+u_.x) * x.x + s.x); 55 | y += r_.y * ((u2+u_.y) * x.y + s.y); 56 | y += r_.z * ((u2+u_.z) * x.z + s.z); 57 | y += r_.w * ((u2+u_.w) * x.w + s.w); 58 | 59 | s.x = s.x * (w2*w_.x) + x.x; 60 | s.y = s.y * (w2*w_.y) + x.y; 61 | s.z = s.z * (w2*w_.z) + x.z; 62 | s.w = s.w * (w2*w_.w) + x.w; 63 | } 64 | _y[t] = F(y); 65 | } 66 | } 67 | 68 | template 69 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 70 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w1, const float *__restrict__ __w1, const F *__restrict__ _u1, const float *__restrict__ _w2, const float *__restrict__ __w2, const F *__restrict__ _u2, const F *__restrict__ const _gy, 71 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw1, F *__restrict__ const _gu1, F *__restrict__ const _gw2, F *__restrict__ const _gu2) 72 | { 73 | const int b = blockIdx.x / H; 74 | const int h = blockIdx.x % H; 75 | const int i = threadIdx.x; 76 | _w1 += h*_N_; 77 | _u1 += h*_N_; 78 | __w1 += h*_N_; 79 | _w2 += h*_N_; 80 | _u2 += h*_N_; 81 | __w2 += h*_N_; 82 | 83 | __shared__ float w1_[_N_], u1_[_N_], w2_[_N_], u2_[_N_]; 84 | __syncthreads(); 85 | w1_[i] = _w1[i]; 86 | u1_[i] = float(_u1[i]); 87 | w2_[i] = _w2[i]; 88 | u2_[i] = float(_u2[i]); 89 | __syncthreads(); 90 | 91 | const float w1 = w1_[i]; 92 | const float u1 = u1_[i]; 93 | const float ww1 = __w1[i]; 94 | const float w2 = w2_[i]; 95 | const float u2 = u2_[i]; 96 | const float ww2 = __w2[i]; 97 | 98 | float _w1_[_N_], _u1_[_N_], _w2_[_N_], _u2_[_N_]; 99 | #pragma unroll 100 | for (int j = 0; j < _N_; j++) { 101 | _w1_[j] = w2 * w1_[j]; 102 | _w2_[j] = w1 * w2_[j]; 103 | _u1_[j] = u2 + u1_[j]; 104 | _u2_[j] = u1 + u2_[j]; 105 | } 106 | 107 | __shared__ float v[_N_], r[_N_], k[_N_], gy[_N_]; 108 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; 109 | 110 | float gu1 = 0, gu2 = 0, gw1 = 0, gw2 = 0; 111 | const int t000 = b*T*C + h*_N_ + i; 112 | const int t111 = (b+1)*T*C + h*_N_ + i; 113 | const int t222 = t111 - 2*C; 114 | 115 | for (int t = t000; t < t111; t += C) 116 | { 117 | __syncthreads(); 118 | r[i] = float(_r[t]); 119 | k[i] = float(_k[t]); 120 | v[i] = float(_v[t]); 121 | gy[i] = float(_gy[t]); 122 | __syncthreads(); 123 | 124 | const float ki = k[i]; 125 | const float vi = v[i]; 126 | float gr = 0, gu1_ = 0, gu2_ = 0; 127 | 128 | #pragma unroll 129 | for (int j = 0; j < _N_; j++) 130 | { 131 | float& s = state[j]; 132 | float x = ki * v[j]; 133 | float x2 = vi * k[j]; 134 | 135 | gr += (_u2_[j] * x + s) * gy[j]; 136 | gu1_ += x * gy[j]; 137 | gu2_ += x2 * r[j]; 138 | 139 | s = s * _w2_[j] + x; 140 | } 141 | _gr[t] = F(gr); 142 | gu1 += gu1_ * r[i]; 143 | gu2 += gu2_ * gy[i]; 144 | } 145 | _gu1[b*C + h*_N_ + i] = F(gu1); 146 | _gu2[b*C + h*_N_ + i] = F(gu2); 147 | 148 | for (int t = t000; t < t222; t += C) 149 | { 150 | __syncthreads(); 151 | r[i] = float(_r[t + 2*C]); 152 | k[i] = float(_k[t]); 153 | v[i] = float(_v[t]); 154 | gy[i] = float(_gy[t + 2*C]); 155 | __syncthreads(); 156 | 157 | const float ki = k[i]; 158 | const float vi = v[i]; 159 | const float ri = r[i]; 160 | const float gyi = gy[i]; 161 | 162 | #pragma unroll 163 | for (int j = 0; j < _N_; j++) 164 | { 165 | float& sa = saaaa[j]; 166 | float& sb = sbbbb[j]; 167 | float& sc = scccc[j]; 168 | float& sd = sdddd[j]; 169 | float x = ki * v[j]; 170 | float x2 = vi * k[j]; 171 | 172 | float tmp = _w2_[j] * (sa + x); 173 | sa = tmp; 174 | sb = tmp + _w2_[j] * sb; 175 | tmp = _w1_[j] * (sc + x2); 176 | sc = tmp; 177 | sd = tmp + _w1_[j] * sd; 178 | 179 | gw1 += ri * sb * gy[j]; 180 | gw2 += gyi * sd * r[j]; 181 | } 182 | } 183 | _gw1[b*C + h*_N_ + i] = F(ww1 * gw1); 184 | _gw2[b*C + h*_N_ + i] = F(ww2 * gw2); 185 | 186 | #pragma unroll 187 | for (int j = 0; j < _N_; ++j) { 188 | saaaa[j] = 0; 189 | sbbbb[j] = 0; 190 | } 191 | 192 | for (int t = t111 - C; t >= t000; t -= C) 193 | { 194 | __syncthreads(); 195 | r[i] = float(_r[t]); 196 | k[i] = float(_k[t]); 197 | v[i] = float(_v[t]); 198 | gy[i] = float(_gy[t]); 199 | __syncthreads(); 200 | 201 | const float rr = r[i]; 202 | const float gyy = gy[i]; 203 | float gk = 0, gv = 0; 204 | 205 | #pragma unroll 206 | for (int j = 0; j < _N_; j++) 207 | { 208 | float& s = saaaa[j]; 209 | float& s2 = sbbbb[j]; 210 | float x = rr * gy[j]; 211 | float x2 = gyy * r[j]; 212 | 213 | gk += (_u2_[j] * x + s) * v[j]; 214 | gv += (_u1_[j] * x2 + s2) * k[j]; 215 | s = x + s * _w2_[j]; 216 | s2 = x2 + s2 * _w1_[j]; 217 | } 218 | _gk[t] = F(gk); 219 | _gv[t] = F(gv); 220 | } 221 | } 222 | 223 | void cuda_forward(int B, int T, int C, int H, DTYPE *r, DTYPE *k, DTYPE *v, float *w1, DTYPE *u1, float *w2, DTYPE *u2, DTYPE *y) 224 | { 225 | assert(H*_N_ == C); 226 | assert(_N_%4 == 0); 227 | kernel_forward<<>>(B, T, C, H, r, k, v, w1, u1, w2, u2, y); 228 | } 229 | 230 | void cuda_backward(int B, int T, int C, int H, DTYPE *r, DTYPE *k, DTYPE *v, float *w1, float *ww1, DTYPE *u1, float *w2, float *ww2, DTYPE *u2, DTYPE *gy, DTYPE *gr, DTYPE *gk, DTYPE *gv, DTYPE *gw1, DTYPE *gu1, DTYPE *gw2, DTYPE *gu2) 231 | { 232 | assert(H*_N_ == C); 233 | assert(_N_%4 == 0); 234 | kernel_backward<<>>(B, T, C, H, r, k, v, w1, ww1, u1, w2, ww2, u2, gy, gr, gk, gv, gw1, gu1, gw2, gu2); 235 | } 236 | -------------------------------------------------------------------------------- /wkv6/cuda/wkv6_cuda_v1a.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 9 | F *__restrict__ const _y) 10 | { 11 | const int b = blockIdx.x / H; 12 | const int h = blockIdx.x % H; 13 | const int i = threadIdx.x; 14 | _u += h*_N_; 15 | 16 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 17 | float state[_N_] = {0}; 18 | 19 | __syncthreads(); 20 | u[i] = float(_u[i]); 21 | __syncthreads(); 22 | 23 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 24 | { 25 | __syncthreads(); 26 | w[i] = exp(_w[t]); 27 | r[i] = float(_r[t]); 28 | k[i] = float(_k[t]); 29 | __syncthreads(); 30 | 31 | const float v = float(_v[t]); 32 | float y = 0; 33 | 34 | #pragma unroll 35 | for (int j = 0; j < _N_; j+=4) 36 | { 37 | const float4& r_ = (float4&)(r[j]); 38 | const float4& k_ = (float4&)(k[j]); 39 | const float4& w_ = (float4&)(w[j]); 40 | const float4& u_ = (float4&)(u[j]); 41 | float4& s = (float4&)(state[j]); 42 | float4 x; 43 | 44 | x.x = k_.x * v; 45 | x.y = k_.y * v; 46 | x.z = k_.z * v; 47 | x.w = k_.w * v; 48 | 49 | y += r_.x * (u_.x * x.x + s.x); 50 | y += r_.y * (u_.y * x.y + s.y); 51 | y += r_.z * (u_.z * x.z + s.z); 52 | y += r_.w * (u_.w * x.w + s.w); 53 | 54 | s.x = s.x * w_.x + x.x; 55 | s.y = s.y * w_.y + x.y; 56 | s.z = s.z * w_.z + x.z; 57 | s.w = s.w * w_.w + x.w; 58 | } 59 | _y[t] = F(y); 60 | } 61 | } 62 | 63 | template 64 | __global__ void kernel_backward_111(const int B, const int T, const int C, const int H, 65 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, 66 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu) 67 | { 68 | const int b = blockIdx.x / H; 69 | const int h = blockIdx.x % H; 70 | const int i = threadIdx.x; 71 | _u += h*_N_; 72 | 73 | __shared__ float u_[_N_]; 74 | __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; 75 | __syncthreads(); 76 | u_[i] = float(_u[i]); 77 | __syncthreads(); 78 | 79 | const float u = u_[i]; 80 | 81 | float state[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; 82 | 83 | const int t_0 = b*T*C + h*_N_ + i; 84 | const int t_T_1 = t_0 + (T-1)*C; 85 | const int t_T = t_0 + T*C; 86 | 87 | float gu = 0; 88 | for (int t = t_0; t < t_T; t += C) 89 | { 90 | __syncthreads(); 91 | v[i] = float(_v[t]); 92 | gy[i] = float(_gy[t]); 93 | __syncthreads(); 94 | 95 | const float k = float(_k[t]); 96 | const float w = exp(_w[t]); 97 | float gr = 0, gu_ = 0; 98 | 99 | #pragma unroll 100 | for (int j = 0; j < _N_; j++) 101 | { 102 | float& s = state[j]; 103 | float x = k * v[j]; 104 | 105 | gr += (u * x + s) * gy[j]; 106 | gu_ += x * gy[j]; 107 | s = s * w + x; 108 | } 109 | _gr[t] = F(gr); 110 | gu += float(_r[t]) * gu_; 111 | } 112 | _gu[b*C + h*_N_ + i] = F(gu); 113 | 114 | for (int t = t_T_1; t >= t_0; t -= C) 115 | { 116 | __syncthreads(); 117 | v[i] = float(_v[t]); 118 | gy[i] = float(_gy[t]); 119 | __syncthreads(); 120 | 121 | const float rr = float(_r[t]); 122 | const float w = exp(_w[t]); 123 | float gk = 0; 124 | 125 | #pragma unroll 126 | for (int j = 0; j < _N_; j++) 127 | { 128 | float& s = scccc[j]; 129 | float x = rr * gy[j]; 130 | 131 | gk += (u * x + s) * v[j]; 132 | s = x + s * w; 133 | } 134 | _gk[t] = F(gk); 135 | } 136 | 137 | for (int t = t_T_1; t >= t_0; t -= C) 138 | { 139 | __syncthreads(); 140 | r[i] = float(_r[t]); 141 | k[i] = float(_k[t]); 142 | w_[i] = exp(_w[t]); 143 | __syncthreads(); 144 | 145 | const float gyy = float(_gy[t]); 146 | float gv = 0; 147 | 148 | #pragma unroll 149 | for (int j = 0; j < _N_; j++) 150 | { 151 | float& s = sdddd[j]; 152 | float x = gyy * r[j]; 153 | 154 | gv += (u_[j] * x + s) * k[j]; 155 | s = x + s * w_[j]; 156 | } 157 | _gv[t] = F(gv); 158 | } 159 | } 160 | 161 | template 162 | __global__ void kernel_backward_222(const int B, const int T, const int C, const int H, float *__restrict__ const _buf, 163 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, 164 | F *__restrict__ const _gw) 165 | { 166 | const int b = blockIdx.x; 167 | for (int h = 0; h < H; h++) { 168 | const int i = threadIdx.x; 169 | float *__restrict__ const buf = _buf + b*(_N_*_T_*_N_) + i*(_T_*_N_); 170 | 171 | __shared__ float v[_N_], gy[_N_]; 172 | float saaaa[_N_] = {0}, sbbbb[_N_] = {0}; 173 | 174 | const int t_0 = b*T*C + h*_N_ + i; 175 | const int t_1 = t_0 + C; 176 | const int t_2 = t_0 + 2*C; 177 | const int t_T_2 = t_0 + (T-2)*C; 178 | const int t_T_1 = t_0 + (T-1)*C; 179 | 180 | for (int t = t_0; t < t_T_2; t += C) 181 | { 182 | __syncthreads(); 183 | v[i] = float(_v[t]); 184 | __syncthreads(); 185 | 186 | const float k = float(_k[t]); 187 | const float w = exp(_w[t]); 188 | const int tt = (t-t_0)/C*_N_; 189 | 190 | #pragma unroll 191 | for (int j = 0; j < _N_; j++) 192 | { 193 | float& s = saaaa[j]; 194 | float x = k * v[j]; 195 | 196 | float tmp = w * s + x; 197 | s = tmp; 198 | buf[tt + j] = tmp; 199 | // printf("b %d h %d i %d t %d j %d buf %f\n", b, h, i, tt/_N_, j, tmp); 200 | } 201 | } 202 | 203 | for (int t = t_T_1; t > t_1; t -= C) 204 | { 205 | __syncthreads(); 206 | gy[i] = float(_gy[t]); 207 | __syncthreads(); 208 | 209 | const float r = float(_r[t]); 210 | const float w = exp(_w[t]); 211 | float sum = 0.0f; 212 | const int tt = (t-t_2)/C*_N_; 213 | 214 | #pragma unroll 215 | for (int j = 0; j < _N_; j++) 216 | { 217 | float& s = sbbbb[j]; 218 | float x = r * gy[j]; 219 | 220 | float tmp = w * s + x; 221 | s = tmp; 222 | sum += buf[tt + j] * tmp; 223 | // printf("b %d h %d i %d t %d j %d buf %f tmp %f\n", b, h, i, tt/_N_, j, buf[tt + j], tmp); 224 | } 225 | _gw[t-C] = F(sum * _w[t-C] * exp(_w[t-C])); 226 | } 227 | } 228 | } 229 | 230 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 231 | { 232 | assert(H*_N_ == C); 233 | assert(_N_%4 == 0); 234 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 235 | } 236 | 237 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 238 | { 239 | assert(H*_N_ == C); 240 | assert(_N_%4 == 0); 241 | kernel_backward_111<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gu); 242 | 243 | void* buf = 0; 244 | cudaMalloc(&buf, 4*B*_N_*_T_*_N_); 245 | kernel_backward_222<<>>(B, T, C, H, (float *)(buf), r, k, v, w, u, gy, gw); 246 | cudaFree(buf); 247 | } 248 | -------------------------------------------------------------------------------- /rwkv7_fast_fused/rwkv7_cuda_benchmark.py: -------------------------------------------------------------------------------- 1 | import time, sys 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.cpp_extension import load 6 | 7 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 8 | torch.backends.cudnn.allow_tf32 = False 9 | torch.backends.cuda.matmul.allow_tf32 = False 10 | 11 | ''' 12 | cd /mnt/program/_RWKV_/_ref_/RWKV-CUDA/rwkv7_fast_fused; python rwkv7_cuda_benchmark.py fp32 0; python rwkv7_cuda_benchmark.py fp32 1 13 | cd /mnt/program/_RWKV_/_ref_/RWKV-CUDA/rwkv7_fast_fused; python rwkv7_cuda_benchmark.py bf16 0; python rwkv7_cuda_benchmark.py bf16 1 14 | ''' 15 | print('\n### RWKV7_fused_clamp_w vanilla fwd+bwd kernel ###\n') 16 | 17 | DTYPE = torch.float if sys.argv[1].strip()=='fp32' else torch.bfloat16 18 | BENCHMARK_SPEED = True if int(sys.argv[2].strip()) == 1 else 0 19 | 20 | ###################################################################################################### 21 | 22 | DEVICE = "cuda" 23 | 24 | B, T, CHUNK_LEN, C, N = 2, 64, 16, 32, 16 25 | if BENCHMARK_SPEED: 26 | B, T, CHUNK_LEN, C, N = 8, 4096, 16, 4096, 64 27 | 28 | H, HEAD_SIZE = C // N, N 29 | print(f"\n\nB={B} T={T} C={C} HEAD_SIZE={HEAD_SIZE} DTYPE={str(DTYPE).replace('torch.','')}\n\n") 30 | 31 | def set_seed(seed): 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | 36 | def val(x): 37 | return x.detach().float().cpu().numpy() 38 | 39 | def err_ratio(x, y): 40 | err = (x-y).flatten().square().mean().sqrt().item() 41 | base = (x).flatten().square().mean().sqrt().item() 42 | return err / base 43 | 44 | ###################################################################################################### 45 | 46 | def RWKV7_CLAMPW_REF(r, w, k, v, a, b): 47 | r = r.view(B, T, H, N) 48 | k = k.view(B, T, H, N) 49 | v = v.view(B, T, H, N) 50 | a = a.view(B, T, H, N) 51 | b = b.view(B, T, H, N) 52 | 53 | w = -F.softplus(-w) - 0.5 # soft-clamp, after exp becomes sigmoid in CUDA kernel 54 | w = torch.exp(-torch.exp(w.view(B, T, H, N))) 55 | 56 | out = torch.zeros((B, T, H, N), device=DEVICE) 57 | state = torch.zeros((B, H, N, N), device=DEVICE) 58 | 59 | for t in range(T): 60 | rr = r[:, t, :] 61 | kk = k[:, t, :] 62 | vv = v[:, t, :] 63 | aa = a[:, t, :] 64 | bb = b[:, t, :] 65 | sab = torch.einsum('bhik,bhk,bhj->bhij', state, aa, bb) 66 | state = state * w[: , t, :, None, :] + sab + torch.einsum('bhj,bhi->bhij', kk, vv) 67 | out[:, t, :] = torch.einsum('bhj,bhij->bhi', rr, state) 68 | 69 | return out.view((B, T, C)) 70 | 71 | ###################################################################################################### 72 | 73 | if DTYPE == torch.bfloat16: 74 | flags = ['-res-usage', f'-D_N_={HEAD_SIZE}', f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] 75 | load(name="rwkv7_clampw", sources=[f'cuda/rwkv7_clampw.cu', 'cuda/rwkv7_clampw.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) 76 | class RWKV7_CLAMPW_CUDA_OP(torch.autograd.Function): 77 | @staticmethod 78 | def forward(ctx,r,w,k,v,a,b): 79 | B,T,H,C = r.shape 80 | assert T%CHUNK_LEN == 0 81 | assert all(i.dtype==torch.bfloat16 for i in [r,w,k,v,a,b]) 82 | assert all(i.is_contiguous() for i in [r,w,k,v,a,b]) 83 | y = torch.empty_like(v) 84 | s = torch.empty(B,H,T//CHUNK_LEN,C,C, dtype=torch.float32,device=w.device) 85 | sa = torch.empty(B,T,H,C, dtype=torch.float32,device=w.device) 86 | torch.ops.rwkv7_clampw.forward(r,w,k,v,a,b,y,s,sa) 87 | ctx.save_for_backward(r,w,k,v,a,b,s,sa) 88 | return y 89 | @staticmethod 90 | def backward(ctx,dy): 91 | assert all(i.dtype==torch.bfloat16 for i in [dy]) 92 | assert all(i.is_contiguous() for i in [dy]) 93 | r,w,k,v,a,b,s,sa = ctx.saved_tensors 94 | dr,dw,dk,dv,da,db = [torch.empty_like(x) for x in [r,w,k,v,a,b]] 95 | torch.ops.rwkv7_clampw.backward(r,w,k,v,a,b,dy,s,sa,dr,dw,dk,dv,da,db) 96 | return dr,dw,dk,dv,da,db 97 | def RWKV7_CLAMPW_CUDA(r,w,k,v,a,b): 98 | B,T,HC = r.shape 99 | r,w,k,v,a,b = [i.view(B,T,HC//HEAD_SIZE,HEAD_SIZE) for i in [r,w,k,v,a,b]] 100 | return RWKV7_CLAMPW_CUDA_OP.apply(r,w,k,v,a,b).view(B,T,HC) 101 | 102 | elif DTYPE == torch.float: 103 | flags = ['-res-usage', f'-D_N_={HEAD_SIZE}', "-D_FP32_", f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] 104 | load(name="rwkv7_clampw", sources=[f'cuda/rwkv7_clampw.cu', 'cuda/rwkv7_clampw.cpp'], is_python_module=False, verbose=True, extra_cflags=["-D_FP32_"], extra_cuda_cflags=flags) 105 | class RWKV7_CLAMPW_CUDA_OP(torch.autograd.Function): 106 | @staticmethod 107 | def forward(ctx,r,w,k,v,a,b): 108 | B,T,H,C = r.shape 109 | assert T%CHUNK_LEN == 0 110 | assert all(i.dtype==torch.float32 for i in [r,w,k,v,a,b]) 111 | assert all(i.is_contiguous() for i in [r,w,k,v,a,b]) 112 | y = torch.empty_like(v) 113 | s = torch.empty(B,H,T//CHUNK_LEN,C,C, dtype=torch.float32,device=w.device) 114 | sa = torch.empty(B,T,H,C, dtype=torch.float32,device=w.device) 115 | torch.ops.rwkv7_clampw.forward(r,w,k,v,a,b,y,s,sa) 116 | ctx.save_for_backward(r,w,k,v,a,b,s,sa) 117 | return y 118 | @staticmethod 119 | def backward(ctx,dy): 120 | assert all(i.dtype==torch.float32 for i in [dy]) 121 | assert all(i.is_contiguous() for i in [dy]) 122 | r,w,k,v,a,b,s,sa = ctx.saved_tensors 123 | dr,dw,dk,dv,da,db = [torch.empty_like(x) for x in [r,w,k,v,a,b]] 124 | torch.ops.rwkv7_clampw.backward(r,w,k,v,a,b,dy,s,sa,dr,dw,dk,dv,da,db) 125 | return dr,dw,dk,dv,da,db 126 | def RWKV7_CLAMPW_CUDA(r,w,k,v,a,b): 127 | B,T,HC = r.shape 128 | r,w,k,v,a,b = [i.view(B,T,HC//HEAD_SIZE,HEAD_SIZE) for i in [r,w,k,v,a,b]] 129 | return RWKV7_CLAMPW_CUDA_OP.apply(r,w,k,v,a,b).view(B,T,HC) 130 | 131 | ###################################################################################################### 132 | 133 | with torch.no_grad(): 134 | r = torch.empty(B, T, C, device=DEVICE).uniform_(-1, 1) * 3 135 | w = torch.empty(B, T, C, device=DEVICE).uniform_(-6, 0) 136 | k = torch.empty(B, T, C, device=DEVICE).uniform_(-1, 1) * 3 137 | v = torch.empty(B, T, C, device=DEVICE).uniform_(-1, 1) * 3 138 | a = torch.empty(B, T, C, device=DEVICE).uniform_(-1, 1) * 2 139 | b = torch.empty(B, T, C, device=DEVICE).uniform_(-1, 1) * 2 140 | a = F.normalize(a, dim=-1, p=2.0) 141 | b = F.normalize(b, dim=-1, p=2.0) 142 | 143 | params = (r,w,k,v,a,b) 144 | 145 | def LOSS(y): 146 | return ((y * y) - torch.tanh(y)).sum() 147 | 148 | def clear_grad(): 149 | for t in params: 150 | t.requires_grad_(True) 151 | if t.grad is not None: 152 | t.grad.zero_() 153 | 154 | ###################################################################################################### 155 | 156 | if not BENCHMARK_SPEED: 157 | clear_grad() 158 | y = RWKV7_CLAMPW_REF(*params) 159 | LOSS(y).backward() 160 | grad_ref = [t.grad.detach().clone() for t in params] 161 | 162 | clear_grad() 163 | if DTYPE == torch.float: 164 | y_cuda = RWKV7_CLAMPW_CUDA(*params) 165 | else: 166 | rr,ww,kk,vv,aa,bb = r.bfloat16(),w.bfloat16(),k.bfloat16(),v.bfloat16(),a.bfloat16(),b.bfloat16() 167 | y_cuda = RWKV7_CLAMPW_CUDA(rr,ww,kk,vv,aa,bb).float() 168 | LOSS(y_cuda).backward() 169 | grad_cuda = [t.grad.detach().clone() for t in params] 170 | 171 | print('!!! y err !!!', err_ratio(y, y_cuda)) 172 | for name, g_ref, g_cuda in zip('rwkvab', grad_ref, grad_cuda): 173 | print(f'!!! g_{name} err !!!', err_ratio(g_ref, g_cuda)) 174 | 175 | else: 176 | print('benchmark speed...') 177 | repeats = 10 178 | fwd_times = [] 179 | bwd_times = [] 180 | for _ in range(repeats): 181 | clear_grad() 182 | 183 | if DTYPE == torch.float: 184 | torch.cuda.synchronize(); t0 = time.perf_counter() 185 | y_cuda = RWKV7_CLAMPW_CUDA(*params) 186 | torch.cuda.synchronize(); fwd_times.append(time.perf_counter() - t0) 187 | 188 | torch.cuda.synchronize(); t0 = time.perf_counter() 189 | LOSS(y_cuda).backward() 190 | torch.cuda.synchronize(); bwd_times.append(time.perf_counter() - t0) 191 | else: 192 | rr,ww,kk,vv,aa,bb = r.bfloat16(),w.bfloat16(),k.bfloat16(),v.bfloat16(),a.bfloat16(),b.bfloat16() 193 | torch.cuda.synchronize(); t0 = time.perf_counter() 194 | y_cuda = RWKV7_CLAMPW_CUDA(rr,ww,kk,vv,aa,bb) 195 | torch.cuda.synchronize(); fwd_times.append(time.perf_counter() - t0) 196 | 197 | torch.cuda.synchronize(); t0 = time.perf_counter() 198 | LOSS(y_cuda).backward() 199 | torch.cuda.synchronize(); bwd_times.append(time.perf_counter() - t0) 200 | 201 | print('fwd time =', min(fwd_times)) 202 | print('bwd time =', min(bwd_times)) 203 | -------------------------------------------------------------------------------- /wkv5_bf16/cuda/wkv5_cuda_v3.cu: -------------------------------------------------------------------------------- 1 | // Forward Origin Author: Bleatan 2 | #include 3 | #include 4 | #include 5 | 6 | typedef at::BFloat16 bf16; 7 | 8 | constexpr int N = _N_; 9 | constexpr int PARALLEL_SCAN_BLOCK = 512; 10 | constexpr int TILE_T = 16; 11 | 12 | __global__ void kernel_forward_accumulate_state(const int nblocks_per_sample, 13 | float *__restrict__ state, 14 | const float *__restrict__ w) { 15 | const int b = blockIdx.x, B = gridDim.x; 16 | const int h = blockIdx.y / N, H = gridDim.y / N, HNN = H * N * N; 17 | const int i = blockIdx.y % N; 18 | const int j = threadIdx.x; 19 | 20 | float wp = powf(w[h * N + i], PARALLEL_SCAN_BLOCK); 21 | state += (((b * nblocks_per_sample * H + h) * N) + j) * N + i; 22 | float s = 0; 23 | 24 | for (int block = 0; block < nblocks_per_sample; ++block) { 25 | const float next_s = s * wp + *state; 26 | *state = s; 27 | s = next_s; 28 | state += HNN; 29 | } 30 | } 31 | 32 | template 33 | __global__ void 34 | kernel_forward(const int T, const F *__restrict__ const _r, 35 | const F *__restrict__ const _k, const F *__restrict__ const _v, 36 | const float *__restrict__ _w, const F *__restrict__ _u, 37 | F *__restrict__ const _y, float *__restrict__ state) { 38 | const int b = blockIdx.x, B = gridDim.x; 39 | if (!DoY && b == B - 1) 40 | return; 41 | 42 | const int h = blockIdx.y, H = gridDim.y; 43 | const int i = threadIdx.x; 44 | _w += h * N; 45 | _u += h * N; 46 | 47 | __shared__ float rr[TILE_T][N], kk[TILE_T][N]; 48 | 49 | state += (b * H + h) * N * N; 50 | 51 | if constexpr (!DoY) 52 | for (int j = 0; j < N; ++j) 53 | state[j * N + i] = 0; 54 | 55 | for (int _t = b * T * H * N + h * N + i, 56 | _tend = (b + 1) * T * H * N + h * N + i; 57 | _t < _tend; _t += TILE_T * H * N) { 58 | float yy[TILE_T], vv[TILE_T]; 59 | #pragma unroll(TILE_T) 60 | for (int tt = 0, t = _t; tt < TILE_T; (++tt), (t += H * N)) { 61 | vv[tt] = (float)_v[t]; 62 | if constexpr (DoY) { 63 | yy[tt] = 0.f; 64 | rr[tt][i] = (float)_r[t]; 65 | } 66 | kk[tt][i] = (float)_k[t]; 67 | } 68 | 69 | __syncthreads(); 70 | 71 | for (int j = 0; j < N; j += 1) { 72 | float s = state[j * N + i]; 73 | float w = _w[j]; 74 | float u = (float)_u[j]; 75 | 76 | #pragma unroll(TILE_T) 77 | for (int tt = 0, t = _t; tt < TILE_T; (++tt), (t += H * N)) { 78 | float x = kk[tt][j] * vv[tt]; 79 | if constexpr (DoY) { 80 | float yyy = rr[tt][j] * (u * x + s); 81 | yy[tt] += yyy; 82 | } 83 | s = s * w + x; 84 | } 85 | state[j * N + i] = s; 86 | } 87 | 88 | if constexpr (DoY) 89 | #pragma unroll(TILE_T) 90 | for (int tt = 0, t = _t; tt < TILE_T; (++tt), (t += H * N)) 91 | _y[t] = (F)yy[tt]; 92 | 93 | __syncthreads(); 94 | } 95 | } 96 | 97 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, 98 | torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, 99 | torch::Tensor &u, torch::Tensor &y) { 100 | assert(H * N == C); 101 | assert(T % PARALLEL_SCAN_BLOCK == 0); 102 | const int blocksz = PARALLEL_SCAN_BLOCK; 103 | const int nblocks_per_sample = T / blocksz; 104 | const int nblocks = B * nblocks_per_sample; 105 | cudaEvent_t events[5]; 106 | for (int i = 0; i < 5; ++i) 107 | cudaEventCreate(&events[i]); 108 | 109 | cudaEventRecord(events[0]); 110 | torch::Tensor states = 111 | torch::empty({B, nblocks_per_sample, H, N, N}, w.options()); 112 | cudaEventRecord(events[1]); 113 | kernel_forward<<>>( 114 | blocksz, r.data_ptr(), k.data_ptr(), v.data_ptr(), 115 | w.data_ptr(), u.data_ptr(), y.data_ptr(), 116 | states.data_ptr()); 117 | cudaEventRecord(events[2]); 118 | kernel_forward_accumulate_state<<>>( 119 | nblocks_per_sample, states.data_ptr(), w.data_ptr()); 120 | cudaEventRecord(events[3]); 121 | kernel_forward<<>>( 122 | blocksz, r.data_ptr(), k.data_ptr(), v.data_ptr(), 123 | w.data_ptr(), u.data_ptr(), y.data_ptr(), 124 | states.data_ptr()); 125 | cudaEventRecord(events[4]); 126 | 127 | cudaEventSynchronize(events[4]); 128 | } 129 | 130 | template 131 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 132 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, 133 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 134 | { 135 | const int b = blockIdx.x / H; 136 | const int h = blockIdx.x % H; 137 | const int i = threadIdx.x; 138 | _w += h*_N_; 139 | _u += h*_N_; 140 | __w += h*_N_; 141 | const float w = _w[i]; 142 | const float u = float(_u[i]); 143 | const float ww = __w[i]; 144 | 145 | __shared__ float v[_N_], r[_N_], k[_N_], gy[_N_], gy2[_N_], w_[_N_], u_[_N_]; 146 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}; 147 | 148 | float gw = 0, gu = 0; 149 | const int t000 = b*T*C + h*_N_ + i; 150 | const int t111 = (b+1)*T*C + h*_N_ + i; 151 | const int t222 = t111 - 2*C; 152 | 153 | for (int _t = t000; _t < t111; _t += C) 154 | { 155 | __syncthreads(); 156 | v[i] = float(_v[_t]); 157 | gy[i] = float(_gy[_t]); 158 | __syncthreads(); 159 | 160 | const float k = float(_k[_t]); 161 | const float r = float(_r[_t]); 162 | 163 | float gr = 0; 164 | 165 | #pragma unroll 166 | for (int j = 0; j < _N_; j++) 167 | { 168 | float x = v[j] * k; 169 | float s = state[j]; 170 | state[j] = s * w + x; 171 | 172 | gr += gy[j] * (u * x + s); 173 | gu += r * x * gy[j]; 174 | } 175 | _gr[_t] = F(gr); 176 | } 177 | _gu[b*C + h*_N_ + i] = F(gu); 178 | 179 | for (int _t = t000; _t < t222; _t += C) 180 | { 181 | __syncthreads(); 182 | v[i] = float(_v[_t]); 183 | gy2[i] = float(_gy[_t + 2*C]); 184 | __syncthreads(); 185 | const float r2 = float(_r[_t + 2*C]); 186 | const float k = float(_k[_t]); 187 | 188 | #pragma unroll 189 | for (int j = 0; j < _N_; j++) 190 | { 191 | float x = v[j] * k; 192 | saaaa[j] = w * (saaaa[j] + sbbbb[j] + x); 193 | sbbbb[j] = w * (sbbbb[j] + x); 194 | 195 | gw += r2 * ww * saaaa[j] * gy2[j]; 196 | } 197 | } 198 | 199 | _gw[b*C + h*_N_ + i] = F(gw); 200 | 201 | #pragma unroll 202 | for (int j = 0; j < _N_; ++j) { 203 | saaaa[j] = 0; 204 | sbbbb[j] = 0; 205 | } 206 | 207 | __syncthreads(); 208 | w_[i] = float(_w[i]); 209 | u_[i] = float(_u[i]); 210 | __syncthreads(); 211 | 212 | for (int _t = t111 - C; _t >= t000; _t -= C) 213 | { 214 | __syncthreads(); 215 | v[i] = float(_v[_t]); 216 | gy[i] = float(_gy[_t]); 217 | k[i] = float(_k[_t]); 218 | r[i] = float(_r[_t]); 219 | __syncthreads(); 220 | 221 | float gk = 0, gv = 0, x, s; 222 | 223 | #pragma unroll 224 | for (int j = 0; j < _N_; j++) 225 | { 226 | x = gy[j] * r[i]; 227 | s = saaaa[j]; 228 | saaaa[j] = s * w + x; 229 | gk += v[j] * (u * x + s); 230 | 231 | x = gy[i] * r[j]; 232 | s = sbbbb[j]; 233 | sbbbb[j] = s * w_[j] + x; 234 | gv += k[j] * (u_[j] * x + s); 235 | } 236 | _gk[_t] = F(gk); 237 | _gv[_t] = F(gv); 238 | } 239 | } 240 | 241 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 242 | { 243 | assert(H*_N_ == C); 244 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); 245 | } 246 | 247 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 248 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); 249 | } 250 | 251 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 252 | m.def("forward", &forward, "wkv5 forward"); 253 | m.def("backward", &backward, "wkv5 backward"); 254 | } 255 | 256 | TORCH_LIBRARY(wkv5, m) { 257 | m.def("forward", forward); 258 | m.def("backward", backward); 259 | } 260 | -------------------------------------------------------------------------------- /rwkv7_fast_fused/rwkv7_cuda_benchmark_state.py: -------------------------------------------------------------------------------- 1 | import time, sys 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.cpp_extension import load 6 | 7 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 8 | torch.backends.cudnn.allow_tf32 = False 9 | torch.backends.cuda.matmul.allow_tf32 = False 10 | 11 | ''' 12 | cd /mnt/program/_RWKV_/_ref_/RWKV-CUDA/rwkv7_fast_fused; python rwkv7_cuda_benchmark_state.py fp32 0; python rwkv7_cuda_benchmark_state.py fp32 1 13 | cd /mnt/program/_RWKV_/_ref_/RWKV-CUDA/rwkv7_fast_fused; python rwkv7_cuda_benchmark_state.py bf16 0; python rwkv7_cuda_benchmark_state.py bf16 1 14 | ''' 15 | print('\n### RWKV7_fused_clamp_w state-tuning fwd+bwd kernel ###\n') 16 | 17 | DTYPE = torch.float if sys.argv[1].strip()=='fp32' else torch.bfloat16 18 | BENCHMARK_SPEED = True if int(sys.argv[2].strip()) == 1 else 0 19 | 20 | ###################################################################################################### 21 | 22 | DEVICE = "cuda" 23 | 24 | B, T, CHUNK_LEN, C, N = 2, 64, 16, 32, 16 25 | if BENCHMARK_SPEED: 26 | B, T, CHUNK_LEN, C, N = 8, 4096, 16, 4096, 64 27 | 28 | H, HEAD_SIZE = C // N, N 29 | print(f"\n\nB={B} T={T} C={C} HEAD_SIZE={HEAD_SIZE} DTYPE={str(DTYPE).replace('torch.','')}\n\n") 30 | 31 | def set_seed(seed): 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | 36 | def val(x): 37 | return x.detach().float().cpu().numpy() 38 | 39 | def err_ratio(x, y): 40 | err = (x-y).flatten().square().mean().sqrt().item() 41 | base = (x).flatten().square().mean().sqrt().item() 42 | return err / base 43 | 44 | ###################################################################################################### 45 | 46 | def RWKV7_STATE_CLAMPW_REF(state, r, w, k, v, a, b): 47 | r = r.view(B, T, H, N) 48 | k = k.view(B, T, H, N) 49 | v = v.view(B, T, H, N) 50 | a = a.view(B, T, H, N) 51 | b = b.view(B, T, H, N) 52 | 53 | w = -F.softplus(-w) - 0.5 # soft-clamp, after exp becomes sigmoid in CUDA kernel 54 | w = torch.exp(-torch.exp(w.view(B, T, H, N))) 55 | 56 | out = torch.zeros((B, T, H, N), device=DEVICE) 57 | 58 | for t in range(T): 59 | rr = r[:, t, :] 60 | kk = k[:, t, :] 61 | vv = v[:, t, :] 62 | aa = a[:, t, :] 63 | bb = b[:, t, :] 64 | sab = torch.einsum('bhik,bhk,bhj->bhij', state, aa, bb) 65 | state = state * w[: , t, :, None, :] + sab + torch.einsum('bhj,bhi->bhij', kk, vv) 66 | out[:, t, :] = torch.einsum('bhj,bhij->bhi', rr, state) 67 | 68 | return out.view((B, T, C)) 69 | 70 | ###################################################################################################### 71 | 72 | if DTYPE == torch.bfloat16: 73 | flags = ['-res-usage', f'-D_N_={HEAD_SIZE}', f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] 74 | load(name="rwkv7_state_clampw", sources=[f'cuda/rwkv7_state_clampw.cu', 'cuda/rwkv7_state_clampw.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) 75 | class RWKV7_STATE_CLAMPW_CUDA_OP(torch.autograd.Function): 76 | @staticmethod 77 | def forward(ctx,s0,r,w,k,v,a,b): 78 | B,T,H,N = r.shape 79 | assert T%CHUNK_LEN == 0 80 | assert all(i.dtype==torch.bfloat16 for i in [r,w,k,v,a,b]) 81 | assert all(i.is_contiguous() for i in [s0,r,w,k,v,a,b]) 82 | assert s0.dtype==torch.float 83 | y = torch.empty_like(v) 84 | s = torch.empty(B,H,T//CHUNK_LEN,N,N, dtype=torch.float32,device=w.device) 85 | sa = torch.empty(B,T,H,N,dtype=torch.float32,device=w.device) 86 | torch.ops.rwkv7_state_clampw.forward(s0,r,w,k,v,a,b,y,s,sa) 87 | ctx.save_for_backward(r,w,k,v,a,b,s,sa) 88 | return y 89 | @staticmethod 90 | def backward(ctx,dy): 91 | assert all(i.dtype==torch.bfloat16 for i in [dy]) 92 | assert all(i.is_contiguous() for i in [dy]) 93 | r,w,k,v,a,b,s,sa = ctx.saved_tensors 94 | B,T,H,N = r.shape 95 | dr,dw,dk,dv,da,db = [torch.empty_like(x) for x in [r,w,k,v,a,b]] 96 | ds0 = torch.empty(B,H,N,N,dtype=torch.float32,device=r.device) 97 | torch.ops.rwkv7_state_clampw.backward(r,w,k,v,a,b,dy,s,sa,ds0,dr,dw,dk,dv,da,db) 98 | return ds0,dr,dw,dk,dv,da,db 99 | def RWKV7_STATE_CLAMPW_CUDA(s0,r,w,k,v,a,b): 100 | B,T,HC = r.shape 101 | r,w,k,v,a,b = [i.view(B,T,HC//HEAD_SIZE,HEAD_SIZE) for i in [r,w,k,v,a,b]] 102 | return RWKV7_STATE_CLAMPW_CUDA_OP.apply(s0,r,w,k,v,a,b).view(B,T,HC) 103 | 104 | elif DTYPE == torch.float: 105 | flags = ['-res-usage', f'-D_N_={HEAD_SIZE}', "-D_FP32_", f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] 106 | load(name="rwkv7_state_clampw", sources=[f'cuda/rwkv7_state_clampw.cu', 'cuda/rwkv7_state_clampw.cpp'], is_python_module=False, verbose=True, extra_cflags=["-D_FP32_"], extra_cuda_cflags=flags) 107 | class RWKV7_STATE_CLAMPW_CUDA_OP(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx,s0,r,w,k,v,a,b): 110 | B,T,H,C = r.shape 111 | assert T%CHUNK_LEN == 0 112 | assert all(i.dtype==torch.float32 for i in [s0,r,w,k,v,a,b]) 113 | assert all(i.is_contiguous() for i in [s0,r,w,k,v,a,b]) 114 | y = torch.empty_like(v) 115 | s = torch.empty(B,H,T//CHUNK_LEN,C,C, dtype=torch.float32,device=w.device) 116 | sa = torch.empty(B,T,H,C, dtype=torch.float32,device=w.device) 117 | torch.ops.rwkv7_state_clampw.forward(s0,r,w,k,v,a,b,y,s,sa) 118 | ctx.save_for_backward(r,w,k,v,a,b,s,sa) 119 | return y 120 | @staticmethod 121 | def backward(ctx,dy): 122 | assert all(i.dtype==torch.float32 for i in [dy]) 123 | assert all(i.is_contiguous() for i in [dy]) 124 | r,w,k,v,a,b,s,sa = ctx.saved_tensors 125 | dr,dw,dk,dv,da,db = [torch.empty_like(x) for x in [r,w,k,v,a,b]] 126 | ds0 = torch.empty(B,H,N,N,dtype=torch.float32,device=r.device) 127 | torch.ops.rwkv7_state_clampw.backward(r,w,k,v,a,b,dy,s,sa,ds0,dr,dw,dk,dv,da,db) 128 | return ds0,dr,dw,dk,dv,da,db 129 | def RWKV7_STATE_CLAMPW_CUDA(s0,r,w,k,v,a,b): 130 | B,T,HC = r.shape 131 | r,w,k,v,a,b = [i.view(B,T,HC//HEAD_SIZE,HEAD_SIZE) for i in [r,w,k,v,a,b]] 132 | return RWKV7_STATE_CLAMPW_CUDA_OP.apply(s0,r,w,k,v,a,b).view(B,T,HC) 133 | 134 | ###################################################################################################### 135 | 136 | with torch.no_grad(): 137 | r = torch.empty(B, T, C, device=DEVICE).uniform_(-1, 1) * 3 138 | w = torch.empty(B, T, C, device=DEVICE).uniform_(-6, 0) 139 | k = torch.empty(B, T, C, device=DEVICE).uniform_(-1, 1) * 3 140 | v = torch.empty(B, T, C, device=DEVICE).uniform_(-1, 1) * 3 141 | a = torch.empty(B, T, C, device=DEVICE).uniform_(-1, 1) * 2 142 | b = torch.empty(B, T, C, device=DEVICE).uniform_(-1, 1) * 2 143 | a = F.normalize(a, dim=-1, p=2.0) 144 | b = F.normalize(b, dim=-1, p=2.0) 145 | s = torch.empty(B, H, N, N, device=DEVICE).uniform_(-1, 1) * 10 146 | 147 | params = (s,r,w,k,v,a,b) 148 | 149 | def LOSS(y): 150 | return ((y * y) - torch.tanh(y)).sum() 151 | 152 | def clear_grad(): 153 | for t in params: 154 | t.requires_grad_(True) 155 | if t.grad is not None: 156 | t.grad.zero_() 157 | 158 | ###################################################################################################### 159 | 160 | if not BENCHMARK_SPEED: 161 | clear_grad() 162 | y = RWKV7_STATE_CLAMPW_REF(*params) 163 | LOSS(y).backward() 164 | grad_ref = [t.grad.detach().clone() for t in params] 165 | 166 | clear_grad() 167 | if DTYPE == torch.float: 168 | y_cuda = RWKV7_STATE_CLAMPW_CUDA(*params) 169 | else: 170 | ss,rr,ww,kk,vv,aa,bb = s,r.bfloat16(),w.bfloat16(),k.bfloat16(),v.bfloat16(),a.bfloat16(),b.bfloat16() 171 | y_cuda = RWKV7_STATE_CLAMPW_CUDA(ss,rr,ww,kk,vv,aa,bb).float() 172 | LOSS(y_cuda).backward() 173 | grad_cuda = [t.grad.detach().clone() for t in params] 174 | 175 | print('!!! y err !!!', err_ratio(y, y_cuda)) 176 | for name, g_ref, g_cuda in zip('srwkvab', grad_ref, grad_cuda): 177 | print(f'!!! g_{name} err !!!', err_ratio(g_ref, g_cuda)) 178 | 179 | else: 180 | print('benchmark speed...') 181 | repeats = 10 182 | fwd_times = [] 183 | bwd_times = [] 184 | for _ in range(repeats): 185 | clear_grad() 186 | 187 | if DTYPE == torch.float: 188 | torch.cuda.synchronize(); t0 = time.perf_counter() 189 | y_cuda = RWKV7_STATE_CLAMPW_CUDA(*params) 190 | torch.cuda.synchronize(); fwd_times.append(time.perf_counter() - t0) 191 | 192 | torch.cuda.synchronize(); t0 = time.perf_counter() 193 | LOSS(y_cuda).backward() 194 | torch.cuda.synchronize(); bwd_times.append(time.perf_counter() - t0) 195 | else: 196 | ss,rr,ww,kk,vv,aa,bb = s,r.bfloat16(),w.bfloat16(),k.bfloat16(),v.bfloat16(),a.bfloat16(),b.bfloat16() 197 | torch.cuda.synchronize(); t0 = time.perf_counter() 198 | y_cuda = RWKV7_STATE_CLAMPW_CUDA(ss,rr,ww,kk,vv,aa,bb) 199 | torch.cuda.synchronize(); fwd_times.append(time.perf_counter() - t0) 200 | 201 | torch.cuda.synchronize(); t0 = time.perf_counter() 202 | LOSS(y_cuda).backward() 203 | torch.cuda.synchronize(); bwd_times.append(time.perf_counter() - t0) 204 | 205 | print('fwd time =', min(fwd_times)) 206 | print('bwd time =', min(bwd_times)) 207 | --------------------------------------------------------------------------------