├── .gitignore ├── README.md ├── cuda ├── .gitignore ├── Makefile ├── cuda_common.cuh ├── matrix_add.cu ├── vector_add.cu └── vector_add_simple.cu ├── notes ├── 0000 - Table of Contents.pdf ├── 0001 - Multi Head Attention.pdf ├── 0002 - (Safe) Softmax.pdf ├── 0003 - Online Softmax.pdf ├── 0004 - Block Matrix Multiplication.pdf ├── 0005 - Intro to GPU & CUDA.pdf ├── 0006 - Tensor Layouts.pdf ├── 0007 - Software Pipelining.pdf ├── 0008 - Autograd & Gradients.pdf ├── 009 - Gradient of MatMul.pdf └── 010 - Gradient of Softmax.pdf └── triton ├── flash_attention.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flash Attention implemented with Triton 2 | 3 | Implements the Flash Attention 2 algorithm, based on the code published by OpenAI's team at [Fused Attention](https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html) 4 | 5 | It also includes some cuda examples as shown in the video. 6 | 7 | Install the requirements at `triton/requirements.txt` to launch the Python file. Adjust the `BATCH_SIZE`, `NUM_HEADS`, `SEQ_LEN`, `HEAD_DIM` to make sure your computer doesn't explode. 8 | 9 | The *naive* implementation materializes a `SEQ_LEN x SEQ_LEN` tensor, so it may be the bottleneck in running this code. Just disable it and try to push the `SEQ_LEN` of the Flash Attention to the limit supported by your hardware. 10 | 11 | Not tested on AMD, so let me know! 12 | 13 | ## Exercise 1: autotuning the backwards pass 14 | 15 | Can you apply autotuning configs to the backwards pass like done for the forward pass? 16 | 17 | ## Exercise 2: how to make Flash Attention faster 18 | 19 | As you can see, during the backwards pass we are going through the entire `SEQ_LEN` even when the attention calculation is `causal`, can you avoid going through all tokens that would not contribute to any change in `dK`, `dQ` and `dV` when the attention calculation is causal? -------------------------------------------------------------------------------- /cuda/.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | -------------------------------------------------------------------------------- /cuda/Makefile: -------------------------------------------------------------------------------- 1 | NVCCFLAGS = -g -G 2 | 3 | all: build 4 | 5 | build: matrix_add.out vector_add.out vector_add_simple.out 6 | 7 | matrix_add.out: matrix_add.cu 8 | nvcc $(NVCCFLAGS) -o matrix_add.out matrix_add.cu 9 | 10 | vector_add.out: vector_add.cu 11 | nvcc $(NVCCFLAGS) -o vector_add.out vector_add.cu 12 | 13 | vector_add_simple.out: vector_add_simple.cu 14 | nvcc $(NVCCFLAGS) -o vector_add_simple.out vector_add_simple.cu 15 | 16 | clean: 17 | rm -f vector_add.out 18 | rm -f matrix_add.out 19 | rm -f vector_add_simple.out 20 | 21 | run: build 22 | ./vector_add.out 23 | ./matrix_add.out 24 | ./vector_add_simple.out -------------------------------------------------------------------------------- /cuda/cuda_common.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define CUDA_CHECK(err) do { cuda_check((err), __FILE__, __LINE__); } while(false) 4 | inline void cuda_check(cudaError_t error_code, const char *file, int line) 5 | { 6 | if (error_code != cudaSuccess) 7 | { 8 | fprintf(stderr, "CUDA Error %d: %s. In file '%s' on line %d\n", error_code, cudaGetErrorString(error_code), file, line); 9 | fflush(stderr); 10 | exit(error_code); 11 | } 12 | } -------------------------------------------------------------------------------- /cuda/matrix_add.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "cuda_common.cuh" 9 | 10 | typedef float EL_TYPE; 11 | 12 | __global__ void cuda_matrix_add(EL_TYPE *OUT, EL_TYPE *A, EL_TYPE *B, int NUM_ROWS, int NUM_COLS) 13 | { 14 | int row_index = blockIdx.y * blockDim.y + threadIdx.y; 15 | int col_index = blockIdx.x * blockDim.x + threadIdx.x; 16 | 17 | if (row_index < NUM_ROWS && col_index < NUM_COLS) 18 | { 19 | size_t index = static_cast(row_index) * NUM_COLS + col_index; // A[row_index][col_index] 20 | OUT[index] = A[index] + B[index]; 21 | } 22 | } 23 | 24 | void test_matrix_add(int NUM_ROWS, int NUM_COLS, int ROWS_block_size, int COLS_block_size) 25 | { 26 | EL_TYPE *A, *B, *OUT; 27 | EL_TYPE *d_A, *d_B, *d_OUT; 28 | 29 | // Allocate the matrices on the host device 30 | A = (EL_TYPE *)malloc(sizeof(EL_TYPE) * NUM_ROWS * NUM_COLS); 31 | B = (EL_TYPE *)malloc(sizeof(EL_TYPE) * NUM_ROWS * NUM_COLS); 32 | OUT = (EL_TYPE *)malloc(sizeof(EL_TYPE) * NUM_ROWS * NUM_COLS); 33 | 34 | // Initialize the matrices with random values 35 | for (int i = 0; i < NUM_ROWS; i++) 36 | { 37 | for (int j = 0; j < NUM_COLS; j++) 38 | { 39 | size_t index = static_cast(i) * NUM_COLS + j; 40 | A[index] = rand() % 100; 41 | B[index] = rand() % 100; 42 | } 43 | } 44 | 45 | // Allocate device memory for a 46 | CUDA_CHECK(cudaMalloc((void **)&d_A, sizeof(EL_TYPE) * NUM_ROWS * NUM_COLS)); 47 | CUDA_CHECK(cudaMalloc((void **)&d_B, sizeof(EL_TYPE) * NUM_ROWS * NUM_COLS)); 48 | CUDA_CHECK(cudaMalloc((void **)&d_OUT, sizeof(EL_TYPE) * NUM_ROWS * NUM_COLS)); 49 | 50 | // Transfer the matrices to the device 51 | CUDA_CHECK(cudaMemcpy(d_A, A, sizeof(EL_TYPE) * NUM_ROWS * NUM_COLS, cudaMemcpyHostToDevice)); 52 | CUDA_CHECK(cudaMemcpy(d_B, B, sizeof(EL_TYPE) * NUM_ROWS * NUM_COLS, cudaMemcpyHostToDevice)); 53 | 54 | cudaEvent_t start_kernel, stop_kernel; 55 | CUDA_CHECK(cudaEventCreate(&start_kernel)); 56 | CUDA_CHECK(cudaEventCreate(&stop_kernel)); 57 | 58 | CUDA_CHECK(cudaEventRecord(start_kernel)); 59 | 60 | // Define the launch grid 61 | int num_blocks_ROWS = (NUM_ROWS + ROWS_block_size - 1) / ROWS_block_size; // ceil(NUM_ROWS / ROWS_block_size) 62 | int num_blocks_COLS = (NUM_COLS + COLS_block_size - 1) / COLS_block_size; // ceil(NUM_COLS / COLS_block_size) 63 | printf("Matrix Add - M: %d, N: %d will be processed by (%d x %d) blocks of size (%d x %d)\n", NUM_ROWS, NUM_COLS, num_blocks_ROWS, num_blocks_COLS, ROWS_block_size, COLS_block_size); 64 | dim3 grid(num_blocks_COLS, num_blocks_ROWS, 1); 65 | dim3 block(COLS_block_size, ROWS_block_size, 1); 66 | // Run the kernel 67 | cuda_matrix_add<<>>(d_OUT, d_A, d_B, NUM_ROWS, NUM_COLS); 68 | 69 | // Check for launch errors 70 | CUDA_CHECK(cudaPeekAtLastError()); 71 | CUDA_CHECK(cudaEventRecord(stop_kernel)); 72 | CUDA_CHECK(cudaEventSynchronize(stop_kernel)); 73 | 74 | // Calculate elapsed milliseconds 75 | float milliseconds_kernel = 0; 76 | CUDA_CHECK(cudaEventElapsedTime(&milliseconds_kernel, start_kernel, stop_kernel)); 77 | printf("Matrix Add - Elapsed time: %f ms\n", milliseconds_kernel); 78 | 79 | // Copy back the result from the device to the host 80 | CUDA_CHECK(cudaMemcpy(OUT, d_OUT, sizeof(EL_TYPE) * NUM_ROWS * NUM_COLS, cudaMemcpyDeviceToHost)); 81 | 82 | // Free the memory on the device 83 | CUDA_CHECK(cudaFree(d_A)); 84 | CUDA_CHECK(cudaFree(d_B)); 85 | CUDA_CHECK(cudaFree(d_OUT)); 86 | 87 | // Time the operation 88 | struct timeval start_check, end_check; 89 | gettimeofday(&start_check, NULL); 90 | 91 | for (int i = 0; i < NUM_ROWS; i++) 92 | { 93 | for (int j = 0; j < NUM_COLS; j++) 94 | { 95 | size_t index = static_cast(i) * NUM_COLS + j; 96 | if (OUT[index] != A[index] + B[index]) 97 | { 98 | printf("Error at index (%d, %d): %.2f != %.2f + %.2f\n", i, j, OUT[index], A[index], B[index]); 99 | exit(1); 100 | } 101 | } 102 | } 103 | 104 | // Calculate elapsed time 105 | gettimeofday(&end_check, NULL); 106 | float elapsed = (end_check.tv_sec - start_check.tv_sec) * 1000.0 + (end_check.tv_usec - start_check.tv_usec) / 1000.0; 107 | printf("Matrix Add - Check elapsed time: %f ms\n", elapsed); 108 | 109 | printf("Matrix Add - Result OK\n"); 110 | 111 | // Free the memory on the host 112 | free(A); 113 | free(B); 114 | free(OUT); 115 | } 116 | 117 | int main() 118 | { 119 | // set your seed 120 | srand(0); 121 | 122 | test_matrix_add(10000, 10000, 16, 16); 123 | 124 | } -------------------------------------------------------------------------------- /cuda/vector_add.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "cuda_common.cuh" 9 | 10 | typedef int EL_TYPE; 11 | 12 | __global__ void cuda_vector_add(EL_TYPE *OUT, EL_TYPE *A, EL_TYPE *B, int N) 13 | { 14 | int i = blockIdx.x * blockDim.x + threadIdx.x; 15 | if (i < N) 16 | { 17 | OUT[i] = A[i] + B[i]; 18 | } 19 | } 20 | 21 | void test_vector_add(int N, int block_size) 22 | { 23 | EL_TYPE *A, *B, *OUT; 24 | EL_TYPE *d_A, *d_B, *d_OUT; 25 | 26 | // Allocate the vectors on the host device 27 | A = (EL_TYPE *)malloc(sizeof(EL_TYPE) * N); 28 | B = (EL_TYPE *)malloc(sizeof(EL_TYPE) * N); 29 | OUT = (EL_TYPE *)malloc(sizeof(EL_TYPE) * N); 30 | 31 | // Initialize the vectors with random values 32 | for (int i = 0; i < N; i++) 33 | { 34 | A[i] = rand() % 100; 35 | B[i] = rand() % 100; 36 | } 37 | 38 | // Allocate device memory for a 39 | CUDA_CHECK(cudaMalloc((void **)&d_A, sizeof(EL_TYPE) * N)); 40 | CUDA_CHECK(cudaMalloc((void **)&d_B, sizeof(EL_TYPE) * N)); 41 | CUDA_CHECK(cudaMalloc((void **)&d_OUT, sizeof(EL_TYPE) * N)); 42 | 43 | // Transfer the vectors to the device 44 | CUDA_CHECK(cudaMemcpy(d_A, A, sizeof(EL_TYPE) * N, cudaMemcpyHostToDevice)); 45 | CUDA_CHECK(cudaMemcpy(d_B, B, sizeof(EL_TYPE) * N, cudaMemcpyHostToDevice)); 46 | 47 | // Define the launch grid 48 | int num_blocks = ceil((float)N / block_size); 49 | printf("Vector Add - N: %d will be processed by %d blocks of size %d\n", N, num_blocks, block_size); 50 | dim3 grid(num_blocks, 1, 1); 51 | dim3 block(block_size, 1, 1); 52 | 53 | cudaEvent_t start_kernel, stop_kernel; 54 | CUDA_CHECK(cudaEventCreate(&start_kernel)); 55 | CUDA_CHECK(cudaEventCreate(&stop_kernel)); 56 | 57 | CUDA_CHECK(cudaEventRecord(start_kernel)); 58 | // Run the kernel 59 | cuda_vector_add<<>>(d_OUT, d_A, d_B, N); 60 | CUDA_CHECK(cudaEventRecord(stop_kernel)); 61 | // Check for launch errors 62 | CUDA_CHECK(cudaPeekAtLastError()); 63 | CUDA_CHECK(cudaEventSynchronize(stop_kernel)); 64 | 65 | // Calculate elapsed milliseconds 66 | float milliseconds_kernel = 0; 67 | CUDA_CHECK(cudaEventElapsedTime(&milliseconds_kernel, start_kernel, stop_kernel)); 68 | printf("Vector Add - elapsed time: %f ms\n", milliseconds_kernel); 69 | 70 | // Copy back the result from the device to the host 71 | CUDA_CHECK(cudaMemcpy(OUT, d_OUT, sizeof(EL_TYPE) * N, cudaMemcpyDeviceToHost)); 72 | 73 | // Free the memory on the device 74 | CUDA_CHECK(cudaFree(d_A)); 75 | CUDA_CHECK(cudaFree(d_B)); 76 | CUDA_CHECK(cudaFree(d_OUT)); 77 | 78 | // Time the operation 79 | struct timeval start_check, end_check; 80 | gettimeofday(&start_check, NULL); 81 | 82 | for (int i = 0; i < N; i++) 83 | { 84 | // Check if the result is correct 85 | if (OUT[i] != A[i] + B[i]) 86 | { 87 | printf("Error at index %d: %d != %d + %d\n", i, OUT[i], A[i], B[i]); 88 | exit(1); 89 | } 90 | } 91 | 92 | // Calculate elapsed time 93 | gettimeofday(&end_check, NULL); 94 | float elapsed = (end_check.tv_sec - start_check.tv_sec) * 1000.0 + (end_check.tv_usec - start_check.tv_usec) / 1000.0; 95 | printf("Vector Add - Check elapsed time: %f ms\n", elapsed); 96 | printf("Vector Add - result OK\n"); 97 | 98 | // Free the memory on the host 99 | free(A); 100 | free(B); 101 | free(OUT); 102 | } 103 | 104 | int main() 105 | { 106 | // set your seed 107 | srand(0); 108 | 109 | test_vector_add(1000000, 128); 110 | 111 | } -------------------------------------------------------------------------------- /cuda/vector_add_simple.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "cuda_common.cuh" 9 | 10 | typedef int EL_TYPE; 11 | 12 | __global__ void cuda_vector_add_simple(EL_TYPE *OUT, EL_TYPE *A, EL_TYPE *B, int N) 13 | { 14 | int i = threadIdx.x; 15 | if (i < N) 16 | { 17 | OUT[i] = A[i] + B[i]; 18 | } 19 | } 20 | 21 | void test_vector_add(int N) 22 | { 23 | EL_TYPE *A, *B, *OUT; 24 | EL_TYPE *d_A, *d_B, *d_OUT; 25 | 26 | // Allocate the vectors on the host device 27 | A = (EL_TYPE *)malloc(sizeof(EL_TYPE) * N); 28 | B = (EL_TYPE *)malloc(sizeof(EL_TYPE) * N); 29 | OUT = (EL_TYPE *)malloc(sizeof(EL_TYPE) * N); 30 | 31 | // Initialize the vectors with random values 32 | for (int i = 0; i < N; i++) 33 | { 34 | A[i] = rand() % 100; 35 | B[i] = rand() % 100; 36 | } 37 | 38 | // Allocate device memory for a 39 | CUDA_CHECK(cudaMalloc((void **)&d_A, sizeof(EL_TYPE) * N)); 40 | CUDA_CHECK(cudaMalloc((void **)&d_B, sizeof(EL_TYPE) * N)); 41 | CUDA_CHECK(cudaMalloc((void **)&d_OUT, sizeof(EL_TYPE) * N)); 42 | 43 | // Transfer the vectors to the device 44 | CUDA_CHECK(cudaMemcpy(d_A, A, sizeof(EL_TYPE) * N, cudaMemcpyHostToDevice)); 45 | CUDA_CHECK(cudaMemcpy(d_B, B, sizeof(EL_TYPE) * N, cudaMemcpyHostToDevice)); 46 | 47 | cudaEvent_t start_kernel, stop_kernel; 48 | CUDA_CHECK(cudaEventCreate(&start_kernel)); 49 | CUDA_CHECK(cudaEventCreate(&stop_kernel)); 50 | 51 | CUDA_CHECK(cudaEventRecord(start_kernel)); 52 | // Run the kernel 53 | cuda_vector_add_simple<<<1, N>>>(d_OUT, d_A, d_B, N); 54 | CUDA_CHECK(cudaEventRecord(stop_kernel)); 55 | // Check for launch errors 56 | CUDA_CHECK(cudaPeekAtLastError()); 57 | CUDA_CHECK(cudaEventSynchronize(stop_kernel)); 58 | 59 | // Calculate elapsed milliseconds 60 | float milliseconds_kernel = 0; 61 | CUDA_CHECK(cudaEventElapsedTime(&milliseconds_kernel, start_kernel, stop_kernel)); 62 | printf("Vector Add - elapsed time: %f ms\n", milliseconds_kernel); 63 | 64 | // Copy back the result from the device to the host 65 | CUDA_CHECK(cudaMemcpy(OUT, d_OUT, sizeof(EL_TYPE) * N, cudaMemcpyDeviceToHost)); 66 | 67 | // Free the memory on the device 68 | CUDA_CHECK(cudaFree(d_A)); 69 | CUDA_CHECK(cudaFree(d_B)); 70 | CUDA_CHECK(cudaFree(d_OUT)); 71 | 72 | // Time the operation 73 | struct timeval start_check, end_check; 74 | gettimeofday(&start_check, NULL); 75 | 76 | for (int i = 0; i < N; i++) 77 | { 78 | // Check if the result is correct 79 | if (OUT[i] != A[i] + B[i]) 80 | { 81 | printf("Error at index %d: %d != %d + %d\n", i, OUT[i], A[i], B[i]); 82 | exit(1); 83 | } 84 | } 85 | 86 | // Calculate elapsed time 87 | gettimeofday(&end_check, NULL); 88 | float elapsed = (end_check.tv_sec - start_check.tv_sec) * 1000.0 + (end_check.tv_usec - start_check.tv_usec) / 1000.0; 89 | printf("Vector Add - Check elapsed time: %f ms\n", elapsed); 90 | printf("Vector Add - result OK\n"); 91 | 92 | // Free the memory on the host 93 | free(A); 94 | free(B); 95 | free(OUT); 96 | } 97 | 98 | int main() 99 | { 100 | // set your seed 101 | srand(0); 102 | 103 | test_vector_add(1024); 104 | 105 | } -------------------------------------------------------------------------------- /notes/0000 - Table of Contents.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/triton-flash-attention/296ee44c8a238cd2192d13e22e9082251f1c1289/notes/0000 - Table of Contents.pdf -------------------------------------------------------------------------------- /notes/0001 - Multi Head Attention.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/triton-flash-attention/296ee44c8a238cd2192d13e22e9082251f1c1289/notes/0001 - Multi Head Attention.pdf -------------------------------------------------------------------------------- /notes/0002 - (Safe) Softmax.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/triton-flash-attention/296ee44c8a238cd2192d13e22e9082251f1c1289/notes/0002 - (Safe) Softmax.pdf -------------------------------------------------------------------------------- /notes/0003 - Online Softmax.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/triton-flash-attention/296ee44c8a238cd2192d13e22e9082251f1c1289/notes/0003 - Online Softmax.pdf -------------------------------------------------------------------------------- /notes/0004 - Block Matrix Multiplication.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/triton-flash-attention/296ee44c8a238cd2192d13e22e9082251f1c1289/notes/0004 - Block Matrix Multiplication.pdf -------------------------------------------------------------------------------- /notes/0005 - Intro to GPU & CUDA.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/triton-flash-attention/296ee44c8a238cd2192d13e22e9082251f1c1289/notes/0005 - Intro to GPU & CUDA.pdf -------------------------------------------------------------------------------- /notes/0006 - Tensor Layouts.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/triton-flash-attention/296ee44c8a238cd2192d13e22e9082251f1c1289/notes/0006 - Tensor Layouts.pdf -------------------------------------------------------------------------------- /notes/0007 - Software Pipelining.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/triton-flash-attention/296ee44c8a238cd2192d13e22e9082251f1c1289/notes/0007 - Software Pipelining.pdf -------------------------------------------------------------------------------- /notes/0008 - Autograd & Gradients.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/triton-flash-attention/296ee44c8a238cd2192d13e22e9082251f1c1289/notes/0008 - Autograd & Gradients.pdf -------------------------------------------------------------------------------- /notes/009 - Gradient of MatMul.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/triton-flash-attention/296ee44c8a238cd2192d13e22e9082251f1c1289/notes/009 - Gradient of MatMul.pdf -------------------------------------------------------------------------------- /notes/010 - Gradient of Softmax.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/triton-flash-attention/296ee44c8a238cd2192d13e22e9082251f1c1289/notes/010 - Gradient of Softmax.pdf -------------------------------------------------------------------------------- /triton/flash_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import triton 4 | import triton.language as tl 5 | 6 | 7 | @triton.jit 8 | def _attn_fwd_inner( 9 | O_block, 10 | l_i, 11 | m_i, 12 | Q_block, 13 | K_block_ptr, 14 | V_block_ptr, 15 | block_index_q, 16 | softmax_scale, 17 | BLOCK_SIZE_Q: tl.constexpr, 18 | BLOCK_SIZE_KV: tl.constexpr, 19 | STAGE: tl.constexpr, 20 | offs_q: tl.constexpr, 21 | offs_kv: tl.constexpr, 22 | SEQ_LEN: tl.constexpr, 23 | ): 24 | # range of values handled by this stage 25 | if STAGE == 1: 26 | # From 0 to the left of the diagonal 27 | lo, hi = 0, block_index_q * BLOCK_SIZE_Q 28 | elif STAGE == 2: 29 | # Used only for the block in which there is transition between non-masked and masked keys 30 | lo, hi = block_index_q * BLOCK_SIZE_Q, (block_index_q + 1) * BLOCK_SIZE_Q 31 | lo = tl.multiple_of(lo, BLOCK_SIZE_Q) 32 | else: 33 | # Only used for non-causal attention 34 | lo, hi = 0, SEQ_LEN 35 | 36 | K_block_ptr = tl.advance(K_block_ptr, (0, lo)) 37 | V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) 38 | 39 | # loop over k, v and update accumulator 40 | for start_kv in range(lo, hi, BLOCK_SIZE_KV): 41 | # Just let the compiler know that start_n is a multiple of BLOCK_N, so the compiler can do optimizations 42 | start_kv = tl.multiple_of(start_kv, BLOCK_SIZE_KV) 43 | 44 | # -- compute qk ---- 45 | K_block = tl.load(K_block_ptr) 46 | QK_block = tl.dot(Q_block, K_block) 47 | 48 | if STAGE == 2: 49 | mask = offs_q[:, None] >= (start_kv + offs_kv[None, :]) 50 | QK_block = QK_block * softmax_scale + tl.where(mask, 0, -1.0e6) 51 | m_ij = tl.maximum(m_i, tl.max(QK_block, 1)) 52 | QK_block -= m_ij[:, None] 53 | else: 54 | # Compute the maximum value of qk or keep the old max value 55 | m_ij = tl.maximum(m_i, tl.max(QK_block, 1) * softmax_scale) 56 | QK_block = QK_block * softmax_scale - m_ij[:, None] 57 | 58 | # Compute the exponential of each dot product, so now we are computing exp(qk_ij - m_ij) 59 | P_block = tl.math.exp(QK_block) 60 | # Compute the sum by rows of the attention scores 61 | l_ij = tl.sum(P_block, 1) 62 | 63 | # This is the correction factor for the previous l_i 64 | alpha = tl.math.exp(m_i - m_ij) 65 | # Apply the correction factor to the previous l_i and add the new l_ij 66 | l_i = l_i * alpha + l_ij 67 | 68 | V_block = tl.load(V_block_ptr) 69 | P_block = P_block.to(tl.float16) 70 | # This computes the following: O_new = P x V + O_old * alpha 71 | O_block = O_block * alpha[:, None] 72 | O_block = tl.dot(P_block, V_block, O_block) 73 | 74 | m_i = m_ij 75 | 76 | # Move to the next block of K and V 77 | V_block_ptr = tl.advance(V_block_ptr, (BLOCK_SIZE_KV, 0)) 78 | K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_SIZE_KV)) 79 | return O_block, l_i, m_i 80 | 81 | 82 | @triton.autotune( 83 | [ 84 | triton.Config( 85 | {"BLOCK_SIZE_Q": BLOCK_SIZE_Q, "BLOCK_SIZE_KV": BLOCK_SIZE_KV}, 86 | num_stages=num_stages, 87 | num_warps=num_warps, 88 | ) 89 | for BLOCK_SIZE_Q in [64, 128] 90 | for BLOCK_SIZE_KV in [32, 64] 91 | for num_stages in ([3, 4, 7]) 92 | for num_warps in [2, 4] 93 | ], 94 | key=["SEQ_LEN", "HEAD_DIM"], 95 | ) 96 | @triton.jit 97 | def _attn_fwd( 98 | Q, # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM 99 | K, # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM 100 | V, # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM 101 | softmax_scale, 102 | M, # BATCH_SIZE, NUM_HEADS, SEQ_LEN 103 | O, # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM 104 | stride_Q_batch, 105 | stride_Q_head, 106 | stride_Q_seq, 107 | stride_Q_dim, 108 | stride_K_batch, 109 | stride_K_head, 110 | stride_K_seq, 111 | stride_K_dim, 112 | stride_V_batch, 113 | stride_V_head, 114 | stride_V_seq, 115 | stride_V_dim, 116 | stride_O_batch, 117 | stride_O_head, 118 | stride_O_seq, 119 | stride_O_dim, 120 | BATCH_SIZE, 121 | NUM_HEADS: tl.constexpr, 122 | SEQ_LEN: tl.constexpr, 123 | HEAD_DIM: tl.constexpr, 124 | BLOCK_SIZE_Q: tl.constexpr, 125 | BLOCK_SIZE_KV: tl.constexpr, 126 | STAGE: tl.constexpr, 127 | ): 128 | tl.static_assert(BLOCK_SIZE_KV <= HEAD_DIM) 129 | 130 | # This indicate which block in the sequence length to process 131 | block_index_q = tl.program_id(0) 132 | 133 | # This indicates which head and batch to process. Each program is associated with a single head of a single batch 134 | index_batch_head = tl.program_id(1) 135 | # This indicate which batch this program is associated with (each batch has NUM_HEADS heads) 136 | index_batch = index_batch_head // NUM_HEADS 137 | # This indicate the position of the head in the batch 138 | index_head = index_batch_head % NUM_HEADS 139 | 140 | # This allows to get the (N_CTX, HEAD_DIM) block in the Q, K, V by selecting indexing it by batch and head 141 | qvk_offset = ( 142 | index_batch.to(tl.int64) * stride_Q_batch 143 | + index_head.to(tl.int64) * stride_Q_head 144 | ) 145 | 146 | Q_block_ptr = tl.make_block_ptr( 147 | base=Q + qvk_offset, 148 | shape=(SEQ_LEN, HEAD_DIM), 149 | strides=(stride_Q_seq, stride_Q_dim), 150 | offsets=(block_index_q * BLOCK_SIZE_Q, 0), 151 | block_shape=(BLOCK_SIZE_Q, HEAD_DIM), 152 | order=(1, 0), 153 | ) 154 | 155 | V_block_ptr = tl.make_block_ptr( 156 | base=V + qvk_offset, 157 | shape=(SEQ_LEN, HEAD_DIM), 158 | strides=(stride_V_seq, stride_V_dim), 159 | offsets=(0, 0), 160 | block_shape=(BLOCK_SIZE_KV, HEAD_DIM), 161 | order=(1, 0), 162 | ) 163 | 164 | K_block_ptr = tl.make_block_ptr( 165 | base=K + qvk_offset, 166 | shape=(HEAD_DIM, SEQ_LEN), 167 | strides=( 168 | stride_K_dim, 169 | stride_K_seq, 170 | ), # We invert the strides w.r.t Q, so we transpose the matrix 171 | offsets=(0, 0), 172 | block_shape=(HEAD_DIM, BLOCK_SIZE_KV), 173 | order=(0, 1), 174 | ) 175 | 176 | O_block_ptr = tl.make_block_ptr( 177 | base=O + qvk_offset, 178 | shape=(SEQ_LEN, HEAD_DIM), 179 | strides=(stride_O_seq, stride_O_dim), 180 | offsets=(block_index_q * BLOCK_SIZE_Q, 0), 181 | block_shape=(BLOCK_SIZE_Q, HEAD_DIM), 182 | order=(1, 0), 183 | ) 184 | 185 | # offs_q: the offsets for the tokens in the Q to process 186 | offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) 187 | # offs_kv: the offsets for the tokens in the K and V sequence to process 188 | offs_kv = tl.arange(0, BLOCK_SIZE_KV) 189 | 190 | # m_i: the running maximum. We have one for each query 191 | m_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) - float("inf") 192 | # l_i: the running sum. We have one for each query (as we sum the attention scores by rows) 193 | l_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) + 1.0 194 | # acc: the accumulator for the output, which is a group of rows of the O matrix 195 | O_block = tl.zeros([BLOCK_SIZE_Q, HEAD_DIM], dtype=tl.float32) 196 | 197 | # load the blocks of Q: it will stay in SRAM throughout 198 | Q_block = tl.load(Q_block_ptr) 199 | 200 | # Stage: 3 if causal, else 1 201 | 202 | if STAGE == 1 or STAGE == 3: 203 | # This step runs for non-causal attention or for the blocks to the left of the diagonal in the causal attention 204 | O_block, l_i, m_i = _attn_fwd_inner( 205 | O_block, 206 | l_i, 207 | m_i, 208 | Q_block, 209 | K_block_ptr, 210 | V_block_ptr, 211 | block_index_q, 212 | softmax_scale, 213 | BLOCK_SIZE_Q, 214 | BLOCK_SIZE_KV, 215 | 4 - STAGE, 216 | offs_q, 217 | offs_kv, 218 | SEQ_LEN, 219 | ) 220 | 221 | if STAGE == 3: 222 | # This step runs for the blocks to the right of the diagonal in the causal attention 223 | O_block, l_i, m_i = _attn_fwd_inner( 224 | O_block, 225 | l_i, 226 | m_i, 227 | Q_block, 228 | K_block_ptr, 229 | V_block_ptr, 230 | block_index_q, 231 | softmax_scale, 232 | BLOCK_SIZE_Q, 233 | BLOCK_SIZE_KV, 234 | 2, 235 | offs_q, 236 | offs_kv, 237 | SEQ_LEN, 238 | ) 239 | # epilogue 240 | m_i += tl.math.log( 241 | l_i 242 | ) # This is needed to compute the logsumexp for the backwards pass 243 | O_block = O_block / l_i[:, None] 244 | m_ptrs = M + index_batch_head * SEQ_LEN + offs_q 245 | tl.store(m_ptrs, m_i) 246 | tl.store(O_block_ptr, O_block.to(O.type.element_ty)) 247 | 248 | 249 | @triton.jit 250 | def _attn_bwd_preprocess( 251 | O, 252 | dO, 253 | D, 254 | SEQ_LEN, 255 | BLOCK_SIZE_Q: tl.constexpr, 256 | HEAD_DIM: tl.constexpr, 257 | ): 258 | block_index_q = tl.program_id(0) 259 | offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) 260 | index_batch_head = tl.program_id(1) 261 | offs_dim = tl.arange(0, HEAD_DIM) 262 | # Load a single block of BLOCK_SIZE_Q rows of O 263 | O_block = tl.load( 264 | O 265 | + index_batch_head * HEAD_DIM * SEQ_LEN 266 | + offs_q[:, None] * HEAD_DIM 267 | + offs_dim[None, :] 268 | ) 269 | # Load a single block of BLOCK_SIZE_Q rows of dO 270 | dO_block = tl.load( 271 | dO 272 | + index_batch_head * HEAD_DIM * SEQ_LEN 273 | + offs_q[:, None] * HEAD_DIM 274 | + offs_dim[None, :] 275 | ).to(tl.float32) 276 | # Compute the D block 277 | D_block = tl.sum(dO_block * O_block, axis=1) # Shape: (BLOCK_SIZE_Q,) 278 | # Store the D block 279 | D_block_ptrs = D + index_batch_head * SEQ_LEN + offs_q 280 | tl.store(D_block_ptrs, D_block) 281 | 282 | 283 | @triton.jit 284 | def _attn_bwd_dq( 285 | Q, 286 | K, 287 | V, 288 | softmax_scale, 289 | dO, 290 | dQ, 291 | dK, 292 | dV, 293 | M, 294 | D, 295 | stride_batch, 296 | stride_head, 297 | stride_seq, 298 | stride_dim, 299 | NUM_HEADS, 300 | SEQ_LEN, 301 | BLOCK_Q: tl.constexpr, 302 | BLOCK_KV: tl.constexpr, 303 | HEAD_DIM: tl.constexpr, 304 | STAGE: tl.constexpr, 305 | ): 306 | index_batch_head = tl.program_id(2) 307 | index_batch = index_batch_head // NUM_HEADS 308 | index_head = index_batch_head % NUM_HEADS 309 | offset_batch_head = (stride_batch * index_batch + stride_head * index_head).to( 310 | tl.int64 311 | ) 312 | # This is the offset that allows us to select the right sequence given the batch and head. 313 | offset_batch_head_seq = (index_batch_head * SEQ_LEN).to(tl.int64) 314 | 315 | # Make sure the pointers are in the right place w.r.t batch and head 316 | # The reason we don't access the blocks through make_block_ptr is because we need to use the range of offsets to apply the masking 317 | Q += offset_batch_head 318 | K += offset_batch_head 319 | V += offset_batch_head 320 | dO += offset_batch_head 321 | dQ += offset_batch_head 322 | dK += offset_batch_head 323 | dV += offset_batch_head 324 | 325 | # Make sure the pointers are in the right place w.r.t batch, head and sequence 326 | M += offset_batch_head_seq 327 | D += offset_batch_head_seq 328 | 329 | # load scales 330 | offs_dim = tl.arange(0, HEAD_DIM) 331 | 332 | index_block_kv = tl.program_id(0) 333 | 334 | start_q = index_block_kv * BLOCK_Q 335 | offs_q = start_q + tl.arange(0, BLOCK_Q) 336 | 337 | Q_block = tl.load(Q + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim) 338 | dQ_block = tl.zeros([BLOCK_Q, HEAD_DIM], dtype=tl.float32) 339 | dO_block = tl.load( 340 | dO + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim 341 | ) 342 | 343 | M_block = tl.load(M + offs_q) 344 | M_block = M_block[:, None] 345 | 346 | offs_kv = tl.arange(0, BLOCK_KV) 347 | 348 | # We access the K and V as transposed blocks 349 | kT_ptrs = K + offs_kv[None, :] * stride_seq + offs_dim[:, None] * stride_dim 350 | vT_ptrs = V + offs_kv[None, :] * stride_seq + offs_dim[:, None] * stride_dim 351 | 352 | Di = tl.load(D + offs_q) 353 | 354 | curr_kv = 0 355 | num_steps = SEQ_LEN // BLOCK_KV 356 | for blk_idx in range(num_steps): 357 | K_T_block = tl.load(kT_ptrs) 358 | V_T_block = tl.load(vT_ptrs) 359 | QK_block = softmax_scale * tl.dot(Q_block, K_T_block) 360 | P_block = tl.math.exp(QK_block - M_block) 361 | 362 | if STAGE == 3: 363 | # Autoregressive masking. 364 | offs_kv = curr_kv + tl.arange(0, BLOCK_KV) 365 | mask_block = offs_q[:, None] >= offs_kv[None, :] 366 | P_block = tl.where(mask_block, P_block, 0.0) 367 | 368 | # Compute dP and dS. 369 | dP_block = tl.dot(dO_block, V_T_block).to(tl.float32) 370 | dS_block = P_block * (dP_block - Di[:, None]) 371 | dS_block = dS_block.to(tl.float16) 372 | # Compute dQ. 373 | # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. 374 | dQ_block += softmax_scale * tl.dot(dS_block, tl.trans(K_T_block)) 375 | # Increment pointers. 376 | curr_kv += BLOCK_KV 377 | kT_ptrs += BLOCK_KV * stride_seq 378 | vT_ptrs += BLOCK_KV * stride_seq 379 | 380 | dQ_block_ptrs = dQ + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim 381 | tl.store(dQ_block_ptrs, dQ_block) 382 | 383 | 384 | @triton.jit 385 | def _attn_bwd_dk_dv( 386 | Q, 387 | K, 388 | V, 389 | softmax_scale, 390 | dO, 391 | dQ, 392 | dK, 393 | dV, 394 | M, 395 | D, 396 | stride_batch, 397 | stride_head, 398 | stride_seq, 399 | stride_dim, 400 | NUM_HEADS, 401 | SEQ_LEN, 402 | BLOCK_Q: tl.constexpr, 403 | BLOCK_KV: tl.constexpr, 404 | HEAD_DIM: tl.constexpr, 405 | STAGE: tl.constexpr, 406 | ): 407 | index_batch_head = tl.program_id(2) 408 | index_batch = index_batch_head // NUM_HEADS 409 | index_head = index_batch_head % NUM_HEADS 410 | offset_batch_head = (stride_batch * index_batch + stride_head * index_head).to( 411 | tl.int64 412 | ) 413 | # This is the offset that allows us to select the right sequence given the batch and head. 414 | offset_batch_head_seq = (index_batch_head * SEQ_LEN).to(tl.int64) 415 | 416 | # Make sure the pointers are in the right place w.r.t batch and head 417 | # The reason we don't access the blocks through make_block_ptr is because we need to use the range of offsets to apply the masking 418 | Q += offset_batch_head 419 | K += offset_batch_head 420 | V += offset_batch_head 421 | dO += offset_batch_head 422 | dQ += offset_batch_head 423 | dK += offset_batch_head 424 | dV += offset_batch_head 425 | 426 | # Make sure the pointers are in the right place w.r.t batch, head and sequence 427 | M += offset_batch_head_seq 428 | D += offset_batch_head_seq 429 | 430 | # load scales 431 | offs_dim = tl.arange(0, HEAD_DIM) 432 | 433 | index_block_kv = tl.program_id(0) 434 | start_kv = index_block_kv * BLOCK_KV 435 | 436 | offs_kv = start_kv + tl.arange(0, BLOCK_KV) 437 | 438 | dV_block = tl.zeros([BLOCK_KV, HEAD_DIM], dtype=tl.float32) 439 | dK_block = tl.zeros([BLOCK_KV, HEAD_DIM], dtype=tl.float32) 440 | 441 | # load K and V: they stay in SRAM throughout the inner loop. 442 | K_block = tl.load( 443 | K + offs_kv[:, None] * stride_seq + offs_dim[None, :] * stride_dim 444 | ) # Shape: (BLOCK_KV1, HEAD_DIM) 445 | V_block = tl.load( 446 | V + offs_kv[:, None] * stride_seq + offs_dim[None, :] * stride_dim 447 | ) # Shape: (BLOCK_KV1, HEAD_DIM) 448 | 449 | offs_q = tl.arange(0, BLOCK_Q) 450 | 451 | # We access the Q as a transposed array, so that's why we treat offs_q as a column vector ans offs_dim as a row vector 452 | # This is equivalent to doing: 453 | # q_ptrs = Q + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim 454 | # qT_ptrs = tl.trans(q_ptrs) 455 | # We point to the first BLOCK_Q rows of Q for both the qT and dO pointers, inside the for loop we will move forward by BLOCK_Q rows at each iteration. 456 | qT_ptrs = Q + offs_q[None, :] * stride_seq + offs_dim[:, None] * stride_dim 457 | dO_ptrs = dO + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim 458 | 459 | # Iterates over the sequence dimension of the query 460 | curr_q = 0 461 | num_steps = SEQ_LEN // BLOCK_Q 462 | for blk_idx in range(num_steps): 463 | # Load a block of Q 464 | qT_block = tl.load(qT_ptrs) 465 | # Load the logsumexp values for the queries in the current block 466 | offs_q = curr_q + tl.arange(0, BLOCK_Q) 467 | m = tl.load(M + offs_q) 468 | 469 | # This gives us (QK^T)^T = (K^T)^T(Q^T) = K(Q^T) = P^T 470 | QK_T_block = softmax_scale * tl.dot(K_block, qT_block) 471 | # We apply the softmax by using the logsumexp trick 472 | P_T_block = tl.math.exp(QK_T_block - m[None, :]) 473 | 474 | if STAGE == 3: 475 | # Autoregressive masking. 476 | # mask is True for all values that DO NOT NEED TO BE MASKED 477 | mask_block = ( 478 | offs_q[None, :] >= offs_kv[:, None] 479 | ) # Shape: (BLOCK_KV1, BLOCK_Q1) 480 | # Replace all the masked values with 0. 481 | # In this case we do not need to mask with -Inf before applying the softmax since we already computed the normalization factors (stored in "m") 482 | P_T_block = tl.where(mask_block, P_T_block, 0.0) 483 | 484 | dO_block = tl.load(dO_ptrs) 485 | # According to the formula: dV_new = dV_old + P^T x dO, where x is the matrix multiplication 486 | dV_block += tl.dot(P_T_block.to(tl.float16), dO_block) 487 | 488 | # Delta = rowsum(O * dO) where * is the element-wise product 489 | Di = tl.load(D + offs_q) 490 | 491 | # dP = dO x V^T, so dP^T = V x dO^T 492 | # Where x is the matrix multiplication 493 | dpT_block = tl.dot(V_block, tl.trans(dO_block)).to(tl.float32) 494 | 495 | # We know that dS = P * (dP - Delta), so dS^T = P^T * (dP^T - Delta^T) 496 | 497 | dS_T_block = P_T_block * (dpT_block - Di[None, :]) 498 | dS_T_block = dS_T_block.to(tl.float16) 499 | 500 | # According to the formula on the paper: dK_new = dK_old + dS^T x Q 501 | dK_block += softmax_scale * tl.dot(dS_T_block, tl.trans(qT_block)) 502 | # Increment pointers. 503 | curr_q += BLOCK_Q 504 | qT_ptrs += BLOCK_Q * stride_seq 505 | dO_ptrs += BLOCK_Q * stride_seq 506 | 507 | # Write back dV. 508 | dV_block_ptrs = dV + offs_kv[:, None] * stride_seq + offs_dim[None, :] * stride_dim 509 | tl.store(dV_block_ptrs, dV_block) 510 | 511 | # Write back dK. 512 | dK_block_ptrs = dK + offs_kv[:, None] * stride_seq + offs_dim[None, :] * stride_dim 513 | tl.store(dK_block_ptrs, dK_block) 514 | 515 | 516 | class TritonAttention(torch.autograd.Function): 517 | 518 | @staticmethod 519 | def forward(ctx, Q, K, V, causal, softmax_scale): 520 | HEAD_DIM_Q, HEAD_DIM_K = Q.shape[-1], K.shape[-1] 521 | HEAD_DIM_V = V.shape[-1] 522 | 523 | BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM = Q.shape 524 | 525 | assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 526 | 527 | O = torch.empty_like(Q) 528 | stage = 3 if causal else 1 529 | 530 | grid = lambda args: ( 531 | triton.cdiv(SEQ_LEN, args["BLOCK_SIZE_Q"]), 532 | BATCH_SIZE * NUM_HEADS, 533 | 1, 534 | ) 535 | 536 | # M is the logsumexp for the backward pass, one for each query 537 | M = torch.empty( 538 | (BATCH_SIZE, NUM_HEADS, SEQ_LEN), device=Q.device, dtype=torch.float32 539 | ) 540 | 541 | _attn_fwd[grid]( 542 | Q=Q, 543 | K=K, 544 | V=V, 545 | softmax_scale=softmax_scale, 546 | M=M, 547 | O=O, 548 | stride_Q_batch=Q.stride(0), 549 | stride_Q_head=Q.stride(1), 550 | stride_Q_seq=Q.stride(2), 551 | stride_Q_dim=Q.stride(3), 552 | stride_K_batch=K.stride(0), 553 | stride_K_head=K.stride(1), 554 | stride_K_seq=K.stride(2), 555 | stride_K_dim=K.stride(3), 556 | stride_V_batch=V.stride(0), 557 | stride_V_head=V.stride(1), 558 | stride_V_seq=V.stride(2), 559 | stride_V_dim=V.stride(3), 560 | stride_O_batch=O.stride(0), 561 | stride_O_head=O.stride(1), 562 | stride_O_seq=O.stride(2), 563 | stride_O_dim=O.stride(3), 564 | BATCH_SIZE=Q.shape[0], 565 | NUM_HEADS=Q.shape[1], 566 | SEQ_LEN=Q.shape[2], 567 | HEAD_DIM=HEAD_DIM_K, 568 | STAGE=stage, 569 | ) 570 | 571 | ctx.save_for_backward(Q, K, V, O, M) 572 | ctx.grid = grid 573 | ctx.softmax_scale = softmax_scale 574 | ctx.HEAD_DIM = HEAD_DIM_K 575 | ctx.causal = causal 576 | return O 577 | 578 | @staticmethod 579 | def backward(ctx, dO): 580 | Q, K, V, O, M = ctx.saved_tensors 581 | 582 | assert dO.is_contiguous() 583 | assert Q.stride() == K.stride() == V.stride() == O.stride() == dO.stride() 584 | dQ = torch.empty_like(Q) 585 | dK = torch.empty_like(K) 586 | dV = torch.empty_like(V) 587 | 588 | BATCH_SIZE, NUM_HEADS, SEQ_LEN = Q.shape[:3] 589 | NUM_WARPS, NUM_STAGES = 4, 3 590 | BLOCK_SIZE_MICRO, BLOCK_SIZE_MACRO = 32, 128 591 | 592 | preprocess_grid = (SEQ_LEN // BLOCK_SIZE_MACRO, BATCH_SIZE * NUM_HEADS) 593 | D = torch.empty_like(M) # Shape: (BATCH_SIZE, NUM_HEADS, SEQ_LEN) 594 | 595 | # Compute all the elements Di 596 | _attn_bwd_preprocess[preprocess_grid]( 597 | O=O, 598 | dO=dO, 599 | D=D, 600 | SEQ_LEN=SEQ_LEN, 601 | BLOCK_SIZE_Q=BLOCK_SIZE_MACRO, 602 | HEAD_DIM=ctx.HEAD_DIM, 603 | ) 604 | 605 | grid = (SEQ_LEN // BLOCK_SIZE_MACRO, 1, BATCH_SIZE * NUM_HEADS) 606 | 607 | stage = 3 if ctx.causal else 1 608 | 609 | # Fix KV and iterate through all the Q blocks 610 | _attn_bwd_dk_dv[grid]( 611 | Q=Q, 612 | K=K, 613 | V=V, 614 | softmax_scale=ctx.softmax_scale, 615 | dO=dO, 616 | dQ=dQ, 617 | dK=dK, 618 | dV=dV, 619 | M=M, 620 | D=D, 621 | stride_batch=Q.stride(0), 622 | stride_head=Q.stride(1), 623 | stride_seq=Q.stride(2), 624 | stride_dim=Q.stride(3), 625 | NUM_HEADS=NUM_HEADS, 626 | SEQ_LEN=SEQ_LEN, 627 | BLOCK_Q=BLOCK_SIZE_MICRO, 628 | BLOCK_KV=BLOCK_SIZE_MACRO, 629 | HEAD_DIM=ctx.HEAD_DIM, 630 | STAGE=stage, 631 | num_warps=NUM_WARPS, 632 | num_stages=NUM_STAGES, 633 | ) 634 | 635 | # Fix Q and iterate through all the KV block 636 | _attn_bwd_dq[grid]( 637 | Q=Q, 638 | K=K, 639 | V=V, 640 | softmax_scale=ctx.softmax_scale, 641 | dO=dO, 642 | dQ=dQ, 643 | dK=dK, 644 | dV=dV, 645 | M=M, 646 | D=D, 647 | stride_batch=Q.stride(0), 648 | stride_head=Q.stride(1), 649 | stride_seq=Q.stride(2), 650 | stride_dim=Q.stride(3), 651 | NUM_HEADS=NUM_HEADS, 652 | SEQ_LEN=SEQ_LEN, 653 | BLOCK_Q=BLOCK_SIZE_MACRO, 654 | BLOCK_KV=BLOCK_SIZE_MICRO, 655 | HEAD_DIM=ctx.HEAD_DIM, 656 | STAGE=stage, 657 | num_warps=NUM_WARPS, 658 | num_stages=NUM_STAGES, 659 | ) 660 | 661 | return dQ, dK, dV, None, None 662 | 663 | 664 | def test_op(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, causal, dtype=torch.float16): 665 | Q = ( 666 | torch.empty( 667 | (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda" 668 | ) 669 | .normal_(mean=0.0, std=0.5) 670 | .requires_grad_() 671 | ) 672 | K = ( 673 | torch.empty( 674 | (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda" 675 | ) 676 | .normal_(mean=0.0, std=0.5) 677 | .requires_grad_() 678 | ) 679 | V = ( 680 | torch.empty( 681 | (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda" 682 | ) 683 | .normal_(mean=0.0, std=0.5) 684 | .requires_grad_() 685 | ) 686 | 687 | softmax_scale = 1 / (HEAD_DIM**0.5) 688 | dO = torch.randn_like(Q) 689 | 690 | # reference implementation 691 | MASK = torch.tril(torch.ones((SEQ_LEN, SEQ_LEN), device="cuda")) 692 | P = torch.matmul(Q, K.transpose(2, 3)) * softmax_scale 693 | if causal: 694 | P[:, :, MASK == 0] = float("-inf") 695 | P = torch.softmax(P.float(), dim=-1).half() 696 | ref_O = torch.matmul(P, V) 697 | ref_O.backward(dO) 698 | ref_dV, V.grad = V.grad.clone(), None 699 | ref_dK, K.grad = K.grad.clone(), None 700 | ref_dQ, Q.grad = Q.grad.clone(), None 701 | 702 | # triton implementation 703 | tri_out = TritonAttention.apply(Q, K, V, causal, softmax_scale).half() 704 | tri_out.backward(dO) 705 | tri_dV, V.grad = V.grad.clone(), None 706 | tri_dK, K.grad = K.grad.clone(), None 707 | tri_dQ, Q.grad = Q.grad.clone(), None 708 | 709 | # compare 710 | rtol = 0.0 711 | atol = 1e-2 712 | assert torch.allclose(ref_O, tri_out, atol=atol, rtol=rtol) 713 | assert torch.allclose(ref_dK, tri_dK, atol=atol, rtol=rtol) 714 | assert torch.allclose(ref_dV, tri_dV, atol=atol, rtol=rtol) 715 | assert torch.allclose(ref_dQ, tri_dQ, atol=atol, rtol=rtol) 716 | 717 | 718 | if __name__ == "__main__": 719 | test_op(BATCH_SIZE=8, NUM_HEADS=16, SEQ_LEN=4096, HEAD_DIM=64, causal=True) 720 | test_op(BATCH_SIZE=8, NUM_HEADS=16, SEQ_LEN=4096, HEAD_DIM=64, causal=False) 721 | print("PASSED") 722 | -------------------------------------------------------------------------------- /triton/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.0 2 | triton==3.0.0 --------------------------------------------------------------------------------