├── README.md
├── depthwise_conv1d
├── cuda
│ ├── timex_op.cpp
│ ├── timex_cuda_v0.cu
│ ├── timex_cuda_v1.cu
│ ├── timex_cuda_v2.cu
│ └── timex_cuda_v3.cu
└── run.py
├── wkv
└── cuda
│ ├── wkv_op.cpp
│ ├── wkv_cuda_v1.cu
│ ├── wkv_cuda_v2.cu
│ └── wkv_cuda_v0.cu
├── wkv6
└── cuda
│ ├── wkv6_op.cpp
│ ├── wkv5_op.cpp
│ ├── wkv5_cuda_v1b2.cu
│ ├── wkv6_cuda_v1.cu
│ └── wkv6_cuda_v1a.cu
├── wkv5
└── cuda
│ ├── wkv5_op.cpp
│ ├── wkv5_ref.cpp
│ ├── wkv5_cuda_v1c.cu
│ ├── wkv5_cuda_v1a.cu
│ ├── wkv5_cuda_v1d.cu
│ ├── wkv5_cuda_v3.cu
│ ├── wkv5_cuda_v1b.cu
│ ├── wkv5_cuda_ref.cu
│ ├── wkv5_cuda_v1.cu
│ ├── wkv5_cuda_v2.cu
│ └── wkv5_cuda_v1e.cu
├── wkv5_bf16
└── cuda
│ ├── wkv5_op.cpp
│ ├── wkv5_cuda_v1a.cu
│ ├── wkv5_cuda_v1.cu
│ ├── wkv5_cuda_v1b.cu
│ ├── wkv5_cuda_v1b2.cu
│ ├── wkv5_cuda_v2.cu
│ └── wkv5_cuda_v3.cu
├── wkv6_state
└── cuda
│ └── wkv6state_op.cpp
├── rwkv7_fast_fused
├── cuda
│ ├── rwkv7_clampw.cpp
│ ├── rwkv7_state_clampw.cpp
│ ├── rwkv7_statepassing_clampw.cpp
│ ├── rwkv7_clampw.cu
│ ├── rwkv7_state_clampw.cu
│ └── rwkv7_statepassing_clampw.cu
├── rwkv7_cuda_benchmark.py
└── rwkv7_cuda_benchmark_state.py
├── wkv5a
└── cuda
│ ├── wkv5a_op.cpp
│ ├── wkv5a_cuda_v1.cu
│ ├── wkv5a_cuda_v1a.cu
│ └── wkv5a_cuda_v1a2.cu
└── .gitignore
/README.md:
--------------------------------------------------------------------------------
1 | # RWKV-CUDA
2 | The CUDA version of the RWKV language model ( https://github.com/BlinkDL/RWKV-LM )
3 |
4 | ## Towards RWKV-4 (see the wkv folder)
5 |
6 | I have a basic RWKV-4 kernel in the wkv folder. Let's optimize it.
7 |
8 |
9 |
10 | ## Experiment 1 - depthwise_conv1d - 20x faster than pytorch
11 |
12 | The formula:
13 | ```
14 | w.shape = (C, T)
15 | k.shape = (B, C, T)
16 | out.shape = (B, C, T)
17 | out[b][c][t] = sum_u{ w[c][(T-1)-(t-u)] * k[b][c][u] }
18 | ```
19 |
20 | pytorch = fwd 94ms bwd 529ms
21 |
22 | CUDA kernel v0 = fwd 45ms bwd 84ms (simple)
23 |
24 | CUDA kernel v1 = fwd 17ms bwd 43ms (shared memory)
25 |
26 | CUDA kernel v2 = fwd 13ms bwd 31ms (float4)
27 |
28 | CUDA kernel v3 = fwd 3.4ms bwd 23ms (B-group)
29 |
30 | More test on RTX3090:
31 |
32 | pytorch = fwd 14ms bwd 65ms
33 |
34 | CUDA kernel v3 = fwd 0.8ms bwd 5.5ms
35 |
36 | How to use: ```python run.py``` and it will compile everything for you (```pip install Ninja``` if you don't have it).
37 |
--------------------------------------------------------------------------------
/depthwise_conv1d/cuda/timex_op.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T);
4 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T);
5 |
6 | void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) {
7 | cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T);
8 | }
9 | void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) {
10 | cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T);
11 | }
12 |
13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14 | m.def("forward", &forward, "timex forward");
15 | m.def("backward", &backward, "timex backward");
16 | }
17 |
18 | TORCH_LIBRARY(timex, m) {
19 | m.def("forward", forward);
20 | m.def("backward", backward);
21 | }
22 |
--------------------------------------------------------------------------------
/wkv/cuda/wkv_op.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
4 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
5 |
6 | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
7 | cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr());
8 | }
9 | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
10 | cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr());
11 | }
12 |
13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14 | m.def("forward", &forward, "wkv forward");
15 | m.def("backward", &backward, "wkv backward");
16 | }
17 |
18 | TORCH_LIBRARY(wkv, m) {
19 | m.def("forward", forward);
20 | m.def("backward", backward);
21 | }
22 |
--------------------------------------------------------------------------------
/wkv6/cuda/wkv6_op.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "ATen/ATen.h"
3 | typedef at::BFloat16 bf16;
4 |
5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
7 |
8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr());
10 | }
11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr());
13 | }
14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
15 | m.def("forward", &forward, "wkv6 forward");
16 | m.def("backward", &backward, "wkv6 backward");
17 | }
18 |
19 | TORCH_LIBRARY(wkv6, m) {
20 | m.def("forward", forward);
21 | m.def("backward", backward);
22 | }
23 |
--------------------------------------------------------------------------------
/wkv5/cuda/wkv5_op.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y);
4 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *ww, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu);
5 |
6 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
7 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr());
8 | }
9 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
10 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr());
11 | }
12 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
13 | m.def("forward", &forward, "wkv5 forward");
14 | m.def("backward", &backward, "wkv5 backward");
15 | }
16 |
17 | TORCH_LIBRARY(wkv5, m) {
18 | m.def("forward", forward);
19 | m.def("backward", backward);
20 | }
21 |
--------------------------------------------------------------------------------
/wkv5/cuda/wkv5_ref.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y);
4 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *ww, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu);
5 |
6 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
7 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr());
8 | }
9 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
10 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr());
11 | }
12 |
13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14 | m.def("forward", &forward, "wkv5_ref forward");
15 | m.def("backward", &backward, "wkv5_ref backward");
16 | }
17 |
18 | TORCH_LIBRARY(wkv5_ref, m) {
19 | m.def("forward", forward);
20 | m.def("backward", backward);
21 | }
22 |
--------------------------------------------------------------------------------
/wkv6/cuda/wkv5_op.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "ATen/ATen.h"
3 | typedef at::BFloat16 bf16;
4 |
5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
7 |
8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr());
10 | }
11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr());
13 | }
14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
15 | m.def("forward", &forward, "wkv5 forward");
16 | m.def("backward", &backward, "wkv5 backward");
17 | }
18 |
19 | TORCH_LIBRARY(wkv5, m) {
20 | m.def("forward", forward);
21 | m.def("backward", backward);
22 | }
23 |
--------------------------------------------------------------------------------
/wkv5_bf16/cuda/wkv5_op.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "ATen/ATen.h"
3 | typedef at::BFloat16 bf16;
4 |
5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
7 |
8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr());
10 | }
11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr());
13 | }
14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
15 | m.def("forward", &forward, "wkv5 forward");
16 | m.def("backward", &backward, "wkv5 backward");
17 | }
18 |
19 | TORCH_LIBRARY(wkv5, m) {
20 | m.def("forward", forward);
21 | m.def("backward", backward);
22 | }
23 |
--------------------------------------------------------------------------------
/wkv6_state/cuda/wkv6state_op.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "ATen/ATen.h"
3 | typedef at::BFloat16 bf16;
4 |
5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y);
6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs);
7 |
8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) {
9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), y.data_ptr());
10 | }
11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) {
12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr(), gs.data_ptr());
13 | }
14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
15 | m.def("forward", &forward, "wkv6b forward");
16 | m.def("backward", &backward, "wkv6b backward");
17 | }
18 |
19 | TORCH_LIBRARY(wkv6b, m) {
20 | m.def("forward", forward);
21 | m.def("backward", backward);
22 | }
23 |
--------------------------------------------------------------------------------
/rwkv7_fast_fused/cuda/rwkv7_clampw.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #ifdef _FP32_
4 | using bf = float;
5 | #else
6 | #include
7 | using bf = __nv_bfloat16;
8 | #endif
9 |
10 | void cuda_forward(int B, int T, int H, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*y, float*s, float*sa);
11 |
12 | void forward(torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) {
13 | int B = r.sizes()[0], T = r.sizes()[1], H = r.sizes()[2];
14 | cuda_forward(B, T, H, (bf*)r.data_ptr(), (bf*)w.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr());
15 | }
16 |
17 | void cuda_backward(int B, int T, int H, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*dy, float*s, float*sa, bf*dr, bf*dw, bf*dk, bf*dv, bf*da, bf*db);
18 |
19 | void backward(torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &dy,
20 | torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dr, torch::Tensor &dw, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &da, torch::Tensor &db) {
21 | int B = r.sizes()[0], T = r.sizes()[1], H = r.sizes()[2];
22 | cuda_backward(B, T, H, (bf*)r.data_ptr(), (bf*)w.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)dy.data_ptr(),
23 | (float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dr.data_ptr(), (bf*)dw.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)da.data_ptr(), (bf*)db.data_ptr());
24 | }
25 |
26 | TORCH_LIBRARY(rwkv7_clampw, m) {
27 | m.def("forward", forward);
28 | m.def("backward", backward);
29 | }
30 |
--------------------------------------------------------------------------------
/rwkv7_fast_fused/cuda/rwkv7_state_clampw.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #ifdef _FP32_
4 | using bf = float;
5 | #else
6 | #include
7 | using bf = __nv_bfloat16;
8 | #endif
9 |
10 | void cuda_forward(int B, int T, int H, float*s0, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*y, float*s, float*sa);
11 |
12 | void forward(torch::Tensor &s0, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) {
13 | int B = r.sizes()[0], T = r.sizes()[1], H = r.sizes()[2];
14 | cuda_forward(B, T, H, (float*)s0.data_ptr(), (bf*)r.data_ptr(), (bf*)w.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr());
15 | }
16 |
17 | void cuda_backward(int B, int T, int H, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*dy, float*s, float*sa, float*ds0, bf*dr, bf*dw, bf*dk, bf*dv, bf*da, bf*db);
18 |
19 | void backward(torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &dy,
20 | torch::Tensor &s, torch::Tensor &sa, torch::Tensor &ds0, torch::Tensor &dr, torch::Tensor &dw, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &da, torch::Tensor &db) {
21 | int B = r.sizes()[0], T = r.sizes()[1], H = r.sizes()[2];
22 | cuda_backward(B, T, H, (bf*)r.data_ptr(), (bf*)w.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)dy.data_ptr(),
23 | (float*)s.data_ptr(), (float*)sa.data_ptr(), (float*)ds0.data_ptr(), (bf*)dr.data_ptr(), (bf*)dw.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)da.data_ptr(), (bf*)db.data_ptr());
24 | }
25 |
26 | TORCH_LIBRARY(rwkv7_state_clampw, m) {
27 | m.def("forward", forward);
28 | m.def("backward", backward);
29 | }
30 |
--------------------------------------------------------------------------------
/rwkv7_fast_fused/cuda/rwkv7_statepassing_clampw.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #ifdef _FP32_
4 | using bf = float;
5 | #else
6 | #include
7 | using bf = __nv_bfloat16;
8 | #endif
9 |
10 | void cuda_forward(int B, int T, int H, float*s0, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*y, float*sT, float*s, float*sa);
11 |
12 | void forward(torch::Tensor &s0, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y, torch::Tensor &sT, torch::Tensor &s, torch::Tensor &sa) {
13 | int B = r.sizes()[0], T = r.sizes()[1], H = r.sizes()[2];
14 | cuda_forward(B, T, H, (float*)s0.data_ptr(), (bf*)r.data_ptr(), (bf*)w.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)y.data_ptr(), (float*)sT.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr());
15 | }
16 |
17 | void cuda_backward(int B, int T, int H, bf*r, bf*w, bf*k, bf*v, bf*a, bf*b, bf*dy, float*dsT, float*s, float*sa, float*ds0, bf*dr, bf*dw, bf*dk, bf*dv, bf*da, bf*db);
18 |
19 | void backward(torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &dy, torch::Tensor &dsT,
20 | torch::Tensor &s, torch::Tensor &sa, torch::Tensor &ds0, torch::Tensor &dr, torch::Tensor &dw, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &da, torch::Tensor &db) {
21 | int B = r.sizes()[0], T = r.sizes()[1], H = r.sizes()[2];
22 | cuda_backward(B, T, H, (bf*)r.data_ptr(), (bf*)w.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)a.data_ptr(), (bf*)b.data_ptr(), (bf*)dy.data_ptr(), (float*)dsT.data_ptr(),
23 | (float*)s.data_ptr(), (float*)sa.data_ptr(), (float*)ds0.data_ptr(), (bf*)dr.data_ptr(), (bf*)dw.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)da.data_ptr(), (bf*)db.data_ptr());
24 | }
25 |
26 | TORCH_LIBRARY(rwkv7_statepassing_clampw, m) {
27 | m.def("forward", forward);
28 | m.def("backward", backward);
29 | }
30 |
--------------------------------------------------------------------------------
/wkv5a/cuda/wkv5a_op.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "ATen/ATen.h"
3 | typedef at::BFloat16 bf16;
4 | typedef float DTYPE;
5 |
6 | void cuda_forward(int B, int T, int C, int H, DTYPE *r, DTYPE *k, DTYPE *v, float *w1, DTYPE *u1, float *w2, DTYPE *u2, DTYPE *y);
7 | void cuda_backward(int B, int T, int C, int H, DTYPE *r, DTYPE *k, DTYPE *v, float *w1, float *ww1, DTYPE *u1, float *w2, float *ww2, DTYPE *u2, DTYPE *gy, DTYPE *gr, DTYPE *gk, DTYPE *gv, DTYPE *gw1, DTYPE *gu1, DTYPE *gw2, DTYPE *gu2);
8 |
9 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w1, torch::Tensor &u1, torch::Tensor &w2, torch::Tensor &u2, torch::Tensor &y) {
10 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w1.data_ptr(), u1.data_ptr(), w2.data_ptr(), u2.data_ptr(), y.data_ptr());
11 | }
12 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w1, torch::Tensor &ww1, torch::Tensor &u1, torch::Tensor &w2, torch::Tensor &ww2, torch::Tensor &u2,
13 | torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw1, torch::Tensor &gu1, torch::Tensor &gw2, torch::Tensor &gu2) {
14 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w1.data_ptr(), ww1.data_ptr(), u1.data_ptr(), w2.data_ptr(), ww2.data_ptr(), u2.data_ptr(),
15 | gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw1.data_ptr(), gu1.data_ptr(), gw2.data_ptr(), gu2.data_ptr());
16 | }
17 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
18 | m.def("forward", &forward, "wkv5a forward");
19 | m.def("backward", &backward, "wkv5a backward");
20 | }
21 |
22 | TORCH_LIBRARY(wkv5a, m) {
23 | m.def("forward", forward);
24 | m.def("backward", backward);
25 | }
26 |
--------------------------------------------------------------------------------
/depthwise_conv1d/cuda/timex_cuda_v0.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | template
4 | __global__ void kernel_forward(const F *__restrict__ const w, const F *__restrict__ const k, F *__restrict__ const x,
5 | const F eps, const int B, const int C, const int T)
6 | {
7 | const int i = blockIdx.y;
8 | const int t = threadIdx.x;
9 |
10 | F s = eps;
11 | const F *__restrict__ const www = w + (i % C) * T + (T - 1) - t;
12 | const F *__restrict__ const kk = k + i * T;
13 | for (int u = 0; u <= t; u++)
14 | {
15 | s += www[u] * kk[u];
16 | }
17 | x[i * T + t] = s;
18 | }
19 |
20 | template
21 | __global__ void kernel_backward(const F *__restrict__ const w, const F *__restrict__ const k, const F *__restrict__ const gwk,
22 | F *__restrict__ const gw, F *__restrict__ const gk,
23 | const int B, const int C, const int T)
24 | {
25 | const int i = blockIdx.y;
26 | const int t = threadIdx.x;
27 |
28 | F s = 0;
29 | const F *__restrict__ const ggk = gwk + i * T + (T - 1) - t;
30 | const F *__restrict__ const kk = k + i * T;
31 | for (int u = 0; u <= t; u++)
32 | {
33 | s += ggk[u] * kk[u];
34 | }
35 | gw[i * T + t] = s;
36 |
37 | s = 0;
38 | const F *__restrict__ const ggw = gwk + i * T + (T - 1) + t;
39 | const F *__restrict__ const ww = w + (i % C) * T;
40 | for (int u = t; u < T; u++)
41 | {
42 | s += ggw[-u] * ww[u];
43 | }
44 | gk[i * T + t] = s;
45 | }
46 |
47 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T)
48 | {
49 | dim3 gridDim(1, B * C);
50 | dim3 blockDim(T);
51 | kernel_forward<<>>(w, k, x, eps, B, C, T);
52 | }
53 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T)
54 | {
55 | dim3 gridDim(1, B * C);
56 | dim3 blockDim(T);
57 | kernel_backward<<>>(w, k, gwk, gw, gk, B, C, T);
58 | }
59 |
--------------------------------------------------------------------------------
/depthwise_conv1d/cuda/timex_cuda_v1.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | // require T <= Tmax
4 |
5 | template
6 | __global__ void kernel_forward(const F *__restrict__ const w, const F *__restrict__ const k, F *__restrict__ const x,
7 | const F eps, const int B, const int C, const int T)
8 | {
9 | const int i = blockIdx.y;
10 | const int t = threadIdx.x;
11 |
12 | __shared__ F ww[Tmax];
13 | __shared__ F kk[Tmax];
14 | ww[t] = w[(i % C) * T + t];
15 | kk[t] = k[i * T + t];
16 |
17 | __syncthreads();
18 |
19 | F s = eps;
20 | const F *__restrict__ const www = ww + (T - 1) - t;
21 | for (int u = 0; u <= t; u++)
22 | {
23 | s += www[u] * kk[u];
24 | }
25 | x[i * T + t] = s;
26 | }
27 |
28 | template
29 | __global__ void kernel_backward(const F *__restrict__ const w, const F *__restrict__ const k, const F *__restrict__ const gwk,
30 | F *__restrict__ const gw, F *__restrict__ const gk,
31 | const int B, const int C, const int T)
32 | {
33 | const int i = blockIdx.y;
34 | const int t = threadIdx.x;
35 |
36 | __shared__ F gg[Tmax];
37 | __shared__ F kk[Tmax];
38 | __shared__ F ww[Tmax];
39 | gg[t] = gwk[i * T + t];
40 | kk[t] = k[i * T + t];
41 | ww[t] = w[(i % C) * T + t];
42 |
43 | __syncthreads();
44 |
45 | F s = 0;
46 | const F *__restrict__ const ggk = gg + (T - 1) - t;
47 | for (int u = 0; u <= t; u++)
48 | {
49 | s += ggk[u] * kk[u];
50 | }
51 | gw[i * T + t] = s;
52 |
53 | s = 0;
54 | const F *__restrict__ const ggw = gg + (T - 1) + t;
55 | for (int u = t; u < T; u++)
56 | {
57 | s += ggw[-u] * ww[u];
58 | }
59 | gk[i * T + t] = s;
60 | }
61 |
62 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T)
63 | {
64 | dim3 gridDim(1, B * C);
65 | dim3 blockDim(T);
66 | kernel_forward<<>>(w, k, x, eps, B, C, T);
67 | }
68 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T)
69 | {
70 | dim3 gridDim(1, B * C);
71 | dim3 blockDim(T);
72 | kernel_backward<<>>(w, k, gwk, gw, gk, B, C, T);
73 | }
74 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/wkv5/cuda/wkv5_cuda_v1c.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | template
5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H,
6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u,
7 | F *__restrict__ const _y)
8 | {
9 | const int b = blockIdx.x / H;
10 | const int h = blockIdx.x % H;
11 | const int i = threadIdx.x;
12 | _w += h*N;
13 | _u += h*N;
14 |
15 | __shared__ float state[N * N], rr[N], kk[N];
16 |
17 | for (int j = 0; j < N; ++j)
18 | state[j * N + i] = 0;
19 |
20 | for (int _t = b*T*C + h*N + i, _tend = (b+1)*T*C + h*N + i; _t < _tend; _t += C)
21 | {
22 | const F vv = _v[_t];
23 | F yy = 0;
24 |
25 | rr[i] = _r[_t];
26 | kk[i] = _k[_t];
27 |
28 | __syncthreads();
29 |
30 | for (int j = 0; j < N; j++)
31 | {
32 | const float ww = _w[j];
33 | const float uu = _u[j];
34 |
35 | float x = kk[j] * vv;
36 |
37 | float s = state[j * N + i];
38 | yy += rr[j] * (uu * x + s);
39 | state[j * N + i] = s * ww + x;
40 | }
41 | _y[_t] = yy;
42 |
43 | __syncthreads();
44 | }
45 | }
46 |
47 | template
48 | __global__ void kernel_backward(const int B, const int T, const int C, const int H,
49 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _gy,
50 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
51 | {
52 |
53 | }
54 |
55 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y)
56 | {
57 | assert(H*N == C);
58 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y);
59 | }
60 |
61 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu)
62 | {
63 | assert(H*N == C);
64 | const int SIZE = B*C;
65 | dim3 threadsPerBlock(min(SIZE, 32));
66 | assert(SIZE % threadsPerBlock.x == 0);
67 | dim3 numBlocks(SIZE / threadsPerBlock.x);
68 | kernel_backward<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu);
69 | }
70 |
--------------------------------------------------------------------------------
/wkv5/cuda/wkv5_cuda_v1a.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | template
5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H,
6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u,
7 | F *__restrict__ const _y)
8 | {
9 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
10 | const int _b = idx / C;
11 | const int _h = (idx / N) % H;
12 | const int _i = idx % N;
13 |
14 | const int _o0 = _b*T*C + _h*N;
15 | const int _o1 = _h*N;
16 | const F *__restrict__ const k = _k + _o0;
17 | const F *__restrict__ const v = _v + _o0 + _i;
18 | const F *__restrict__ const r = _r + _o0;
19 | F *__restrict__ const y = _y + _o0 + _i;
20 |
21 | float state[N] = {0};
22 |
23 | for (int _t = 0; _t < T; _t++)
24 | {
25 | const int tt = _t*C;
26 | const F vv = v[tt];
27 | F yy = 0;
28 |
29 | #pragma unroll
30 | for (int _j = 0; _j < N; _j++)
31 | {
32 | const int j = tt + _j;
33 | const int m = _o1 + _j;
34 |
35 | const float x = k[j] * vv;
36 | const float s = state[_j];
37 |
38 | yy += r[j] * (_u[m] * x + s);
39 | state[_j] = s * _w[m] + x;
40 | }
41 | y[tt] = yy;
42 | }
43 | }
44 |
45 | template
46 | __global__ void kernel_backward(const int B, const int T, const int C, const int H,
47 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _gy,
48 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
49 | {
50 |
51 | }
52 |
53 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y)
54 | {
55 | assert(H*N == C);
56 | const int SIZE = B*C;
57 | dim3 threadsPerBlock(min(SIZE, 32));
58 | assert(SIZE % threadsPerBlock.x == 0);
59 | dim3 numBlocks(SIZE / threadsPerBlock.x);
60 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y);
61 | }
62 |
63 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu)
64 | {
65 | assert(H*N == C);
66 | const int SIZE = B*C;
67 | dim3 threadsPerBlock(min(SIZE, 32));
68 | assert(SIZE % threadsPerBlock.x == 0);
69 | dim3 numBlocks(SIZE / threadsPerBlock.x);
70 | kernel_backward<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu);
71 | }
72 |
--------------------------------------------------------------------------------
/wkv5/cuda/wkv5_cuda_v1d.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | template
5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H,
6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u,
7 | F *__restrict__ const _y)
8 | {
9 | const int b = blockIdx.x / H;
10 | const int h = blockIdx.x % H;
11 | const int i = threadIdx.x;
12 | const float4 *__restrict__ const ww = (float4 *)(_w + h*N);
13 | const float4 *__restrict__ const uu = (float4 *)(_u + h*N);
14 |
15 | __shared__ float state[N*N], rr[N], kk[N];
16 |
17 | #pragma unroll
18 | for (int j = 0; j < N; ++j)
19 | state[j*N + i] = 0; // will __syncthreads soon
20 |
21 | for (int bthi = b*T*H*N + 0*H*N + h*N + i; bthi < b*T*H*N + T*H*N + h*N + i; bthi += C)
22 | {
23 | __syncthreads();
24 | rr[i] = _r[bthi]; // rr[0:N] = _r[b,t,h,0:N]
25 | kk[i] = _k[bthi]; // kk[0:N] = _r[b,t,h,0:N]
26 | __syncthreads();
27 |
28 | const float v = _v[bthi];
29 | float y = 0;
30 |
31 | const float4 *__restrict__ const rrr = (float4 *)(rr);
32 | const float4 *__restrict__ const kkk = (float4 *)(kk);
33 | float4 x, s;
34 |
35 | #pragma unroll
36 | for (int j = 0; j < N/4; ++j)
37 | {
38 | const float4 r = rrr[j];
39 | const float4 k = kkk[j];
40 | const float4 w = ww[j];
41 | const float4 u = uu[j];
42 |
43 | x.x = k.x * v;
44 | x.y = k.y * v;
45 | x.z = k.z * v;
46 | x.w = k.w * v;
47 |
48 | const int jj = (j<<2)*N + i;
49 | s.x = state[jj + 0*N];
50 | s.y = state[jj + 1*N];
51 | s.z = state[jj + 2*N];
52 | s.w = state[jj + 3*N];
53 |
54 | y += r.x * (u.x * x.x + s.x);
55 | y += r.y * (u.y * x.y + s.y);
56 | y += r.z * (u.z * x.z + s.z);
57 | y += r.w * (u.w * x.w + s.w);
58 |
59 | state[jj + 0*N] = s.x * w.x + x.x;
60 | state[jj + 1*N] = s.y * w.y + x.y;
61 | state[jj + 2*N] = s.z * w.z + x.z;
62 | state[jj + 3*N] = s.w * w.w + x.w;
63 | }
64 | _y[bthi] = y;
65 | }
66 | }
67 |
68 | template
69 | __global__ void kernel_backward(const int B, const int T, const int C, const int H,
70 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _gy,
71 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
72 | {
73 |
74 | }
75 |
76 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y)
77 | {
78 | assert(H*N == C);
79 | assert(N%4 == 0);
80 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y);
81 | }
82 |
83 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu)
84 | {
85 |
86 | }
87 |
--------------------------------------------------------------------------------
/wkv5/cuda/wkv5_cuda_v3.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | template
5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H,
6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u,
7 | F *__restrict__ const _y)
8 | {
9 |
10 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
11 | const int _b = idx / C;
12 | const int _h = (idx / N) - ((idx / N) / H) * H;
13 | const int _i = idx - (idx / N) * N;
14 |
15 | const int _o0 = _b * T * C + _h * N;
16 | const int _o1 = _h * N;
17 |
18 | const float4 *__restrict__ const k = (float4 *)(_k + _o0);
19 | const float4 *__restrict__ const r = (float4 *)(_r + _o0);
20 | const float4 *__restrict__ const w = (float4 *)(_w + _o1);
21 | const float4 *__restrict__ const u = (float4 *)(_u + _o1);
22 | const F *__restrict__ const v = _v + _o0 + _i;
23 | F *__restrict__ const y = _y + _o0 + _i;
24 |
25 | __align__(16) float4 state[N >> 2] = { make_float4(0.0f, 0.0f, 0.0f, 0.0f) };
26 |
27 | for (int __t = 0; __t < T; __t++)
28 | {
29 | const int _t = __t * (C >> 2);
30 | const int tt = __t * C;
31 | const F vv = v[tt];
32 | float yy = 0.0f;
33 |
34 | #pragma unroll
35 | for (int _j = 0; _j < N >> 2; _j++)
36 | {
37 | const int j = _t + _j;
38 |
39 | const float4 k_val = k[j];
40 | const float4 r_val = r[j];
41 | const float4 ww = w[_j];
42 | const float4 uu = u[_j];
43 | float4 x;
44 | x.x = k_val.x * vv;
45 | x.y = k_val.y * vv;
46 | x.z = k_val.z * vv;
47 | x.w = k_val.w * vv;
48 |
49 | float4 &s = state[_j];
50 |
51 | yy += r_val.x * (uu.x * x.x + s.x) + r_val.y * (uu.y * x.y + s.y) + r_val.z * (uu.z * x.z + s.z) + r_val.w * (uu.w * x.w + s.w);
52 |
53 | s.x = s.x * ww.x + x.x;
54 | s.y = s.y * ww.y + x.y;
55 | s.z = s.z * ww.z + x.z;
56 | s.w = s.w * ww.w + x.w;
57 | }
58 |
59 | y[tt] = yy;
60 | }
61 | }
62 |
63 | template
64 | __global__ void kernel_backward(const int B, const int T, const int C, const int H,
65 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _gy,
66 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
67 | {
68 |
69 | }
70 |
71 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y)
72 | {
73 | assert(H*N == C);
74 | const int SIZE = B*C;
75 | dim3 threadsPerBlock(min(SIZE, 32));
76 | assert(SIZE % threadsPerBlock.x == 0);
77 | dim3 numBlocks(SIZE / threadsPerBlock.x);
78 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y);
79 | }
80 |
81 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu)
82 | {
83 | assert(H*N == C);
84 | const int SIZE = B*C;
85 | dim3 threadsPerBlock(min(SIZE, 32));
86 | assert(SIZE % threadsPerBlock.x == 0);
87 | dim3 numBlocks(SIZE / threadsPerBlock.x);
88 | kernel_backward<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu);
89 | }
90 |
--------------------------------------------------------------------------------
/wkv5/cuda/wkv5_cuda_v1b.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | template
5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H,
6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u,
7 | F *__restrict__ const _y)
8 | {
9 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
10 | const int _b = idx / C;
11 | const int _h = (idx / N) % H;
12 | const int i = idx % N;
13 |
14 | const int _o0 = _b*T*C + _h*N;
15 | const int _o1 = _h*N;
16 | const float4 *__restrict__ const r = (float4 *)(_r + _o0);
17 | const float4 *__restrict__ const k = (float4 *)(_k + _o0);
18 | const float4 *__restrict__ const w = (float4 *)(_w + _o1);
19 | const float4 *__restrict__ const u = (float4 *)(_u + _o1);
20 |
21 | const F *__restrict__ const v = _v + _o0 + i;
22 | F *__restrict__ const y = _y + _o0 + i;
23 |
24 | __align__(16) float4 state[N/4] = { make_float4(0.0f, 0.0f, 0.0f, 0.0f) };
25 |
26 | for (int _t = 0; _t < T; _t++)
27 | {
28 | const int tt = _t*C;
29 | const int ttt = tt >> 2;
30 | const F vv = v[tt];
31 | F yy = 0;
32 |
33 | #pragma unroll
34 | for (int j = 0; j < N/4; j++)
35 | {
36 | const float4 rr = r[ttt + j];
37 | const float4 kk = k[ttt + j];
38 | const float4 ww = w[j];
39 | const float4 uu = u[j];
40 |
41 | float4 x;
42 | x.x = kk.x * vv;
43 | x.y = kk.y * vv;
44 | x.z = kk.z * vv;
45 | x.w = kk.w * vv;
46 |
47 | float4 s = state[j];
48 | yy += rr.x * (uu.x * x.x + s.x) + rr.y * (uu.y * x.y + s.y) + rr.z * (uu.z * x.z + s.z) + rr.w * (uu.w * x.w + s.w);
49 |
50 | float4* ss = state + j;
51 | ss->x = s.x * ww.x + x.x;
52 | ss->y = s.y * ww.y + x.y;
53 | ss->z = s.z * ww.z + x.z;
54 | ss->w = s.w * ww.w + x.w;
55 | }
56 | y[tt] = yy;
57 | }
58 | }
59 |
60 | template
61 | __global__ void kernel_backward(const int B, const int T, const int C, const int H,
62 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _gy,
63 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
64 | {
65 |
66 | }
67 |
68 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y)
69 | {
70 | assert(H*N == C);
71 | const int SIZE = B*C;
72 | dim3 threadsPerBlock(min(SIZE, 32));
73 | assert(SIZE % threadsPerBlock.x == 0);
74 | dim3 numBlocks(SIZE / threadsPerBlock.x);
75 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y);
76 | }
77 |
78 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu)
79 | {
80 | assert(H*N == C);
81 | const int SIZE = B*C;
82 | dim3 threadsPerBlock(min(SIZE, 32));
83 | assert(SIZE % threadsPerBlock.x == 0);
84 | dim3 numBlocks(SIZE / threadsPerBlock.x);
85 | kernel_backward<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu);
86 | }
87 |
--------------------------------------------------------------------------------
/depthwise_conv1d/cuda/timex_cuda_v2.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | // require T % 4 == 0 and T <= Tmax (passed by compiler)
4 |
5 | #define F4(A, B) ((float4 *)(A))[(B) >> 2]
6 |
7 | template
8 | __global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
9 | const F eps, const int B, const int C, const int T) {
10 | const int i = blockIdx.y;
11 | const int t = threadIdx.x << 2;
12 |
13 | __shared__ F ww[Tmax];
14 | __shared__ F k[Tmax];
15 | F4(ww, t) = F4(__w, t + T * (i % C));
16 | F4(k, t) = F4(__k, t + T * i);
17 | __syncthreads();
18 |
19 | float4 s = {eps, eps, eps, eps};
20 |
21 | const F *__restrict__ const w = ww + T - t - 4;
22 | for (int u = 0; u <= t; u++) {
23 | F x = k[u];
24 | s.x += w[u + 3] * x;
25 | s.y += w[u + 2] * x;
26 | s.z += w[u + 1] * x;
27 | s.w += w[u + 0] * x;
28 | }
29 | s.y += w[t + 3] * k[t + 1];
30 | s.z += w[t + 2] * k[t + 1];
31 | s.z += w[t + 3] * k[t + 2];
32 | s.w += w[t + 1] * k[t + 1];
33 | s.w += w[t + 2] * k[t + 2];
34 | s.w += w[t + 3] * k[t + 3];
35 |
36 | F4(x, t + T * i) = s;
37 | }
38 |
39 | template
40 | __global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
41 | F *__restrict__ const gw, F *__restrict__ const gk,
42 | const int B, const int C, const int T) {
43 | const int i = blockIdx.y;
44 | const int t = threadIdx.x << 2;
45 |
46 | __shared__ F w[Tmax];
47 | __shared__ F k[Tmax];
48 | __shared__ F gg[Tmax];
49 | F4(w, t) = F4(__w, t + T * (i % C));
50 | F4(k, t) = F4(__k, t + T * i);
51 | F4(gg, t) = F4(__gwk, t + T * i);
52 | __syncthreads();
53 |
54 | float4 s = {0, 0, 0, 0};
55 | const F *__restrict__ const ga = gg + T - t - 4;
56 |
57 | for (int u = 0; u <= t; u++) {
58 | F x = k[u];
59 | s.x += ga[u + 3] * x;
60 | s.y += ga[u + 2] * x;
61 | s.z += ga[u + 1] * x;
62 | s.w += ga[u] * x;
63 | }
64 | s.y += ga[t + 3] * k[t + 1];
65 | s.z += ga[t + 2] * k[t + 1];
66 | s.z += ga[t + 3] * k[t + 2];
67 | s.w += ga[t + 1] * k[t + 1];
68 | s.w += ga[t + 2] * k[t + 2];
69 | s.w += ga[t + 3] * k[t + 3];
70 |
71 | F4(gw, t + T * i) = s;
72 |
73 | s.x = 0;
74 | s.y = 0;
75 | s.z = 0;
76 | s.w = 0;
77 | const F *__restrict__ const gb = gg + T + t - 3;
78 |
79 | for (int u = t + 3; u < T; u++) {
80 | F x = w[u];
81 | s.x += gb[2 - u] * x;
82 | s.y += gb[3 - u] * x;
83 | s.z += gb[4 - u] * x;
84 | s.w += gb[5 - u] * x;
85 | }
86 | s.x += gb[2 - t] * w[t + 0];
87 | s.x += gb[1 - t] * w[t + 1];
88 | s.x += gb[0 - t] * w[t + 2];
89 | s.y += gb[2 - t] * w[t + 1];
90 | s.y += gb[1 - t] * w[t + 2];
91 | s.z += gb[2 - t] * w[t + 2];
92 |
93 | F4(gk, t + T * i) = s;
94 | }
95 |
96 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
97 | dim3 gridDim(1, B * C);
98 | dim3 blockDim(T >> 2);
99 | kernel_forward<<>>(w, k, x, eps, B, C, T);
100 | }
101 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
102 | dim3 gridDim(1, B * C);
103 | dim3 blockDim(T >> 2);
104 | kernel_backward<<>>(w, k, gwk, gw, gk, B, C, T);
105 | }
106 |
--------------------------------------------------------------------------------
/wkv/cuda/wkv_cuda_v1.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | template
4 | __global__ void kernel_forward(const int B, const int T, const int C,
5 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
6 | F *__restrict__ const _y)
7 | {
8 | const int _b = blockIdx.x;
9 | const int _c = threadIdx.x;
10 | const int _offset = _b * T * C + _c;
11 |
12 | F u = _u[_c];
13 | F w = _w[_c];
14 | const F *__restrict__ const k = _k + _offset;
15 | const F *__restrict__ const v = _v + _offset;
16 | F *__restrict__ const y = _y + _offset;
17 |
18 | F p = 0, q = 0, o = -65500;
19 | // p and q are running sums divided by exp(o) (to avoid overflows)
20 | for (int i = 0; i < T; i++)
21 | {
22 | const int ii = i * C;
23 |
24 | F no = max(o, u + k[ii]);
25 | F A = exp(o - no);
26 | F B = exp(u + k[ii] - no);
27 | y[ii] = (A * p + B * v[ii]) / (A * q + B);
28 |
29 | no = max(w + o, k[ii]);
30 | A = exp(w + o - no);
31 | B = exp(k[ii] - no);
32 | p = A * p + B * v[ii];
33 | q = A * q + B;
34 | o = no;
35 | }
36 | }
37 |
38 | template
39 | __global__ void kernel_backward(const int B, const int T, const int C,
40 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
41 | F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv)
42 | {
43 | const int _b = blockIdx.x;
44 | const int _c = threadIdx.x;
45 | const int _offset = _b * T * C + _c;
46 |
47 | F u = _u[_c];
48 | F w = _w[_c];
49 | const F *__restrict__ const k = _k + _offset;
50 | const F *__restrict__ const v = _v + _offset;
51 | const F *__restrict__ const gy = _gy + _offset;
52 |
53 | F *__restrict__ const gk = _gk + _offset;
54 | F *__restrict__ const gv = _gv + _offset;
55 |
56 | F y[1024], z[1024], zexp[1024];
57 |
58 | F gw = 0, gu = 0;
59 | F p = 0, q = 0;
60 | F dpdw = 0, dqdw = 0;
61 | F o = -65500;
62 | for (int i = 0; i < T; i++)
63 | {
64 | const int ii = i * C;
65 | F no = max(o, k[ii] + u);
66 | F A = exp(o - no);
67 | F B = exp(k[ii] + u - no);
68 |
69 | F num = A * p + B * v[ii];
70 | F iden = 1 / (A * q + B);
71 |
72 | y[i] = num * iden;
73 | z[i] = iden;
74 | zexp[i] = k[ii] + u - no;
75 |
76 | gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
77 | gu += gy[ii] * (v[ii] - y[i]) * B * iden;
78 |
79 | no = max(w + o, k[ii]);
80 | A = exp(w + o - no);
81 | B = exp(k[ii] - no);
82 | dpdw = A * (p + dpdw);
83 | dqdw = A * (q + dqdw);
84 | p = A * p + B * v[ii];
85 | q = A * q + B;
86 | o = no;
87 | }
88 |
89 | F gp = 0, gq = 0;
90 | o = -65500;
91 | for (int i = T - 1; i >= 0; i--)
92 | {
93 | const int ii = i * C;
94 | F A = gy[ii] * z[i] * exp(zexp[i]);
95 | F B = exp(k[ii] + o);
96 | gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
97 | gv[ii] = A + B * gp;
98 |
99 | F no = max(w + o, zexp[i] - k[ii] - u);
100 | A = exp(w + o - no);
101 | B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
102 | gp = A * gp + B;
103 | gq = A * gq - B * y[i];
104 | o = no;
105 | }
106 |
107 | // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
108 | const int _offsetBC = _b * C + _c;
109 | _gw[_offsetBC] += gw * _w[_c];
110 | _gu[_offsetBC] += gu;
111 | }
112 |
113 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y)
114 | {
115 | dim3 numBlocks(B);
116 | dim3 threadsPerBlock(C);
117 | kernel_forward<<>>(B, T, C, w, u, k, v, y);
118 | }
119 |
120 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv)
121 | {
122 | dim3 numBlocks(B);
123 | dim3 threadsPerBlock(C);
124 | kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
125 | }
126 |
--------------------------------------------------------------------------------
/wkv5/cuda/wkv5_cuda_ref.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | template
5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H,
6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u,
7 | F *__restrict__ const _y)
8 | {
9 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
10 | const int _b = idx / C;
11 | const int _h = (idx / N) % H;
12 | const int _i = idx % N;
13 |
14 | const int _o0 = _b*T*C + _h*N;
15 | const int _o1 = _h*N;
16 | const F *__restrict__ const k = _k + _o0;
17 | const F *__restrict__ const v = _v + _o0 + _i;
18 | const F *__restrict__ const r = _r + _o0;
19 | F *__restrict__ const y = _y + _o0 + _i;
20 |
21 | float state[N] = {0};
22 |
23 | for (int __t = 0; __t < T; __t++)
24 | {
25 | const int _t = __t*C;
26 | const F vv = v[_t];
27 |
28 | for (int _j = 0; _j < N; _j++)
29 | {
30 | const int j = _t + _j;
31 | const int m = _o1 + _j;
32 |
33 | const float x = k[j] * vv;
34 | const float s = state[_j];
35 |
36 | atomicAdd(y + _t, r[j] * (_u[m] * x + s));
37 | state[_j] = s * _w[m] + x;
38 | }
39 | }
40 | }
41 |
42 | template
43 | __global__ void kernel_backward (const int B, const int T, const int C, const int H,
44 | const F *__restrict__ const r, const F *__restrict__ const k, const F *__restrict__ const v, const F *__restrict__ const w, const F *__restrict__ const wwww, const F *__restrict__ const _u, const F *__restrict__ const gy,
45 | F *__restrict__ const gr, F *__restrict__ const gk, F *__restrict__ const gv, F *__restrict__ const gw, F *__restrict__ const gu)
46 | {
47 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
48 | const int b = idx / C;
49 | const int h = (idx / N) % H;
50 | const int i = idx % N;
51 |
52 | for (int t = 0; t < T; t++) {
53 | for (int j = 0; j < N; j++) {
54 | for (int tt = 0; tt <= t; tt++) {
55 | F ww = (tt == t) ? _u[h*N + i] : pow(w[h*N + i], t-tt-1);
56 |
57 | gr[b*T*H*N + t*H*N + h*N + i] += ww * k[b*T*H*N + tt*H*N + h*N + i] *
58 | v[b*T*H*N + tt*H*N + h*N + j] * gy[b*T*H*N + t*H*N + h*N + j];
59 | }
60 |
61 | for (int tt = t; tt < T; tt++) {
62 | F ww = (tt == t) ? _u[h*N + i] : pow(w[h*N + i], tt-t-1);
63 |
64 | gk[b*T*H*N + t*H*N + h*N + i] += r[b*T*H*N + tt*H*N + h*N + i] * ww *
65 | v[b*T*H*N + t*H*N + h*N + j] * gy[b*T*H*N + tt*H*N + h*N + j];
66 |
67 | ww = (tt == t) ? _u[h*N + j] : pow(w[h*N + j], tt-t-1);
68 |
69 | gv[b*T*H*N + t*H*N + h*N + i] += r[b*T*H*N + tt*H*N + h*N + j] * ww *
70 | k[b*T*H*N + t*H*N + h*N + j] * gy[b*T*H*N + tt*H*N + h*N + i];
71 | }
72 |
73 | atomicAdd(gu + h*N + i, r[b*T*H*N + t*H*N + h*N + i] * k[b*T*H*N + t*H*N + h*N + i] *
74 | v[b*T*H*N + t*H*N + h*N + j] * gy[b*T*H*N + t*H*N + h*N + j]);
75 |
76 | for (int tt = 0; tt < t-1; tt++) {
77 | F ww = (t-tt-1) * wwww[h*N + i] * pow(w[h*N + i], t-tt-1);
78 |
79 | atomicAdd(gw + h*N + i, r[b*T*H*N + t*H*N + h*N + i] * ww * k[b*T*H*N + tt*H*N + h*N + i] *
80 | v[b*T*H*N + tt*H*N + h*N + j] * gy[b*T*H*N + t*H*N + h*N + j]);
81 | }
82 | }
83 | }
84 | }
85 |
86 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y)
87 | {
88 | dim3 threadsPerBlock( min(B*C, 32) );
89 | assert(B * C % threadsPerBlock.x == 0);
90 | dim3 numBlocks(B * C / threadsPerBlock.x);
91 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y);
92 | }
93 |
94 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *ww, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu)
95 | {
96 | dim3 threadsPerBlock( min(B*C, 32) );
97 | assert(B * C % threadsPerBlock.x == 0);
98 | dim3 numBlocks(B * C / threadsPerBlock.x);
99 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
100 | }
101 |
--------------------------------------------------------------------------------
/wkv/cuda/wkv_cuda_v2.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | template
5 | __global__ void kernel_forward(const int B, const int T, const int C,
6 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
7 | F *__restrict__ const _y)
8 | {
9 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
10 | const int _b = idx / C;
11 | const int _c = idx % C;
12 | const int _offset = _b * T * C + _c;
13 |
14 | F u = _u[_c];
15 | F w = _w[_c];
16 | const F *__restrict__ const k = _k + _offset;
17 | const F *__restrict__ const v = _v + _offset;
18 | F *__restrict__ const y = _y + _offset;
19 |
20 | F p = 0, q = 0, o = -65500;
21 | // p and q are running sums divided by exp(o) (to avoid overflows)
22 | for (int i = 0; i < T; i++)
23 | {
24 | const int ii = i * C;
25 |
26 | F no = max(o, u + k[ii]);
27 | F A = exp(o - no);
28 | F B = exp(u + k[ii] - no);
29 | y[ii] = (A * p + B * v[ii]) / (A * q + B);
30 |
31 | no = max(w + o, k[ii]);
32 | A = exp(w + o - no);
33 | B = exp(k[ii] - no);
34 | p = A * p + B * v[ii];
35 | q = A * q + B;
36 | o = no;
37 | }
38 | }
39 |
40 | template
41 | __global__ void kernel_backward(const int B, const int T, const int C,
42 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
43 | F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv)
44 | {
45 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
46 | const int _b = idx / C;
47 | const int _c = idx % C;
48 | const int _offset = _b * T * C + _c;
49 |
50 | F u = _u[_c];
51 | F w = _w[_c];
52 | const F *__restrict__ const k = _k + _offset;
53 | const F *__restrict__ const v = _v + _offset;
54 | const F *__restrict__ const gy = _gy + _offset;
55 |
56 | F *__restrict__ const gk = _gk + _offset;
57 | F *__restrict__ const gv = _gv + _offset;
58 |
59 | F y[4096], z[4096], zexp[4096];
60 |
61 | F gw = 0, gu = 0;
62 | F p = 0, q = 0;
63 | F dpdw = 0, dqdw = 0;
64 | F o = -65500;
65 | for (int i = 0; i < T; i++)
66 | {
67 | const int ii = i * C;
68 | F no = max(o, k[ii] + u);
69 | F A = exp(o - no);
70 | F B = exp(k[ii] + u - no);
71 |
72 | F num = A * p + B * v[ii];
73 | F iden = 1 / (A * q + B);
74 |
75 | y[i] = num * iden;
76 | z[i] = iden;
77 | zexp[i] = k[ii] + u - no;
78 |
79 | gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
80 | gu += gy[ii] * (v[ii] - y[i]) * B * iden;
81 |
82 | no = max(w + o, k[ii]);
83 | A = exp(w + o - no);
84 | B = exp(k[ii] - no);
85 | dpdw = A * (p + dpdw);
86 | dqdw = A * (q + dqdw);
87 | p = A * p + B * v[ii];
88 | q = A * q + B;
89 | o = no;
90 | }
91 |
92 | F gp = 0, gq = 0;
93 | o = -65500;
94 | for (int i = T - 1; i >= 0; i--)
95 | {
96 | const int ii = i * C;
97 | F A = gy[ii] * z[i] * exp(zexp[i]);
98 | F B = exp(k[ii] + o);
99 | gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
100 | gv[ii] = A + B * gp;
101 |
102 | F no = max(w + o, zexp[i] - k[ii] - u);
103 | A = exp(w + o - no);
104 | B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
105 | gp = A * gp + B;
106 | gq = A * gq - B * y[i];
107 | o = no;
108 | }
109 |
110 | // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
111 | const int _offsetBC = _b * C + _c;
112 | _gw[_offsetBC] += gw * _w[_c];
113 | _gu[_offsetBC] += gu;
114 | }
115 |
116 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y)
117 | {
118 | dim3 threadsPerBlock( min(C, 32) );
119 | assert(B * C % threadsPerBlock.x == 0);
120 | dim3 numBlocks(B * C / threadsPerBlock.x);
121 | kernel_forward<<>>(B, T, C, w, u, k, v, y);
122 | }
123 |
124 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv)
125 | {
126 | dim3 threadsPerBlock( min(C, 32) );
127 | assert(B * C % threadsPerBlock.x == 0);
128 | dim3 numBlocks(B * C / threadsPerBlock.x);
129 | kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
130 | }
131 |
--------------------------------------------------------------------------------
/wkv/cuda/wkv_cuda_v0.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | template
4 | __global__ void kernel_forward(const int B, const int T, const int C,
5 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
6 | F *__restrict__ const _y) {
7 | const int _b = blockIdx.x;
8 | const int _c = threadIdx.x;
9 | const int _offset = _b*T*C + _c;
10 |
11 | F u = _u[_c];
12 | F w = _w[_c];
13 | const F *__restrict__ const k = _k + _offset;
14 | const F *__restrict__ const v = _v + _offset;
15 | F *__restrict__ const y = _y + _offset;
16 |
17 | y[0] = v[0];
18 | F a = v[0];
19 | F b = 1;
20 | F p = k[0];
21 | for (int i = 1; i < T; i++)
22 | {
23 | const int ii = i*C;
24 | F kk = k[ii];
25 | F vv = v[ii];
26 |
27 | F q = max(p, u+kk);
28 | F e1 = exp(p - q);
29 | F e2 = exp(u+kk - q);
30 | y[ii] = (e1 * a + e2 * vv) / (e1 * b + e2);
31 |
32 | q = max(p+w, kk);
33 | e1 = exp(p+w - q);
34 | e2 = exp(kk - q);
35 | a = e1 * a + e2 * vv;
36 | b = e1 * b + e2;
37 | p = q;
38 | }
39 | }
40 |
41 | template
42 | __global__ void kernel_backward(const int B, const int T, const int C,
43 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
44 | F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
45 | const int _b = blockIdx.x;
46 | const int _c = threadIdx.x;
47 |
48 | F u = _u[_c];
49 | F w = _w[_c];
50 | const int _offset = _b*T*C + _c;
51 | const F *__restrict__ const gy = _gy + _offset;
52 | F k[1024];
53 | F v[1024];
54 | for (int i = 0; i < T; i++)
55 | {
56 | const int ii = _offset + i*C;
57 | k[i] = _k[ii];
58 | v[i] = _v[ii];
59 | }
60 |
61 | F gw = 0;
62 | F gu = 0;
63 | F gk[1024] = {0};
64 | F gv[1024] = {0};
65 |
66 | F a = 0;
67 | F b = 0;
68 | F p = -65500;
69 | F qq = 0;
70 | F r = 0;
71 | F rr = 0;
72 | F s = 0;
73 | F ss = 0;
74 | F ee = 0;
75 | for (int i = 0; i < T; i++)
76 | {
77 | F kk = k[i];
78 | F vv = v[i];
79 | F gg = gy[i*C];
80 |
81 | F q = max(p, u+kk);
82 | F e1 = exp(p - q);
83 | F e2 = exp(u+kk - q);
84 |
85 | F c = e1 * a + e2 * vv;
86 | F d = e1 * b + e2;
87 |
88 | for (int j = 0; j < i; j++)
89 | {
90 | ee = exp((i-j-1)*w + k[j] - q) * gg / d;
91 | gv[j] += ee;
92 | gk[j] += ee * (v[j] - c / d);
93 | }
94 | ee = e2 * gg / d;
95 | gv[i] += ee;
96 | ee *= (vv - c / d);
97 | gk[i] += ee;
98 | gu += ee;
99 |
100 | if (i > 2)
101 | {
102 | e1 = exp(w + qq - q);
103 | e2 = exp(w + k[i-2] - q);
104 | ss = e1 * ss + e2;
105 | s = e1 * s + ss;
106 | rr = e1 * rr + e2 * v[i-2];
107 | r = e1 * r + rr;
108 | }
109 | if (i == 2)
110 | {
111 | ss = exp(w + k[0] - q);
112 | s = ss;
113 | rr = ss * v[0];
114 | r = rr;
115 | }
116 | gw += (r / d - c * s / (d * d)) * gg * w;
117 | qq = q;
118 |
119 | q = max(p+w, kk);
120 | e1 = exp(p+w - q);
121 | e2 = exp(kk - q);
122 | a = e1 * a + e2 * vv;
123 | b = e1 * b + e2;
124 | p = q;
125 | }
126 |
127 | const int _offsetBC = _b*C + _c;
128 | _gw[_offsetBC] += gw;
129 | _gu[_offsetBC] += gu;
130 | F *__restrict__ const __gk = _gk + _offset;
131 | F *__restrict__ const __gv = _gv + _offset;
132 | for (int i = 0; i < T; i++)
133 | {
134 | const int ii = i*C;
135 | __gk[ii] = gk[i];
136 | __gv[ii] = gv[i];
137 | }
138 | }
139 |
140 | // note: test B,C & 1,BC & BC,1 combinations
141 |
142 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
143 | dim3 numBlocks(B);
144 | dim3 threadsPerBlock(C);
145 | kernel_forward<<>>(B, T, C, w, u, k, v, y);
146 | }
147 |
148 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv)
149 | {
150 | dim3 numBlocks(B);
151 | dim3 threadsPerBlock(C);
152 | kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
153 | }
154 |
--------------------------------------------------------------------------------
/depthwise_conv1d/cuda/timex_cuda_v3.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | // require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)
4 |
5 | #define F4(A, B) ((float4 *)(A))[(B) >> 2]
6 |
7 | template
8 | __global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
9 | const F eps, const int B, const int C, const int T) {
10 | const int i = blockIdx.y;
11 | const int t = threadIdx.x << 2;
12 | const int ti = t + T * i;
13 | const int tj = T * (B * C) / BF;
14 |
15 | __shared__ F ww[Tmax];
16 | __shared__ F kk[Tmax * BF];
17 | F4(ww, t) = F4(__w, t + T * (i % C));
18 |
19 | #pragma unroll
20 | for (int j = 0; j < BF; j++) {
21 | F4(kk, t + Tmax * j) = F4(__k, ti + tj * j);
22 | }
23 | __syncthreads();
24 |
25 | float4 ss[BF];
26 | #pragma unroll
27 | for (int j = 0; j < BF; j++) {
28 | ss[j] = {eps, eps, eps, eps};
29 | }
30 | for (int u = 0; u <= t; u++) {
31 | const F *__restrict__ const w = ww + T - t + u - 4;
32 | #pragma unroll
33 | for (int j = 0; j < BF; j++) {
34 | float4 *__restrict__ const s = ss + j;
35 | const F k = kk[u + Tmax * j];
36 | s->x += w[3] * k;
37 | s->y += w[2] * k;
38 | s->z += w[1] * k;
39 | s->w += w[0] * k;
40 | }
41 | }
42 | #pragma unroll
43 | for (int j = 0; j < BF; j++) {
44 | float4 *__restrict__ const s = ss + j;
45 | const F *__restrict__ const w = ww + T - 3;
46 | const F *__restrict__ const k = kk + Tmax * j + t + 1;
47 | s->y += w[2] * k[0];
48 | s->z += w[1] * k[0];
49 | s->z += w[2] * k[1];
50 | s->w += w[0] * k[0];
51 | s->w += w[1] * k[1];
52 | s->w += w[2] * k[2];
53 | F4(x, ti + tj * j) = *s;
54 | }
55 | }
56 |
57 | template
58 | __global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
59 | F *__restrict__ const gw, F *__restrict__ const gk,
60 | const int B, const int C, const int T) {
61 | const int i = blockIdx.y;
62 | const int t = threadIdx.x << 2;
63 | const int ti = t + T * i;
64 | const int tj = T * (B * C) / BB;
65 |
66 | __shared__ F ww[Tmax];
67 | __shared__ F kk[Tmax * BB];
68 | __shared__ F gg[Tmax * BB];
69 | F4(ww, t) = F4(__w, t + T * (i % C));
70 |
71 | #pragma unroll
72 | for (int j = 0; j < BB; j++) {
73 | F4(kk, t + Tmax * j) = F4(__k, ti + tj * j);
74 | F4(gg, t + Tmax * j) = F4(__gwk, ti + tj * j);
75 | }
76 | __syncthreads();
77 |
78 | float4 ss[BB];
79 | #pragma unroll
80 | for (int j = 0; j < BB; j++) {
81 | ss[j] = {0, 0, 0, 0};
82 | }
83 | for (int u = 0; u <= t; u++) {
84 | #pragma unroll
85 | for (int j = 0; j < BB; j++) {
86 | float4 *__restrict__ const s = ss + j;
87 | const F *__restrict__ const g = gg + Tmax * j + T - t + u - 4;
88 | const F k = kk[u + Tmax * j];
89 | s->x += g[3] * k;
90 | s->y += g[2] * k;
91 | s->z += g[1] * k;
92 | s->w += g[0] * k;
93 | }
94 | }
95 | #pragma unroll
96 | for (int j = 0; j < BB; j++) {
97 | float4 *__restrict__ const s = ss + j;
98 | const F *__restrict__ const k = kk + Tmax * j + t + 1;
99 | const F *__restrict__ const g = gg + Tmax * j + T - 3;
100 | s->y += g[2] * k[0];
101 | s->z += g[1] * k[0];
102 | s->z += g[2] * k[1];
103 | s->w += g[0] * k[0];
104 | s->w += g[1] * k[1];
105 | s->w += g[2] * k[2];
106 | F4(gw, ti + tj * j) = *s;
107 | }
108 |
109 | #pragma unroll
110 | for (int j = 0; j < BB; j++) {
111 | ss[j] = {0, 0, 0, 0};
112 | }
113 | for (int u = t + 3; u < T; u++) {
114 | const F w = ww[u];
115 | #pragma unroll
116 | for (int j = 0; j < BB; j++) {
117 | float4 *__restrict__ const s = ss + j;
118 | const F *__restrict__ const g = gg + Tmax * j + T + t - u - 1;
119 | s->x += g[0] * w;
120 | s->y += g[1] * w;
121 | s->z += g[2] * w;
122 | s->w += g[3] * w;
123 | }
124 | }
125 | #pragma unroll
126 | for (int j = 0; j < BB; j++) {
127 | float4 *__restrict__ const s = ss + j;
128 | const F *__restrict__ const g = gg + Tmax * j + T - 3;
129 | const F *__restrict__ const w = ww + t;
130 | s->x += g[2] * w[0];
131 | s->x += g[1] * w[1];
132 | s->x += g[0] * w[2];
133 | s->y += g[2] * w[1];
134 | s->y += g[1] * w[2];
135 | s->z += g[2] * w[2];
136 | F4(gk, ti + tj * j) = *s;
137 | }
138 | }
139 |
140 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
141 | dim3 gridDim(1, B * C / BF);
142 | dim3 blockDim(T >> 2);
143 | kernel_forward<<>>(w, k, x, eps, B, C, T);
144 | }
145 |
146 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
147 | dim3 gridDim(1, B * C / BB);
148 | dim3 blockDim(T >> 2);
149 | kernel_backward<<>>(w, k, gwk, gw, gk, B, C, T);
150 | }
151 |
--------------------------------------------------------------------------------
/wkv5_bf16/cuda/wkv5_cuda_v1a.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "ATen/ATen.h"
4 | typedef at::BFloat16 bf16;
5 |
6 | template
7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H,
8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
9 | F *__restrict__ const _y)
10 | {
11 | const int b = blockIdx.x / H;
12 | const int h = blockIdx.x % H;
13 | const int i = threadIdx.x;
14 | _w += h*_N_;
15 | _u += h*_N_;
16 |
17 | __shared__ float r[_N_], k[_N_], u[_N_];
18 | float state[_N_] = {0};
19 |
20 | __syncthreads();
21 | u[i] = float(_u[i]);
22 | __syncthreads();
23 |
24 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
25 | {
26 | __syncthreads();
27 | r[i] = float(_r[t]);
28 | k[i] = float(_k[t]);
29 | __syncthreads();
30 |
31 | const float v = float(_v[t]);
32 | float y = 0;
33 |
34 | #pragma unroll
35 | for (int j = 0; j < _N_; j++)
36 | {
37 | float x = k[j] * v;
38 | float& s = state[j];
39 |
40 | y += r[j] * (u[j] * x + s);
41 | s = s * _w[j] + x;
42 | }
43 | _y[t] = F(y);
44 | }
45 | }
46 |
47 | template
48 | __global__ void kernel_backward(const int B, const int T, const int C, const int H,
49 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy,
50 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
51 | {
52 | const int b = blockIdx.x / H;
53 | const int h = blockIdx.x % H;
54 | const int i = threadIdx.x;
55 | _w += h*_N_;
56 | _u += h*_N_;
57 | __w += h*_N_;
58 | const float w = _w[i];
59 | const float u = float(_u[i]);
60 | const float ww = __w[i];
61 |
62 | __shared__ float v[_N_], r[_N_], k[_N_], gy[_N_], w_[_N_], u_[_N_];
63 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0};
64 |
65 | float gw = 0, gu = 0;
66 | const int t000 = b*T*C + h*_N_ + i;
67 | const int t111 = (b+1)*T*C + h*_N_ + i;
68 | const int t222 = t111 - 2*C;
69 |
70 | for (int _t = t000; _t < t111; _t += C)
71 | {
72 | __syncthreads();
73 | v[i] = float(_v[_t]);
74 | gy[i] = float(_gy[_t]);
75 | __syncthreads();
76 |
77 | const float k = float(_k[_t]);
78 | const float r = float(_r[_t]);
79 |
80 | float gr = 0;
81 |
82 | #pragma unroll
83 | for (int j = 0; j < _N_; j++)
84 | {
85 | float x = v[j] * k;
86 | float& s = state[j];
87 |
88 | gr += gy[j] * (u * x + s);
89 | gu += r * x * gy[j];
90 | s = s * w + x;
91 | }
92 | _gr[_t] = F(gr);
93 | }
94 | _gu[b*C + h*_N_ + i] = F(gu);
95 |
96 | for (int _t = t000; _t < t222; _t += C)
97 | {
98 | __syncthreads();
99 | v[i] = float(_v[_t]);
100 | gy[i] = float(_gy[_t + 2*C]);
101 | __syncthreads();
102 |
103 | const float k = float(_k[_t]);
104 | const float r = float(_r[_t + 2*C]);
105 |
106 | #pragma unroll
107 | for (int j = 0; j < _N_; j++)
108 | {
109 | float x = v[j] * k;
110 | saaaa[j] = w * (saaaa[j] + sbbbb[j] + x);
111 | sbbbb[j] = w * (sbbbb[j] + x);
112 |
113 | gw += r * ww * saaaa[j] * gy[j];
114 | }
115 | }
116 | _gw[b*C + h*_N_ + i] = F(gw);
117 |
118 | #pragma unroll
119 | for (int j = 0; j < _N_; ++j) {
120 | saaaa[j] = 0;
121 | sbbbb[j] = 0;
122 | }
123 |
124 | __syncthreads();
125 | w_[i] = _w[i];
126 | u_[i] = float(_u[i]);
127 | __syncthreads();
128 |
129 | for (int _t = t111 - C; _t >= t000; _t -= C)
130 | {
131 | __syncthreads();
132 | r[i] = float(_r[_t]);
133 | k[i] = float(_k[_t]);
134 | v[i] = float(_v[_t]);
135 | gy[i] = float(_gy[_t]);
136 | __syncthreads();
137 |
138 | float gk = 0, gv = 0;
139 |
140 | #pragma unroll
141 | for (int j = 0; j < _N_; j++)
142 | {
143 | float x = gy[j] * r[i];
144 | float& s = saaaa[j];
145 | gk += v[j] * (u * x + s);
146 | s = s * w + x;
147 |
148 | float x2 = gy[i] * r[j];
149 | float& s2 = sbbbb[j];
150 | gv += k[j] * (u_[j] * x2 + s2);
151 | s2 = s2 * w_[j] + x2;
152 | }
153 | _gk[_t] = F(gk);
154 | _gv[_t] = F(gv);
155 | }
156 | }
157 |
158 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
159 | {
160 | assert(H*_N_ == C);
161 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y);
162 | }
163 |
164 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
165 | {
166 | assert(H*_N_ == C);
167 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
168 | }
169 |
--------------------------------------------------------------------------------
/wkv5_bf16/cuda/wkv5_cuda_v1.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "ATen/ATen.h"
4 | typedef at::BFloat16 bf16;
5 |
6 | template
7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H,
8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
9 | F *__restrict__ const _y)
10 | {
11 | const int b = blockIdx.x / H;
12 | const int h = blockIdx.x % H;
13 | const int i = threadIdx.x;
14 | _w += h*_N_;
15 | _u += h*_N_;
16 |
17 | __shared__ float r[_N_], k[_N_];
18 |
19 | float state[_N_] = {0};
20 |
21 | for (int _t = b*T*C + h*_N_ + i; _t < (b+1)*T*C + h*_N_ + i; _t += C)
22 | {
23 | __syncthreads();
24 | r[i] = float(_r[_t]);
25 | k[i] = float(_k[_t]);
26 | __syncthreads();
27 |
28 | const float v = float(_v[_t]);
29 | float y = 0;
30 |
31 | for (int j = 0; j < _N_; j++)
32 | {
33 | float x = k[j] * v;
34 |
35 | float s = state[j];
36 | state[j] = s * _w[j] + x;
37 |
38 | y += r[j] * (float(_u[j]) * x + s);
39 | }
40 | _y[_t] = F(y);
41 | }
42 | }
43 |
44 | template
45 | __global__ void kernel_backward(const int B, const int T, const int C, const int H,
46 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy,
47 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
48 | {
49 | const int b = blockIdx.x / H;
50 | const int h = blockIdx.x % H;
51 | const int i = threadIdx.x;
52 | _w += h*_N_;
53 | _u += h*_N_;
54 | __w += h*_N_;
55 |
56 | __shared__ float v[_N_], r[_N_], k[_N_], gy[_N_];
57 |
58 | const float w = _w[i];
59 | const float u = float(_u[i]);
60 | const float ww = __w[i];
61 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0};
62 | float gw = 0, gu = 0;
63 |
64 | for (int _t = b*T*C + h*_N_ + i, _tend = (b+1)*T*C + h*_N_ + i; _t < _tend; _t += C)
65 | {
66 | __syncthreads();
67 | v[i] = float(_v[_t]);
68 | gy[i] = float(_gy[_t]);
69 | __syncthreads();
70 |
71 | const float k = float(_k[_t]);
72 | const float r = float(_r[_t]);
73 | float gr = 0;
74 |
75 | #pragma unroll
76 | for (int j = 0; j < _N_; j++)
77 | {
78 | float x = v[j] * k;
79 | float s = state[j];
80 | state[j] = s * w + x;
81 |
82 | gr += gy[j] * (u * x + s);
83 | gu += r * x * gy[j];
84 | }
85 |
86 | _gr[_t] = F(gr);
87 |
88 | if (_t < _tend - 2*C)
89 | {
90 | __syncthreads();
91 | gy[i] = float(_gy[_t + 2*C]);
92 | __syncthreads();
93 |
94 | const float r = float(_r[_t + 2*C]);
95 |
96 | #pragma unroll
97 | for (int j = 0; j < _N_; j++)
98 | {
99 | float x = v[j] * k;
100 | saaaa[j] = w * (saaaa[j] + sbbbb[j] + x);
101 | sbbbb[j] = w * (sbbbb[j] + x);
102 |
103 | gw += r * ww * saaaa[j] * gy[j];
104 | }
105 | }
106 | }
107 | _gu[b*C + h*_N_ + i] = F(gu);
108 | _gw[b*C + h*_N_ + i] = F(gw);
109 |
110 | #pragma unroll
111 | for (int j = 0; j < _N_; ++j)
112 | state[j] = 0;
113 |
114 | for (int _t = (b+1)*T*C + h*_N_ + i - C, _tend = b*T*C + h*_N_ + i; _t >= _tend; _t -= C)
115 | {
116 | __syncthreads();
117 | v[i] = float(_v[_t]);
118 | gy[i] = float(_gy[_t]);
119 | __syncthreads();
120 |
121 | const float r = float(_r[_t]);
122 | float gk = 0;
123 |
124 | #pragma unroll
125 | for (int j = 0; j < _N_; j++)
126 | {
127 | float x = gy[j] * r;
128 | float s = state[j];
129 | state[j] = s * w + x;
130 |
131 | gk += v[j] * (u * x + s);
132 | }
133 | _gk[_t] = F(gk);
134 | }
135 |
136 | #pragma unroll
137 | for (int j = 0; j < _N_; ++j)
138 | state[j] = 0;
139 |
140 | for (int _t = (b+1)*T*C + h*_N_ + i - C, _tend = b*T*C + h*_N_ + i; _t >= _tend; _t -= C)
141 | {
142 | __syncthreads();
143 | k[i] = float(_k[_t]);
144 | r[i] = float(_r[_t]);
145 | __syncthreads();
146 |
147 | const float gy = float(_gy[_t]);
148 | float gv = 0;
149 |
150 | #pragma unroll
151 | for (int j = 0; j < _N_; j++)
152 | {
153 | float x = gy * r[j];
154 | float s = state[j];
155 | state[j] = s * float(_w[j]) + x;
156 |
157 | gv += k[j] * (float(_u[j]) * x + s);
158 | }
159 | _gv[_t] = F(gv);
160 | }
161 | }
162 |
163 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
164 | {
165 | assert(H*_N_ == C);
166 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y);
167 | }
168 |
169 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
170 | {
171 | assert(H*_N_ == C);
172 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
173 | }
174 |
--------------------------------------------------------------------------------
/wkv5/cuda/wkv5_cuda_v1.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | template
5 | __global__ void kernel_forward(const int B, const int T, const int C, const int H,
6 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u,
7 | F *__restrict__ const _y)
8 | {
9 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
10 | const int _b = idx / C;
11 | const int _h = (idx / N) % H;
12 | const int _i = idx % N;
13 |
14 | const int _o0 = _b*T*C + _h*N;
15 | const int _o1 = _h*N;
16 | const F *__restrict__ const k = _k + _o0;
17 | const F *__restrict__ const v = _v + _o0 + _i;
18 | const F *__restrict__ const r = _r + _o0;
19 | F *__restrict__ const y = _y + _o0 + _i;
20 |
21 | float state[N] = {0};
22 |
23 | for (int __t = 0; __t < T; __t++)
24 | {
25 | const int _t = __t*C;
26 | const F vv = v[_t];
27 |
28 | for (int _j = 0; _j < N; _j++)
29 | {
30 | const int j = _t + _j;
31 | const int m = _o1 + _j;
32 |
33 | const float x = k[j] * vv;
34 | const float s = state[_j];
35 |
36 | atomicAdd(y + _t, r[j] * (_u[m] * x + s));
37 | state[_j] = s * _w[m] + x;
38 | }
39 | }
40 | }
41 |
42 | template
43 | __global__ void kernel_backward (const int B, const int T, const int C, const int H,
44 | const F *__restrict__ const r, const F *__restrict__ const k, const F *__restrict__ const v, const F *__restrict__ w, const F *__restrict__ wwww, const F *__restrict__ u, const F *__restrict__ const gy,
45 | F *__restrict__ const gr, F *__restrict__ const gk, F *__restrict__ const gv, F *__restrict__ gw, F *__restrict__ gu)
46 | {
47 | const int b = blockIdx.x / H;
48 | const int h = blockIdx.x % H;
49 | const int i = threadIdx.x;
50 | w += h*N;
51 | u += h*N;
52 | gu += h*N;
53 | gw += h*N;
54 | wwww += h*N;
55 |
56 | __shared__ float state[N * N], vv[N], rr[N], kk[N], gyy[N];
57 |
58 | #pragma unroll
59 | for (int j = 0; j < N; ++j){
60 | state[j * N + i] = 0;
61 | }
62 |
63 | const float ww = w[i];
64 | const float uu = u[i];
65 | const float wwwww = wwww[i];
66 | float saaaa[N] = {0.0f}, sbbbb[N] = {0.0f};
67 |
68 | for (int _t = b*T*C + h*N + i, _tend = (b+1)*T*C + h*N + i; _t < _tend; _t += C)
69 | {
70 | const F kk = k[_t];
71 | const F rr = r[_t];
72 | F grr = 0;
73 | F guu = 0;
74 |
75 | vv[i] = v[_t];
76 | gyy[i] = gy[_t];
77 |
78 | __syncthreads();
79 |
80 | #pragma unroll
81 | for (int j = 0; j < N; j++)
82 | {
83 |
84 | float x = vv[j] * kk;
85 | float s = state[j * N + i];
86 |
87 | grr += gyy[j] * (uu * x + s);
88 | state[j * N + i] = s * ww + x;
89 | guu += rr * x * gyy[j];
90 |
91 | }
92 | gr[_t] = grr;
93 | atomicAdd(gu + i, guu);
94 |
95 | __syncthreads();
96 | if (_t < _tend - 2 * C){
97 | const F rr_value = r[_t+2*C];
98 | gyy[i] = gy[_t+2*C];
99 | __syncthreads();
100 |
101 | #pragma unroll
102 | for (int j = 0; j < N; j++){
103 | float x = vv[j] * kk;
104 | saaaa[j] = ww * (saaaa[j] + sbbbb[j] + x);
105 | sbbbb[j] = ww * (sbbbb[j] + x);
106 | atomicAdd(gw+i, rr_value * wwwww * saaaa[j] * gyy[j]);
107 | }
108 |
109 | __syncthreads();
110 | }
111 | }
112 |
113 | #pragma unroll
114 | for (int j = 0; j < N; ++j)
115 | state[j * N + i] = 0;
116 |
117 | for (int _t = (b+1)*T*C + h*N + i - C, _tend = b*T*C + h*N + i; _t >= _tend; _t -= C)
118 | {
119 | const F rr = r[_t];
120 | F gkk = 0;
121 |
122 | vv[i] = v[_t];
123 | gyy[i] = gy[_t];
124 |
125 | __syncthreads();
126 |
127 | #pragma unroll
128 | for (int j = 0; j < N; j++)
129 | {
130 |
131 | float x = gyy[j] * rr;
132 | float s = state[j * N + i];
133 |
134 | gkk += vv[j] * (uu * x + s);
135 | state[j * N + i] = s * ww + x;
136 | }
137 | gk[_t] = gkk;
138 | __syncthreads();
139 | }
140 |
141 | #pragma unroll
142 | for (int j = 0; j < N; ++j)
143 | state[j * N + i] = 0;
144 |
145 | for (int _t = (b+1)*T*C + h*N + i - C, _tend = b*T*C + h*N + i; _t >= _tend; _t -= C)
146 | {
147 | const F gy_value = gy[_t];
148 | F gvv = 0;
149 |
150 | kk[i] = k[_t];
151 | rr[i] = r[_t];
152 |
153 | __syncthreads();
154 |
155 | #pragma unroll
156 | for (int j = 0; j < N; j++)
157 | {
158 |
159 | float x = gy_value * rr[j];
160 | float s = state[j * N + i];
161 |
162 | gvv += kk[j] * (u[j] * x + s);
163 | state[j * N + i] = s * w[j] + x;
164 | }
165 | gv[_t] = gvv;
166 | __syncthreads();
167 | }
168 | }
169 |
170 | void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y)
171 | {
172 | dim3 threadsPerBlock( min(B*C, 32) );
173 | assert(B * C % threadsPerBlock.x == 0);
174 | dim3 numBlocks(B * C / threadsPerBlock.x);
175 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y);
176 | }
177 |
178 | void cuda_backward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *ww, float *u, float *gy, float *gr, float *gk, float *gv, float *gw, float *gu)
179 | {
180 | assert(H*N == C);
181 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
182 | }
183 |
--------------------------------------------------------------------------------
/wkv5/cuda/wkv5_cuda_v2.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | template