├── .gitignore
├── README.md
├── include
├── formula_23.h
└── formula_4.h
├── kernel
├── formula_23_kernel.cu
└── formula_4_kernel.cu
├── pytorch
├── formula_23.cpp
├── formula_4.cpp
├── sparse_vec-1651_11008.npy
├── sparse_vec-254_11008.npy
└── sparse_vec.npy
├── run_test-for-all-23.py
├── run_test-for-all-4.py
├── run_test_23.py
├── run_test_4.py
├── setup.py
└── t-whereis-ffn_4.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | nsys-reps/
3 | nsys-reps/*
4 | *.sqlite
5 | *.nsys-rep
6 | **/build/
7 | **/build/*
8 | **/dist/
9 | **/dist/*
10 | **/*.egg-info/
11 | .idea/
12 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Efficient GPU Operators for ReLU-Activated LLMs
2 |
3 | This is the source codes for our two sparse GPU operators mentioned in paper *ProSparse: Introducing and Enhancing Intrinsic Activation Sparsity within Large Language Models* ([link](https://arxiv.org/pdf/2402.13516.pdf)), tailored for the ReLU-activated FFNs in LLM.
4 |
5 | ### Background
6 |
7 | The utilization of activation sparsity, namely the existence of considerable weakly-contributed elements among activation outputs, is a promising method for inference acceleration of large language models (LLMs). Concretely, acceleration methods based on activation sparsity usually achieve higher inference speed by making wiser resource allocation and computation policies to avoid resource waste on these weakly-contributed parameters. However, existing acceleration frameworks are mostly approximate algorithms, which risk potential inference inaccuracies caused by invalid predictions made by activation predictors (e.g., [Deja Vu](https://proceedings.mlr.press/v202/liu23am/liu23am.pdf) and [PowerInfer](https://arxiv.org/pdf/2312.12456.pdf)).
8 |
9 | Therefore, to achieve acceleration without inference inaccuracies and test the practical speedup effects of ReLU-activated LLMs with higher sparsity, we implement two hardware-efficient sparse GPU operators with system-level optimizations, such as operator fusion, coalesced memory access, and vectorization, thereby exploiting input-side and output-side sparsity.
10 |
11 | ### Methodology
12 |
13 | Given the hidden dimension $`d_{model}`$ and the FFN intermediate dimension $`d_{ff}`$, the computation process of a gated FFN can be formalized as:
14 | ```math
15 | \mathbf{s} = \sigma(\mathbf{x} \mathbf{W}_s^T), \quad \mathbf{x}_1 = \mathbf{s} \odot (\mathbf{x} \mathbf{W}_1^T),\quad
16 | \text{FFN}(\mathbf{x}) = \mathbf{x}_1 \mathbf{W}_2^T,
17 | ```
18 | where $`\mathbf{x}\in\mathbb{R}^{d_{model}}`$, $`\mathbf{s}, \mathbf{x}_1\in\mathbb{R}^{d_{ff}}`$, $`\sigma`$, and $`\odot`$ denote the input hidden states, the gating scores, the intermediate outputs, the activation function, and the element-wise multiplication respectively. $`\mathbf{W}_s,\mathbf{W}_1\in\mathbb{R}^{d_{ff} \times d_{model}}`$ and $`\mathbf{W}_2\in\mathbb{R}^{d_{model} \times d_{ff}}`$ are learnable weights.
19 |
20 | We reorganize a ReLU-activated gated FFN into three major steps and our two operators, called **Operator Step (2)** `ffn_23` and **Operator Step (3)** `ffn_4`, are responsible for the step (2) and (3) respectively:
21 |
22 | (1) A dense matrix-vector multiplication operator $`\mathbf{x} \mathbf{W}_s^T`$ which can be directly supported by vendor libraries such as cuBLAS;
23 | (2) A fused operator of ReLU and $`\mathbf{s} \odot (\mathbf{x} \mathbf{W}_1^T)`$ with output-side sparsity;
24 | (3) A sparse matrix-vector multiplication operator $`\mathbf{x}_1 \mathbf{W}_2^T`$ with input-side sparsity.
25 |
26 | Codes for Operator Step (2) and Operator Step (3) are included in `kernel/formula_23_kernel.cu` and `kernel/formula_4_kernel.cu` respectively. For more implementation details, refer to Appendix C of [paper](https://arxiv.org/pdf/2402.13516.pdf).
27 |
28 | ### Results
29 |
30 | To test the practical acceleration effects of ReLU-activated LLMs with the above operators applied, we measure the average single-step wall-clock time spent by our two sparse GPU operators, which are responsible for step (2) and step (3) respectively. Major results are shown as follows, refer to Section 4.3 of [paper](https://arxiv.org/pdf/2402.13516.pdf) for more details. The ProSparse LLaMA2 models, which have ReLU-based high activation sparsity and comparable performance to original Swish-activated LLaMA2 versions, are available at the following links: [7B](https://huggingface.co/SparseLLM/prosparse-llama-2-7b) and [13B](https://huggingface.co/SparseLLM/prosparse-llama-2-13b).
31 |
32 | | Setting | Average
Sparsity | Step (2)
Time | Step (2)
Speedup | Step (3)
Time | Step (3)
Speedup |
33 | |:-------------------------:|:-------------------:|:----------------:|:-------------------:|:----------------:|:----------------:|
34 | | ReluLLaMA-7B | 66.98 | 67.12 | 1.35 | 63.00 | 1.32 |
35 | | Vanilla ReLU-7B | 66.04 | 67.85 | 1.33 | 63.28 | 1.31 |
36 | | Fixed $`L_1`$-7B | 91.46 | 40.99 | 2.21 | 54.19 | 1.53 |
37 | | **ProSparse-7B**$`^*`$ | 88.11 | 46.66 | 1.94 | 55.56 | 1.49 |
38 | | **ProSparse-7B** | 89.32 | 45.38 | 2.00 | 55.05 | 1.51 |
39 | | ReluLLaMA-13B | 71.56 | 69.92 | 1.88 | 75.47 | 1.51 |
40 | | **ProSparse-13B**$`^*`$ | 87.97 | 55.29 | 2.38 | 67.50 | 1.68 |
41 | | **ProSparse-13B** | 88.80 | 53.78 | 2.44 | 66.73 | 1.70 |
42 |
43 | `Time` means the average wall-clock time (us) cost by each step with our sparse GPU operators, and `Speedup` is the speedup ratio to the setting without operators. The average time for step (2) and (3) without sparse GPU operators is about **90.55 and 82.92 (us) for 7B, 131.36 and 113.68 (us) for 13B** respectively under all sparsity.
44 |
45 | As demonstrated by the above results, higher activation sparsity can make accurate algorithms based on GPU operators more efficient. Besides, our two sparse GPU operators also display satisfactory speedup ratios up to 2.44 and 1.70 respectively with better acceleration effects for larger models.
46 |
47 | ### Install
48 |
49 | Use the following command to install `ffn_23` for Operator Step (2) and `ffn_4` for Operator Step (3).
50 |
51 | ```bash
52 | python setup.py install
53 | ```
54 |
55 | **Note**: In some environments, the above command may not work. Under such cases, enter the root folder and then run `pip install .` twice after annotating the two `setup` function calls in `setup.py` one after the other to install the two operators one by one.
56 |
57 | ### Usage
58 |
59 | See `run_test_23.py` and `run_test_4.py`.
60 |
61 | ### Attention: FATReLU support
62 |
63 | Note that our Operator Step (2) supports FATReLU, a non-zero threshold ReLU variant:
64 | ```math
65 | \sigma(x)=
66 | \begin{cases}
67 | x \quad \mathrm{when}\ x \geq T, \\
68 | 0 \quad \mathrm{otherwise},
69 | \end{cases}
70 | ```
71 | where $`T>0`$ is a positive threshold. Remember to specify $`T`$ as the last parameter of a call to `ffn_23`, use 0 for vanilla ReLU.
72 |
73 | ### Attention: Data Types
74 |
75 | The default data type used in these codes is **bfloat16**. Nevertheless, other data types can be easily supported through an overall substitution of data types in source codes.
76 |
77 | ### Attention: Dimensions
78 |
79 | We found significant performance improvement if the dimensions are pre-defined as fixed values in Operator Step (3). The default dimensions are fixed to the settings of LLaMA-7B. If other dimensions (e.g., LLaMA-13B) have to be supported, just edit the macro variables in `kernel/formula_4_kernel.cu`.
80 |
81 | ```c++
82 | // Default setting for LLaMA-7B
83 | ......
84 | #define ROW_OPT 11008
85 | #define COL_OPT 4096
86 | ......
87 |
88 |
89 | // Example: change to the setting of LLaMA-13B
90 | ......
91 | #define ROW_OPT 13824
92 | #define COL_OPT 5120
93 | ......
94 |
95 | ```
96 |
97 | If one want to treat dimensions as variables, i.e. undefine the macros, please remove the `define_macros=[('USE_CONSTANT', None)],` line in the `setup.py` file.
98 |
99 | ### Citation
100 |
101 | Please kindly cite using the following BibTeX:
102 |
103 | ```bibtex
104 | @article{song2024prosparse,
105 | title={{ProSparse}: Introducing and Enhancing Intrinsic Activation Sparsity within Large Language Models},
106 | author={Song, Chenyang and Han, Xu and Zhang, Zhengyan and Hu, Shengding and Shi, Xiyu and Li, Kuai and Chen, Chen and Liu, Zhiyuan and Li, Guangli and Yang, Tao and Sun, Maosong},
107 | year={2024},
108 | journal={arXiv preprint arXiv:2402.13516},
109 | url={https://arxiv.org/pdf/2402.13516.pdf}
110 | }
111 | ```
112 |
--------------------------------------------------------------------------------
/include/formula_23.h:
--------------------------------------------------------------------------------
1 | // #include
2 | #include
3 | void launch_ffn_fuse_23(nv_bfloat16 *vec_sparse, nv_bfloat16 *vec_input,
4 | nv_bfloat16 *mat_up, nv_bfloat16 *res, unsigned int mat_row,
5 | unsigned int mat_col, float threshold = 0);
--------------------------------------------------------------------------------
/include/formula_4.h:
--------------------------------------------------------------------------------
1 | #include
2 | void launch_ffn_4(nv_bfloat16 *mat, nv_bfloat16 *vec, nv_bfloat16 *res,
3 | unsigned int mat_row, unsigned int mat_col);
--------------------------------------------------------------------------------
/kernel/formula_23_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | // #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | // Col major
9 | __global__ void ffn_fuse_23(nv_bfloat16 *vec_sparse, nv_bfloat16 *vec_input,
10 | nv_bfloat16 *mat_up, nv_bfloat16 *res, unsigned int mat_row,
11 | unsigned int mat_col, float threshold)
12 | {
13 | int col_id = blockIdx.y * 32 + threadIdx.y;
14 | int num_per_threadx = mat_row / 32;
15 | int row_chunk_id = threadIdx.x;
16 | int row_id = row_chunk_id * num_per_threadx;
17 |
18 | nv_bfloat16 *vec_sparse_p = &vec_sparse[col_id]; // per thread
19 | nv_bfloat16 *vec_input_p = &vec_input[row_id]; // per thread
20 | nv_bfloat16 *mat_up_p = &mat_up[col_id * mat_row + row_id]; // per thread, col-major
21 | nv_bfloat16 *res_p = &res[col_id]; // per thread
22 |
23 | float4 *vec_input_f4 = reinterpret_cast(vec_input_p);
24 | float4 vec_input_f_val;
25 | float4 *mat_up_f4 = reinterpret_cast(mat_up_p);
26 | float4 mat_up_f_val;
27 |
28 | float sum = 0;
29 | nv_bfloat16 vec_sparse_val = *vec_sparse_p;
30 | if (__bfloat162float(vec_sparse_val) <= threshold)
31 | {
32 | if (threadIdx.x == 0)
33 | {
34 | *res_p = __float2bfloat16(0.f);
35 | }
36 | }
37 | else
38 | {
39 | #pragma unroll
40 | for (int i = 0; i < (num_per_threadx / 8) /*8个half*/; i++)
41 | {
42 | vec_input_f_val = vec_input_f4[i];
43 | const nv_bfloat162 *vec_input_h1 = (nv_bfloat162 *)&vec_input_f_val.x;
44 | const nv_bfloat162 *vec_input_h2 = (nv_bfloat162 *)&vec_input_f_val.y;
45 | const nv_bfloat162 *vec_input_h3 = (nv_bfloat162 *)&vec_input_f_val.z;
46 | const nv_bfloat162 *vec_input_h4 = (nv_bfloat162 *)&vec_input_f_val.w;
47 |
48 | mat_up_f_val = mat_up_f4[i];
49 | const nv_bfloat162 *mat_up_h1 = (nv_bfloat162 *)&mat_up_f_val.x;
50 | const nv_bfloat162 *mat_up_h2 = (nv_bfloat162 *)&mat_up_f_val.y;
51 | const nv_bfloat162 *mat_up_h3 = (nv_bfloat162 *)&mat_up_f_val.z;
52 | const nv_bfloat162 *mat_up_h4 = (nv_bfloat162 *)&mat_up_f_val.w;
53 |
54 | sum += __bfloat162float(vec_input_h1->x) * __bfloat162float(mat_up_h1->x);
55 | sum += __bfloat162float(vec_input_h1->y) * __bfloat162float(mat_up_h1->y);
56 | sum += __bfloat162float(vec_input_h2->x) * __bfloat162float(mat_up_h2->x);
57 | sum += __bfloat162float(vec_input_h2->y) * __bfloat162float(mat_up_h2->y);
58 | sum += __bfloat162float(vec_input_h3->x) * __bfloat162float(mat_up_h3->x);
59 | sum += __bfloat162float(vec_input_h3->y) * __bfloat162float(mat_up_h3->y);
60 | sum += __bfloat162float(vec_input_h4->x) * __bfloat162float(mat_up_h4->x);
61 | sum += __bfloat162float(vec_input_h4->y) * __bfloat162float(mat_up_h4->y);
62 | }
63 | sum += __shfl_down_sync(0xffffffff, sum, 16);
64 | sum += __shfl_down_sync(0xffffffff, sum, 8);
65 | sum += __shfl_down_sync(0xffffffff, sum, 4);
66 | sum += __shfl_down_sync(0xffffffff, sum, 2);
67 | sum += __shfl_down_sync(0xffffffff, sum, 1);
68 |
69 | if (threadIdx.x == 0)
70 | {
71 | float sum_res = sum;
72 | sum_res = sum_res * __bfloat162float(vec_sparse_val);
73 | *res_p = __float2bfloat16(sum_res);
74 | }
75 | }
76 | }
77 |
78 | void launch_ffn_fuse_23(nv_bfloat16 *vec_sparse, nv_bfloat16 *vec_input,
79 | nv_bfloat16 *mat_up, nv_bfloat16 *res, unsigned int mat_row,
80 | unsigned int mat_col, float threshold)
81 | {
82 | dim3 grid_dim(1, mat_col / 32);
83 | dim3 block_dim(32, 32, 1);
84 |
85 | ffn_fuse_23<<>>(vec_sparse, vec_input, mat_up, res,
86 | mat_row, mat_col, threshold);
87 | }
--------------------------------------------------------------------------------
/kernel/formula_4_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 |
7 | #define ROW_OPT 11008
8 | #define COL_OPT 4096
9 |
10 | // Row Major
11 | // (32, 32, 1) (mat_row / 32)
12 | __global__ void ffn_4(nv_bfloat16 *mat, nv_bfloat16 *vec, nv_bfloat16 *res,
13 | unsigned int mat_row, unsigned int mat_col)
14 | {
15 |
16 | #ifdef USE_CONSTANT
17 | mat_row = ROW_OPT;
18 | mat_col = COL_OPT;
19 | #endif
20 |
21 | float sum = 0;
22 | // nv_bfloat16 sum = __float2bfloat16(0.0f);
23 | __shared__ float warp_sum[32];
24 | warp_sum[threadIdx.x] = 0.0f;
25 |
26 | unsigned int col_id = blockIdx.y * 32 + threadIdx.x; // (0,512) (0,32), max:32*511+32=16384
27 | nv_bfloat16 *res_p = &res[col_id];
28 | unsigned int warp_id = threadIdx.y; // (0,32)
29 | unsigned int row_id = warp_id;
30 | nv_bfloat16 *vec_p = &vec[row_id];
31 | nv_bfloat16 *mat_p = &mat[row_id * mat_col + col_id];
32 | nv_bfloat16 mat_val = __float2bfloat16(0.0f);
33 | #pragma unroll 32
34 | for (int iter = 0; iter < mat_row; iter = iter + 32)
35 | {
36 | nv_bfloat16 vec_val = vec_p[iter];
37 | if (__bfloat162float(vec_val) == 0.0f)
38 | continue;
39 | else
40 | mat_val = mat_p[iter * mat_col];
41 | sum += __bfloat162float(vec_val) * __bfloat162float(mat_val);
42 | }
43 | atomicAdd(&warp_sum[threadIdx.x], sum);
44 |
45 | __syncthreads();
46 | if (warp_id == 0)
47 | {
48 | // Write final result
49 | float sum = warp_sum[threadIdx.x];
50 | *res_p = __float2bfloat16(sum);
51 | }
52 | }
53 |
54 | void launch_ffn_4(nv_bfloat16 *mat, nv_bfloat16 *vec, nv_bfloat16 *res,
55 | unsigned int mat_row, unsigned int mat_col)
56 | {
57 | #ifdef USE_CONSTANT
58 | mat_row = ROW_OPT;
59 | mat_col = COL_OPT;
60 | #endif
61 |
62 | dim3 grid_dim(1, mat_col / 32);
63 | dim3 block_dim(32, 32, 1);
64 |
65 | ffn_4<<>>(mat, vec, res, mat_row, mat_col);
66 | }
--------------------------------------------------------------------------------
/pytorch/formula_23.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "formula_23.h"
3 |
4 | void torch_launch_ffn_fuse_23(torch::Tensor &vec_sparse,
5 | torch::Tensor &vec_input,
6 | torch::Tensor &mat_up,
7 | torch::Tensor &res,
8 | int mat_row, int mat_col, float threshold = 0.)
9 | {
10 | launch_ffn_fuse_23((nv_bfloat16 *)vec_sparse.data_ptr(),
11 | (nv_bfloat16 *)vec_input.data_ptr(),
12 | (nv_bfloat16 *)mat_up.data_ptr(),
13 | (nv_bfloat16 *)res.data_ptr(), mat_row, mat_col, threshold);
14 | }
15 |
16 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
17 | {
18 | m.def("torch_launch_ffn_fuse_23",
19 | &torch_launch_ffn_fuse_23,
20 | "ffn_fuse_23 kernel warpper");
21 | }
22 |
--------------------------------------------------------------------------------
/pytorch/formula_4.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "formula_4.h"
3 |
4 | void torch_launch_ffn_4(torch::Tensor &mat,
5 | torch::Tensor &vec,
6 | torch::Tensor &res,
7 | int mat_row, int mat_col)
8 | {
9 | launch_ffn_4((nv_bfloat16 *)mat.data_ptr(),
10 | (nv_bfloat16 *)vec.data_ptr(),
11 | (nv_bfloat16 *)res.data_ptr(), mat_row, mat_col);
12 | }
13 |
14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
15 | {
16 | m.def("torch_launch_ffn_4",
17 | &torch_launch_ffn_4,
18 | "ffn_4 kernel warpper");
19 | }
20 |
--------------------------------------------------------------------------------
/pytorch/sparse_vec-1651_11008.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Raincleared-Song/sparse_gpu_operator/1f0b9af8167929881d5ecd01acf09e5a3c06a1bc/pytorch/sparse_vec-1651_11008.npy
--------------------------------------------------------------------------------
/pytorch/sparse_vec-254_11008.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Raincleared-Song/sparse_gpu_operator/1f0b9af8167929881d5ecd01acf09e5a3c06a1bc/pytorch/sparse_vec-254_11008.npy
--------------------------------------------------------------------------------
/pytorch/sparse_vec.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Raincleared-Song/sparse_gpu_operator/1f0b9af8167929881d5ecd01acf09e5a3c06a1bc/pytorch/sparse_vec.npy
--------------------------------------------------------------------------------
/run_test-for-all-23.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import torch
4 | import ffn_23
5 |
6 | nwarmup = 10
7 | ntest = 100
8 |
9 | def show_time(func):
10 | times = list()
11 | res = None
12 | # GPU warm up
13 | for _ in range(nwarmup):
14 | res = func()
15 | for _ in range(ntest):
16 | # sync the threads to get accurate cuda running time
17 | torch.cuda.synchronize(device="cuda:0")
18 | start_time = time.time()
19 | res = func()
20 | torch.cuda.synchronize(device="cuda:0")
21 | end_time = time.time()
22 | times.append((end_time-start_time)*1e6)
23 | return times, res
24 |
25 | def compare_tensors(res_cuda, res_torch, tolerance):
26 | if res_cuda.shape != res_torch.shape:
27 | print("Tensor shapes are different.")
28 | return False
29 | res_cuda_list = res_cuda.tolist()
30 | res_torch_list = res_torch.tolist()
31 |
32 | for index, (a, b) in enumerate(zip(res_cuda_list, res_torch_list)):
33 | if (abs(b) == 0 and abs(a) > tolerance or abs(b) > 0 and abs(a - b) / abs(b) > tolerance):
34 | print(f"Index {index}: diff = {a-b}")
35 | return False
36 | return True
37 |
38 | mat_row = 4096
39 | mat_col = 11008
40 | threshold = 0.
41 |
42 | file_path = 'pytorch/sparse_vec.npy'
43 | data = np.load(file_path)
44 |
45 |
46 | # first_row = data[18, :]
47 | # vec_sparse = torch.tensor(first_row, device="cuda:0")
48 |
49 | # assert vec_sparse.shape == (mat_col,), f"Expected shape (mat_col,), but got {first_row.shape}"
50 | vec_sparse = torch.zeros(mat_col, device="cuda:0", dtype=torch.float16)
51 | vec = torch.rand(mat_row, device="cuda:0", dtype=torch.bfloat16)
52 | # mat = torch.zeros(mat_row, mat_col, device="cuda:0", dtype=torch.bfloat16)
53 | cuda_res = torch.zeros(mat_col, device="cuda:0", dtype=torch.bfloat16)
54 |
55 |
56 | def run_torch():
57 | res = torch.matmul(vec, mat)
58 | res = res * vec_sparse
59 | # res[vec_sparse == 0] = 0
60 | return res
61 |
62 | def run_cuda():
63 | ffn_23.torch_launch_ffn_fuse_23(vec_sparse, vec, mat, cuda_res, mat_row, mat_col, threshold)
64 | return cuda_res
65 |
66 |
67 | print(f"index,num_nonzero_elements,%,cuda_time,torch_time")
68 | for i in range(100):
69 | row = data[i, :]
70 | vec_sparse = torch.tensor(row, device="cuda:0")
71 | vec_sparse = vec_sparse.to(dtype=torch.bfloat16)
72 | assert vec_sparse.shape == (mat_col,), f"Expected shape (mat_col,), but got {row.shape}"
73 |
74 | nonzero_indices = torch.nonzero(vec_sparse)
75 | num_nonzero_elements = nonzero_indices.size(0)
76 | mat = torch.rand(mat_row, mat_col, device="cuda:0", dtype=torch.bfloat16)
77 | torch_time, torch_res = show_time(run_torch)
78 | mat = mat.t().contiguous()
79 | cuda_time, cuda_res = show_time(run_cuda)
80 |
81 | print(f"{i},{num_nonzero_elements},{round(num_nonzero_elements/mat_col*100, 3)},{np.mean(cuda_time)},{np.mean(torch_time)}")
82 |
83 | tolerance = 0.01
84 | if not compare_tensors(cuda_res, torch_res, tolerance):
85 | from IPython import embed
86 | embed()
87 | exit()
88 |
--------------------------------------------------------------------------------
/run_test-for-all-4.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import torch
4 | import ffn_4
5 |
6 | nwarmup = 10
7 | ntest = 100
8 |
9 | def show_time(func):
10 | times = list()
11 | res = None
12 | # GPU warm up
13 | for _ in range(nwarmup):
14 | res = func()
15 | for _ in range(ntest):
16 | # sync the threads to get accurate cuda running time
17 | torch.cuda.synchronize(device="cuda:0")
18 | start_time = time.time()
19 | res = func()
20 | torch.cuda.synchronize(device="cuda:0")
21 | end_time = time.time()
22 | times.append((end_time-start_time)*1e6)
23 | return times, res
24 |
25 | def compare_tensors(res_cuda, res_torch, tolerance):
26 | if res_cuda.shape != res_torch.shape:
27 | print("Tensor shapes are different.")
28 | return False
29 | res_cuda_list = res_cuda.tolist()
30 | res_torch_list = res_torch.tolist()
31 |
32 | for index, (a, b) in enumerate(zip(res_cuda_list, res_torch_list)):
33 | if (abs(b) == 0 and abs(a) > tolerance or abs(b) > 0 and abs(a - b) / abs(b) > tolerance):
34 | print(f"Index {index}: diff = {a-b}")
35 | return False
36 | return True
37 |
38 | mat_row = 11008
39 | mat_col = 4096
40 |
41 | # mat_row = 256;
42 | # mat_col = 512;
43 | # mat = torch.rand(mat_row, mat_col, device="cuda:0", dtype=torch.float32)
44 | # vec = torch.rand(mat_row, device="cuda:0", dtype=torch.float32)
45 | # ffn_4.torch_launch_ffn_4(mat.to(dtype=torch.bfloat16), vec.to(dtype=torch.bfloat16), res_cuda, mat_row, mat_col)
46 |
47 | file_path = 'pytorch/sparse_vec.npy'
48 | data = np.load(file_path)
49 |
50 | mat = torch.rand(mat_row, mat_col, device="cuda:0", dtype=torch.bfloat16)
51 | vec = torch.zeros(mat_row, device="cuda:0", dtype=torch.bfloat16)
52 | cuda_res = torch.zeros(mat_col, device="cuda:0", dtype=torch.bfloat16)
53 |
54 |
55 | def run_cuda():
56 | ffn_4.torch_launch_ffn_4(mat, vec, cuda_res, mat_row, mat_col)
57 | return cuda_res
58 |
59 | def run_torch():
60 | res = torch.matmul(vec, mat)
61 | return res
62 |
63 |
64 | for i in range(100):
65 | row = data[i, :]
66 | vec = torch.tensor(row, device="cuda:0")
67 | vec = vec.to(dtype=torch.bfloat16)
68 | assert vec.shape == (mat_row,), f"Expected shape (mat_row,), but got {row.shape}"
69 |
70 | nonzero_indices = torch.nonzero(vec)
71 | num_nonzero_elements = nonzero_indices.size(0)
72 | cuda_time, cuda_res = show_time(run_cuda)
73 | torch_time, torch_res = show_time(run_torch)
74 |
75 | print(f"{i},{num_nonzero_elements},{round(num_nonzero_elements/mat_row*100, 3)},{np.mean(cuda_time)},{np.mean(torch_time)}")
76 |
77 | tolerance = 0.01
78 | if not compare_tensors(cuda_res, torch_res, tolerance):
79 | from IPython import embed
80 | embed()
81 | exit()
82 |
--------------------------------------------------------------------------------
/run_test_23.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import torch
4 | import ffn_23
5 |
6 | nwarmup = 10
7 | ntest = 100
8 |
9 | def show_time(func):
10 | times = list()
11 | res = None
12 | # GPU warm up
13 | for _ in range(nwarmup):
14 | res = func()
15 | for _ in range(ntest):
16 | # sync the threads to get accurate cuda running time
17 | torch.cuda.synchronize(device="cuda:0")
18 | start_time = time.time()
19 | res = func()
20 | torch.cuda.synchronize(device="cuda:0")
21 | end_time = time.time()
22 | times.append((end_time-start_time) * 1e6)
23 | return times, res
24 |
25 | def compare_tensors(res_cuda, res_torch, tolerance):
26 | if res_cuda.shape != res_torch.shape:
27 | print("Tensor shapes are different.")
28 | return False
29 | res_cuda_list = res_cuda.tolist()
30 | res_torch_list = res_torch.tolist()
31 |
32 | for index, (a, b) in enumerate(zip(res_cuda_list, res_torch_list)):
33 | if (abs(b) == 0 and abs(a) > tolerance or abs(b) > 0 and abs(a - b) / abs(b) > tolerance):
34 | print(f"Index {index}: diff = {a-b}")
35 | return False
36 | return True
37 |
38 | mat_row = 4096
39 | mat_col = 11008
40 | fatrelu_threshold = 0.
41 |
42 | # file_path = 'pytorch/sparse_vec.npy'
43 | # data = np.load(file_path)
44 | # first_row = data[9, :]
45 | # vec_sparse = torch.tensor(first_row, device="cuda:0", dtype=torch.bfloat16)
46 |
47 | for idx in range(10):
48 | vec_sparse = torch.rand(mat_col, device="cuda:0", dtype=torch.bfloat16)
49 | vec_sparse = torch.relu(vec_sparse - idx / 10)
50 | print(">>> act_rate:", round(torch.sum(vec_sparse > 0).item() * 100 / vec_sparse.numel(), 2))
51 | # assert vec_sparse.shape == (mat_col,), f"Expected shape (mat_col,), but got {first_row.shape}"
52 | vec = torch.rand(mat_row, device="cuda:0", dtype=torch.bfloat16)
53 | mat = torch.rand(mat_row, mat_col, device="cuda:0", dtype=torch.bfloat16)
54 | cuda_res = torch.zeros(mat_col, device="cuda:0", dtype=torch.bfloat16)
55 |
56 | def run_torch():
57 | res = torch.matmul(vec, mat)
58 | res = res * vec_sparse
59 | return res
60 |
61 | def run_cuda():
62 | ffn_23.torch_launch_ffn_fuse_23(vec_sparse, vec, mat, cuda_res, mat_row, mat_col, fatrelu_threshold)
63 | return cuda_res
64 |
65 | # 使用大量计算清空 GPU 缓存
66 | for _ in range(100):
67 | x = torch.rand(1000, 2000, device="cuda:0", dtype=torch.bfloat16)
68 | y = torch.rand(2000, 1000, device="cuda:0", dtype=torch.bfloat16)
69 | x = x ** 2
70 | y = y ** 0.5
71 | z = x @ y
72 | print("Running torch...")
73 | torch_time, torch_res = show_time(run_torch)
74 | print("Torch time: {:.4f} us".format(np.mean(torch_time)))
75 |
76 | mat = mat.t().contiguous() # mat转变为列主序存储后再传给kernel
77 | # 使用大量计算清空 GPU 缓存
78 | for _ in range(100):
79 | x = torch.rand(1000, 2000, device="cuda:0", dtype=torch.bfloat16)
80 | y = torch.rand(2000, 1000, device="cuda:0", dtype=torch.bfloat16)
81 | x = x ** 2
82 | y = y ** 0.5
83 | z = x @ y
84 | print("Running cuda...")
85 | cuda_time, cuda_res = show_time(run_cuda)
86 | print("Cuda time: {:.4f} us".format(np.mean(cuda_time)))
87 |
88 | tolerance = 0.01
89 | if not compare_tensors(cuda_res, torch_res, tolerance):
90 | from IPython import embed
91 | embed()
92 | exit()
93 |
--------------------------------------------------------------------------------
/run_test_4.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import torch
4 | import ffn_4
5 |
6 | nwarmup = 10
7 | ntest = 100
8 |
9 | def show_time(func):
10 | times = list()
11 | res = None
12 | # GPU warm up
13 | for _ in range(nwarmup):
14 | res = func()
15 | for _ in range(ntest):
16 | # sync the threads to get accurate cuda running time
17 | torch.cuda.synchronize(device="cuda:0")
18 | start_time = time.time()
19 | res = func()
20 | torch.cuda.synchronize(device="cuda:0")
21 | end_time = time.time()
22 | times.append((end_time-start_time)*1e6)
23 | return times, res
24 |
25 | def compare_tensors(res_cuda, res_torch, tolerance):
26 | if res_cuda.shape != res_torch.shape:
27 | print("Tensor shapes are different.")
28 | return False
29 | res_cuda_list = res_cuda.tolist()
30 | res_torch_list = res_torch.tolist()
31 |
32 | for index, (a, b) in enumerate(zip(res_cuda_list, res_torch_list)):
33 | if (abs(b) == 0 and abs(a) > tolerance or abs(b) > 0 and abs(a - b) / abs(b) > tolerance):
34 | print(f"Index {index}: diff = {a-b}")
35 | return False
36 | return True
37 |
38 | mat_row = 11008
39 | mat_col = 4096
40 |
41 | # mat_row = 256;
42 | # mat_col = 512;
43 | # mat = torch.rand(mat_row, mat_col, device="cuda:0", dtype=torch.float32)
44 | # vec = torch.rand(mat_row, device="cuda:0", dtype=torch.float32)
45 | # ffn_4.torch_launch_ffn_4(mat.to(dtype=torch.float16), vec.to(dtype=torch.float16), res_cuda, mat_row, mat_col)
46 |
47 |
48 | # file_path = 'pytorch/sparse_vec.npy'
49 | # data = np.load(file_path)
50 | # first_row = data[6, :]
51 |
52 | # file_path = 'pytorch/sparse_vec-1651_11008.npy'
53 | # data = np.load(file_path)
54 | # first_row = data
55 |
56 | # vec = torch.tensor(first_row, device="cuda:0").to(dtype=torch.bfloat16)
57 | # assert vec.shape == (mat_row,), f"Expected shape (mat_row,), but got {first_row.shape}"
58 | # vec = torch.rand(mat_row, device="cuda:0", dtype=torch.float16)
59 | # vec = vec.to(dtype=torch.bfloat16)
60 |
61 | for idx in range(10):
62 | vec = torch.rand(mat_row, device="cuda:0", dtype=torch.bfloat16)
63 | vec = torch.relu(vec - idx / 10)
64 | print(">>> sparsity:", round(torch.sum(vec > 0).item() * 100 / vec.numel(), 2))
65 |
66 | mat = torch.rand(mat_row, mat_col, device="cuda:0", dtype=torch.bfloat16)
67 | cuda_res = torch.zeros(mat_col, device="cuda:0", dtype=torch.bfloat16)
68 |
69 | def run_cuda():
70 | ffn_4.torch_launch_ffn_4(mat, vec, cuda_res, mat_row, mat_col)
71 | return cuda_res
72 |
73 | def run_torch():
74 | res = torch.matmul(vec, mat)
75 | return res
76 |
77 | print("Running torch...")
78 | torch_time, torch_res = show_time(run_torch)
79 | print("Torch time: {:.4f} us".format(np.mean(torch_time)))
80 |
81 |
82 | print("Running cuda...")
83 | cuda_time, cuda_res = show_time(run_cuda)
84 | print("Cuda time: {:.4f} us".format(np.mean(cuda_time)))
85 |
86 | tolerance = 0.01
87 | if not compare_tensors(cuda_res, torch_res, tolerance):
88 | from IPython import embed
89 | embed()
90 | exit()
91 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(
5 | name="ffn_4",
6 | include_dirs=["include"],
7 | ext_modules=[
8 | CUDAExtension(
9 | "ffn_4",
10 | ["pytorch/formula_4.cpp", "kernel/formula_4_kernel.cu"],
11 | define_macros=[('USE_CONSTANT', None)],
12 | ),
13 | ],
14 | cmdclass={
15 | "build_ext": BuildExtension
16 | }
17 | )
18 |
19 | setup(
20 | name="ffn_23",
21 | include_dirs=["include"],
22 | ext_modules=[
23 | CUDAExtension(
24 | "ffn_23",
25 | ["pytorch/formula_23.cpp", "kernel/formula_23_kernel.cu"],
26 | )
27 | ],
28 | cmdclass={
29 | "build_ext": BuildExtension
30 | }
31 | )
32 |
--------------------------------------------------------------------------------
/t-whereis-ffn_4.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import ffn_4
3 | print(ffn_4.__file__)
4 |
5 | # /home/xxx/download/anaconda3/envs/pt2/lib/python3.8/site-packages/ffn_4-0.0.0-py3.8-linux-x86_64.egg/ffn_4.cpython-38-x86_64-linux-gnu.so
--------------------------------------------------------------------------------