├── .DS_Store ├── .gitignore ├── 01_syllabus_day ├── .DS_Store ├── README.md ├── pics │ ├── gpus_go_brrr-2460456025.png │ ├── nv256n5b1-nvidia-logo-nvidia-logo-free-icon-of-vector-logo-1502684034.jpeg │ └── triton-logo.png ├── syllabus_day.pdf └── syllabus_day.tex ├── 02_GPU_architecture_basics ├── .DS_Store ├── GPU_architecture_basics.pdf ├── GPU_architecture_basics.tex ├── README.md └── pics │ ├── SM.png │ ├── architecture_comparison-2717927587.jpg │ ├── mem_hierarchy.png │ └── triton-logo.png ├── 03_cloud_GPU_setup └── 03_cloud_GPU_setup.md ├── 04_vector_addition ├── README.md ├── __init__.py ├── vector-add-performance.png └── vector_addition.py ├── 05_fused_softmax ├── README.md ├── __init__.py ├── fused_softmax.py └── softmax-performance.png ├── 06_matmul ├── README.md ├── __init__.py ├── block_wise_matmul.jpeg ├── grouped_vs_row_major_ordering_annotated.jpg ├── matmul-performance.png └── matmul.py ├── 07_dropout ├── README.md └── dropout.py ├── 08_layernorm ├── README.md ├── layer-norm-backward.png └── layernorm.py ├── 09_flash_attention ├── .DS_Store ├── Note Jan 20, 2026.pdf ├── README.md ├── attention-performance-bwd.png ├── attention-performance-fwd.png └── flash_attention.py ├── 10_CEloss_project ├── README.md └── celoss.py ├── LICENSE ├── README.md ├── __init__.py └── requirements.txt /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /01_syllabus_day/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/01_syllabus_day/.DS_Store -------------------------------------------------------------------------------- /01_syllabus_day/README.md: -------------------------------------------------------------------------------- 1 | jic anyone wanted the original pdf 2 | 3 | [![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/TUQAyCNxFe4/0.jpg)](https://www.youtube.com/watch?v=TUQAyCNxFe4) 4 | -------------------------------------------------------------------------------- /01_syllabus_day/pics/gpus_go_brrr-2460456025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/01_syllabus_day/pics/gpus_go_brrr-2460456025.png -------------------------------------------------------------------------------- /01_syllabus_day/pics/nv256n5b1-nvidia-logo-nvidia-logo-free-icon-of-vector-logo-1502684034.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/01_syllabus_day/pics/nv256n5b1-nvidia-logo-nvidia-logo-free-icon-of-vector-logo-1502684034.jpeg -------------------------------------------------------------------------------- /01_syllabus_day/pics/triton-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/01_syllabus_day/pics/triton-logo.png -------------------------------------------------------------------------------- /01_syllabus_day/syllabus_day.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/01_syllabus_day/syllabus_day.pdf -------------------------------------------------------------------------------- /01_syllabus_day/syllabus_day.tex: -------------------------------------------------------------------------------- 1 | \documentclass[aspectratio=169]{beamer} 2 | % 3 | % Choose how your presentation looks. 4 | % 5 | % For more themes, color themes and font themes, see: 6 | % http://deic.uab.es/~iblanes/beamer_gallery/index_by_theme.html 7 | % 8 | \mode 9 | { 10 | \usetheme{default} % or try Darmstadt, Madrid, Warsaw, ... 11 | \usecolortheme{default} % or try albatross, beaver, crane, ... 12 | \usefonttheme{default} % or try serif, structurebold, ... 13 | \setbeamertemplate{navigation symbols}{} 14 | \setbeamertemplate{caption}[numbered] 15 | } 16 | 17 | % Set background to black and text to white 18 | \setbeamercolor{background canvas}{bg=black} 19 | \setbeamercolor{normal text}{fg=white} 20 | \setbeamercolor{frametitle}{fg=white} 21 | \setbeamercolor{title}{fg=white} 22 | 23 | % You can continue to set other colors as needed: 24 | \setbeamercolor{item}{fg=magenta} % Color of bullets 25 | \setbeamercolor{subitem}{fg=yellow} 26 | \setbeamercolor{subsubitem}{fg=cyan} 27 | % ... 28 | \setbeamertemplate{frametitle}[default][center] 29 | 30 | \usepackage[english]{babel} 31 | \usepackage[utf8]{inputenc} 32 | \usepackage[T1]{fontenc} 33 | \usepackage{graphicx} 34 | 35 | % Set larger font for frame title 36 | \setbeamerfont{frametitle}{size=\huge} 37 | 38 | \usepackage{emoji} 39 | 40 | \begin{document} 41 | 42 | \begin{frame}{Syllabus day topics} 43 | \begin{columns}[T] 44 | \begin{column}[T]{0.5\textwidth} 45 | \begin{enumerate} 46 | \item what is a GPU kernel? 47 | \item why Triton over CUDA? 48 | \item prerequisites 49 | \item why use this guide over others? 50 | \item things to keep in mind during these lectures 51 | \end{enumerate} 52 | \end{column} 53 | \begin{column}{0.5\textwidth} 54 | \includegraphics[height=0.8\textheight]{pics/triton-logo.png} 55 | \end{column} 56 | \end{columns} 57 | \end{frame} 58 | 59 | \begin{frame}{what is a GPU kernel?} 60 | \textbf{GPU:} the types of computer processors that we use to make AI, videogames, scientific computing, etc. go BRRRRRR \\ 61 | \textbf{GPU kernel:} the function that defines exactly how to do a desired mathematical calculation using an awareness of how GPUs are structured in order to best take advantage of that structure to facilitate going BRRRRRR 62 | \begin{center} 63 | \includegraphics[width=0.7\textwidth]{pics/gpus_go_brrr-2460456025.png} 64 | \end{center} 65 | \end{frame} 66 | 67 | \begin{frame}{why Triton over CUDA?} 68 | \begin{columns}[T] 69 | \begin{column}[T]{0.4\textwidth} 70 | \includegraphics[height=0.2\textheight]{pics/triton-logo.png} 71 | \begin{itemize} 72 | \item less popular 73 | \item open-source 74 | \item both Nvidia \& AMD 75 | \item Python 76 | \item 90\% as fast 77 | \item linux only 78 | \item less to learn 79 | \end{itemize} 80 | \end{column} 81 | \begin{column}[T]{0.4\textwidth} 82 | \includegraphics[height=0.2\textheight]{pics/nv256n5b1-nvidia-logo-nvidia-logo-free-icon-of-vector-logo-1502684034.jpeg} 83 | \begin{itemize} 84 | \item more popular 85 | \item closed-source 86 | \item Nvidia GPUs only 87 | \item C 88 | \item gold standard for speed 89 | \item linux or windows 90 | \item more to learn 91 | \end{itemize} 92 | \end{column} 93 | \end{columns} 94 | \end{frame} 95 | 96 | \begin{frame}{prerequisites} 97 | required 98 | \begin{itemize} 99 | \item Python 100 | \item basic computer hardware concepts (memory, processor, bits vs bytes, floating point operations) 101 | \item linear algebra 102 | \item calculus 103 | \item common deep learning operations (matmul, softmax, attention, etc.) 104 | \item PyTorch 105 | \end{itemize} 106 | preferred 107 | \begin{itemize} 108 | \item some basic but not-universal-among-python-programmer concepts \textit{such as} 109 | \begin{itemize} 110 | \item big O notation 111 | \item compile-time vs run-time 112 | \end{itemize} 113 | \item data-structures \& algorithms (leetcode) 114 | \end{itemize} 115 | \end{frame} 116 | 117 | \begin{frame}{why use this guide over others?} 118 | guides I used were (links in repo readme) 119 | \begin{itemize} 120 | \item official triton documentation 121 | \item Umar Jamil's flash-attention tutorial 122 | \item GPU Mode's lecture series #14 123 | \end{itemize} 124 | they all some number of the following issues 125 | \begin{itemize} 126 | \item bugs (kernels straight up didn't pass tests) 127 | \item slower than PyTorch 128 | \item assumed you already know CUDA 129 | \item little to no attempt to explain what was happening 130 | \item unnecessarily confusing/overcomplicated 131 | \item primarily one format (text XOR video) 132 | \end{itemize} 133 | \end{frame} 134 | 135 | \begin{frame}{things to keep in mind} 136 | \vspace{-1.0in} 137 | \begin{itemize} 138 | \item i'm not an expert (but I can beat PyTorch) 139 | \item corrections and elaborations will go in the pinned comment on each video 140 | \item if you cannot build something you do not understand it! watching/reading is not good enough. you need to go build something sufficiently complex in order to claim comprehension 141 | \end{itemize} 142 | \end{frame} 143 | 144 | \end{document} 145 | -------------------------------------------------------------------------------- /02_GPU_architecture_basics/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/02_GPU_architecture_basics/.DS_Store -------------------------------------------------------------------------------- /02_GPU_architecture_basics/GPU_architecture_basics.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/02_GPU_architecture_basics/GPU_architecture_basics.pdf -------------------------------------------------------------------------------- /02_GPU_architecture_basics/GPU_architecture_basics.tex: -------------------------------------------------------------------------------- 1 | \documentclass[aspectratio=169]{beamer} 2 | % 3 | % Choose how your presentation looks. 4 | % 5 | % For more themes, color themes and font themes, see: 6 | % http://deic.uab.es/~iblanes/beamer_gallery/index_by_theme.html 7 | % 8 | \mode 9 | { 10 | \usetheme{default} % or try Darmstadt, Madrid, Warsaw, ... 11 | \usecolortheme{default} % or try albatross, beaver, crane, ... 12 | \usefonttheme{default} % or try serif, structurebold, ... 13 | \setbeamertemplate{navigation symbols}{} 14 | \setbeamertemplate{caption}[numbered] 15 | } 16 | 17 | % Set background to black and text to white 18 | \setbeamercolor{background canvas}{bg=black} 19 | \setbeamercolor{normal text}{fg=white} 20 | \setbeamercolor{frametitle}{fg=white} 21 | \setbeamercolor{title}{fg=white} 22 | 23 | % You can continue to set other colors as needed: 24 | \setbeamercolor{item}{fg=magenta} % Color of bullets 25 | \setbeamercolor{subitem}{fg=yellow} 26 | \setbeamercolor{subsubitem}{fg=cyan} 27 | % ... 28 | \setbeamertemplate{frametitle}[default][center] 29 | 30 | \usepackage[english]{babel} 31 | \usepackage[utf8]{inputenc} 32 | \usepackage[T1]{fontenc} 33 | \usepackage{graphicx} 34 | 35 | % Set larger font for frame title 36 | \setbeamerfont{frametitle}{size=\huge} 37 | 38 | \usepackage{emoji} 39 | 40 | \begin{document} 41 | 42 | \begin{frame}{GPU architecture basics} 43 | \begin{columns}[T] 44 | \begin{column}[T]{0.5\textwidth} 45 | Triton abstracts away many low-level details of how GPUs work so that you don't have to think about them \\ 46 | \vspace{0.1in} 47 | This lesson is just a rough primer; do not feel like you need to understand the specifics of each diagram. It'll make more sense when we start coding 48 | \end{column} 49 | \begin{column}{0.5\textwidth} 50 | \includegraphics[height=0.8\textheight]{pics/triton-logo.png} 51 | \end{column} 52 | \end{columns} 53 | \end{frame} 54 | 55 | \begin{frame}{CPU vs GPU} 56 | \centering 57 | \includegraphics[width=0.8\textwidth]{pics/architecture_comparison-2717927587.jpg} 58 | \end{frame} 59 | 60 | \begin{frame}{memory hierarchy} 61 | \centering 62 | \includegraphics[width=0.8\textwidth]{pics/mem_hierarchy.png} 63 | \end{frame} 64 | 65 | \begin{frame}{one streaming multi-processor (SM) per pool of SRAM} 66 | \centering 67 | \includegraphics[width=0.8\textwidth]{pics/SM.png} 68 | \end{frame} 69 | 70 | \begin{frame}{programs} 71 | \vspace{-0.5in} 72 | a program is a specific instance of our kernel code run in parallel alongside many others. It is differentiated from other programs by the chunk of data it is assigned to work on 73 | \begin{itemize} 74 | \item each program is defined by a program ID (PID) which is a tuple full of integers 75 | \begin{itemize} 76 | \item we use this PID alongside indexing logic to figure out which chunk of a tensor the program works on 77 | \end{itemize} 78 | \item at least one program is called per SM, depending on how many can fit 79 | \begin{itemize} 80 | \item number of PIDs per SM = amount of SRAM // SRAM required per PID 81 | \end{itemize} 82 | \item if your code requires more data per single program than fits within the SM's pool of SRAM, it'll error 83 | \item If different PIDs in the same SM load the same data, then they'll share rather than loading \& storing duplicates (we will take advantage of this) 84 | \end{itemize} 85 | \end{frame} 86 | 87 | \begin{frame}{cores \& warps} 88 | \vspace{-0.5in} 89 | A core is the smallest unit of compute; it performs the actual floating point operations 90 | \begin{itemize} 91 | \item unlike CPU cores, GPU cores cannot do any arbitrary operation; they're built primarily for FLOPs 92 | \item one GPU core performs one FLOP at a time 93 | \end{itemize} 94 | A warp is the smallest grouping of GPU cores 95 | \begin{itemize} 96 | \item 32 cores per warp on Nvidia; 64 on AMD 97 | \item all cores in a warp must perform the exact same operation 98 | \item there are multiple warps per PID; you don't have to worry about how many 99 | \end{itemize} 100 | \end{frame} 101 | 102 | \end{document} -------------------------------------------------------------------------------- /02_GPU_architecture_basics/README.md: -------------------------------------------------------------------------------- 1 | jic anyone wanted the original pdf 2 | 3 | [![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/JD5V8qInVOs/0.jpg)](https://www.youtube.com/watch?v=JD5V8qInVOs) 4 | -------------------------------------------------------------------------------- /02_GPU_architecture_basics/pics/SM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/02_GPU_architecture_basics/pics/SM.png -------------------------------------------------------------------------------- /02_GPU_architecture_basics/pics/architecture_comparison-2717927587.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/02_GPU_architecture_basics/pics/architecture_comparison-2717927587.jpg -------------------------------------------------------------------------------- /02_GPU_architecture_basics/pics/mem_hierarchy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/02_GPU_architecture_basics/pics/mem_hierarchy.png -------------------------------------------------------------------------------- /02_GPU_architecture_basics/pics/triton-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/02_GPU_architecture_basics/pics/triton-logo.png -------------------------------------------------------------------------------- /03_cloud_GPU_setup/03_cloud_GPU_setup.md: -------------------------------------------------------------------------------- 1 | ## how to get up and running 2 | these instructions are for setting up a cloud GPU instance on [vast.ai](https://vast.ai); other cloud providers should be similar. If you're running linux on your own pc with your own GPU then you can skip all this 3 | 4 | see also the accompanying video: 5 | [![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/mmRlZKFLAvE/0.jpg)](https://www.youtube.com/watch?v=mmRlZKFLAvE) 6 | 7 | 1. setup an account and input your payment information on [vast.ai](https://vast.ai), [lambdalabs](https://lambdalabs.com) or similar provider. You can compare prices [here](https://cloud-gpus.com) 8 | 2. launch a GPU instance. You can choose whichever single GPU is cheapest (I can always find one for less than $1/hr on lambda and less than $0.15/hr on vast), but I'd recommend at least 16GB of RAM to ensure that all of the tests and benchmarks in this repo run without overwhelming memory. 9 | - *Note:* I'd also recommend choosing a GPU with at least the [Ampere architecture](https://en.wikipedia.org/wiki/Ampere_(microarchitecture)) or even newer. I've found the exact same Triton code can sometimes provide incorrect final values on older GPUs; not sure if this is a bug with Triton or a limitation of older GPU hardware. All the tests in this repo were run on an 4060Ti 10 | 3. once it's running, open the included remote jupyter lab environment (Vast & Lambda provide this so i presume others do too). Optionally you could instead ssh into the instance in order to be able to use your own IDE, but I'll let chatGPT help you if you want to do that 11 | 4. Open the terminal in your instance and update everything jic 12 | ``` 13 | sudo apt update 14 | ``` 15 | ``` 16 | sudo apt install build-essential 17 | ``` 18 | 5. install github CLI 19 | ``` 20 | sudo apt install gh 21 | ``` 22 | 5. Input the following command and follow the prompts to log in to your GitHub account 23 | ``` 24 | gh auth login 25 | ``` 26 | 6. Clone your fork of this repository 27 | ``` 28 | gh repo clone your_github_username/triton_docs_tutorials 29 | ``` 30 | 7. setup your git user email and username 31 | ``` 32 | git config --global user.email "your_github_email@email.com" 33 | ``` 34 | ``` 35 | git config --global user.name "your_github_username" 36 | ``` 37 | 8. now you can make changes and push updates as usual all through the jupyterlab environment's terminal. Note that unless you also setup a filesystem before intializing your GPU instance that everything will be deleted when you close out the instance, so don't forget to push your changes to github! 38 | 9. install all necessary packages 39 | ``` 40 | pip install numpy matplotlib pandas torch triton pytest 41 | ``` 42 | 10. and force an update to them jic; using newer Triton versions can be very important if you're experiencing bugs 43 | ``` 44 | pip install --upgrade torch 45 | ``` 46 | ``` 47 | pip install --upgrade triton 48 | ``` 49 | 11. Once you're done with all changes pushed, make sure to logout so some random GPU provider doesn't have access to your github account 50 | ``` 51 | gh auth logout 52 | ``` 53 | 54 | *note: if you're on an AMD GPU then this whole process should likely be the same, but throughout the repo you'll have to do your own research on the relatively small edits required to make your code more specifically efficient for your hardware. those edits can be found in the [original official triton docs tutorials](https://triton-lang.org/main/getting-started/tutorials/index.html); i removed them from my version* -------------------------------------------------------------------------------- /04_vector_addition/README.md: -------------------------------------------------------------------------------- 1 | check out the accompanying video: 2 | 3 | [![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/fYMS4IglLgg/0.jpg)](https://www.youtube.com/watch?v=fYMS4IglLgg) 4 | -------------------------------------------------------------------------------- /04_vector_addition/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/04_vector_addition/__init__.py -------------------------------------------------------------------------------- /04_vector_addition/vector-add-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/04_vector_addition/vector-add-performance.png -------------------------------------------------------------------------------- /04_vector_addition/vector_addition.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this document we'll implement basically the simplest possible Triton GPU Kernel, which does entry-wise 3 | addition for vectors. 4 | 5 | What you'll learn: 6 | - How to build a test to ensure your Triton kernels are numerically correct 7 | - Basics of Triton kernels (syntax, pointers, launch grids, DRAM vs SRAM, etc) 8 | - How to benchmark your Triton kernels against PyTorch 9 | 10 | Recommended order to read the code in: 11 | Step 1 - unit test 12 | Step 2 - wrapper 13 | Step 3 - kernel 14 | Step 4 - benchmark 15 | 16 | watch the accompanying YouTube video: 17 | https://youtu.be/fYMS4IglLgg 18 | see original triton documentation: 19 | https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py 20 | """ 21 | import torch 22 | import triton 23 | import triton.language as tl 24 | DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}') 25 | 26 | ######### Step 3 ######### 27 | # this `triton.jit` decorator tells Triton to compile this function into GPU code 28 | @triton.jit # only a subset of python capabilities are useable within a triton kernel 29 | def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 30 | """ 31 | This entry-wise addition kernel is relativevly simple; it's designed to only 32 | take in vectors as input and does not support any kind of broadcasting 33 | 34 | Each torch.tensor object that gets passed into a Triton kernelis implicitly 35 | converted into a pointer to its first element 36 | 37 | x_ptr: pointer to first entry of input vector of shape (n_elements) 38 | y_ptr: pointer to first entry of input vector of shape (n_elements) 39 | output_ptr: pointer to first entry of output vector of shape (n_elements) 40 | n_elements: size of our vectors 41 | BLOCK_SIZE: number of elements each kernel instance should process; should be a power of 2 42 | 43 | tl.constexpr designates BLOCK_SIZE as a compile-time variable (rather than run-time), 44 | meaning that every time a different value for BLOCK_SIZE is passed in you're actually 45 | creating an entirely separate kernel. I may sometimes refer to arguments with this 46 | designation as "meta-parameters" 47 | """ 48 | # There are multiple "programs" processing data; a program is a unique instantiation of this kernel. 49 | # Programs can be defined along multiple dimensions (defined by your launch grid in the wrapper). 50 | # this op is 1D so axis=0 is the only option, but bigger operations later may define program_id as a tuple 51 | # here we identify which program we are: 52 | pid = tl.program_id(axis=0) 53 | # Each program instance gets a unique ID along the specified axis 54 | # For example, for a vector of length 256 and BLOCK_SIZE=64: 55 | # pid=0 might processe elements [0:64] 56 | # pid=1 might processe elements [64:128] 57 | # pid=2 might processe elements [128:192] 58 | # pid=3 might processe elements [192:256] 59 | # I said "might" because the specific elements that get processed depend on the code below 60 | 61 | # herewe tell the program to process inputs that are offset from the initial data (^ described above) 62 | block_start = pid * BLOCK_SIZE 63 | 64 | # offsets is an array of int32 that act as pointers 65 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 66 | # sticking with the above example, if pid=1 then offsets=[64, 65, 66, ...., 126, 127] 67 | 68 | # create a mask to guard memory operations against out-of-bounds accesses 69 | mask = offsets < n_elements 70 | # if we didn't do this AND n_elements were not a multiple of BLOCK_SIZE then 71 | # the kernel might read entries that are actually part of some other tensor in memory 72 | # and use them in calculations 73 | 74 | # here load x and y 75 | # from (DRAM / VRAM / global GPU memory / high-bandwidth memory) which is slow to access 76 | # onto (SRAM / on-chip memory) which is much faster but very limited in size. 77 | # We store data not currently in-use on DRAM and do calculations on data that's in SRAM 78 | x = tl.load(x_ptr + offsets, mask=mask, other=None) # shape (BLOCK_SIZE) 79 | y = tl.load(y_ptr + offsets, mask=mask, other=None) # shape (BLOCK_SIZE) 80 | # The mask ensures we don't access memory beyond the vector's end. 81 | # `other` refers to what value to put in place of any masked-out values; it defaults 82 | # to None (so we didn't have to actually write it here) but depending on the operation 83 | # it may make more sense to use a value like 0.0 (we'll see this in a later tutorial) 84 | # Whenever you see a tl.load that is a memory operation which is expensive so we want to 85 | # keep track of how many memory operations we do. We count them by the total number of 86 | # entries being read/written to memory, in this case BLOCK_SIZE per kernel and 87 | # therefore n_elements in total across all running kernels for EACH of these two lines 88 | 89 | # here we perform the operation on SRAM 90 | # triton has its own internal definitions of all the basic ops that you'll need 91 | output = x + y 92 | # For the masked-out entries, None + None = None (really no operation happens at all). 93 | # Similar to keeping track of memory operations, we also keep track of floating point 94 | # operations (flops) using the shape of the blocks involved. Here this line does BLOCK_SIZE 95 | # flops for each pid, meaning n_elements flops total 96 | 97 | 98 | # write back to DRAM, being sure to mask in order to avoid out-of-bounds accesses 99 | tl.store(output_ptr + offsets, output, mask=mask) 100 | # here is a memory write operation of size BLOCK_SIZE per pid and therefore n_elements 101 | # in aggregate across all pids combined 102 | 103 | ######### Step 2 ######### 104 | def add(x: torch.Tensor, y: torch.Tensor): 105 | ''' 106 | helper/wrapper function to 107 | 1) allocate the output tensor and 108 | 2) enque the above kernel with appropriate grid/block sizes 109 | 110 | This wrapper function does not connect us to pytorch's graph, meaning it does not 111 | support backpropogation. That (as well as a backward pass kernel) is for a future lesson 112 | ''' 113 | # preallocating the output 114 | output = torch.empty_like(x) 115 | 116 | # Ensures all tensors are on the same GPU device 117 | # This is crucial because Triton kernels can't automatically move data between devices 118 | assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE,\ 119 | f'DEVICE: {DEVICE}, x.device: {x.device}, y.device: {y.device}, output.device: {output.device}' 120 | 121 | # getting length of the vectors 122 | n_elements = output.numel() # .numel() returns total number of entries in tensor of any shape 123 | 124 | # grid defines the number of kernel instances that run in parallel 125 | # it can be either Tuple[int] or Callable(metaparameters) -> Tuple[int] 126 | # in this case, we use a 1D grid where the size is the number of blocks: 127 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) 128 | # so 'BLOCK_SIZE' is a parameter to be passed into meta() at compile-time, not runtime 129 | # triton.cdiv = (n_elements + (BLOCK_SIZE - 1)) // BLOCK_SIZE 130 | # then meta() returns a Tuple with the number of kernel programs we want to 131 | # instantiate at once which is a compile-time constant, meaning that if it 132 | # changes Triton will actually create an entirely new kernel for that value 133 | 134 | # `triton.jit`'ed functionis can be indexed with a launch grid to obtain a callable GPU kernel 135 | add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) 136 | # BLOCK_SIZE of 1024 is a heuristic choice 137 | # It's a power of 2 (efficient for memory access patterns) 138 | # It's large enough to hide memory latency 139 | # It's small enough to allow multiple blocks to run concurrently on a GPU 140 | # in a later lesson we'll learn better methods than heuristics 141 | 142 | # the kernel writes to the output in-place rather than having to return anything 143 | # once all the kernel programs have finished running then the output gets returned here 144 | return output 145 | 146 | ######### Step 1 ######### 147 | def test_add_kernel(size, atol=1e-3, rtol=1e-3, device=DEVICE): 148 | """ 149 | Here is where we test the wrapper function and kernel that we wrote 150 | above to ensure all our values are correct, using pytorch as the 151 | correct answer to compare against 152 | """ 153 | # create data 154 | torch.manual_seed(0) 155 | x = torch.rand(size, device=DEVICE) 156 | y = torch.rand(size, device=DEVICE) 157 | # run kernel & pytorch reference implementation 158 | z_tri = add(x, y) 159 | z_ref = x + y 160 | # compare 161 | torch.testing.assert_close(z_tri, z_ref, atol=atol, rtol=rtol) 162 | print("PASSED") 163 | 164 | ######### Step 4 ######### 165 | # Triton has a set of built-in utilities that make it easy for us to plot performance of custom ops. 166 | # This decorator tells Triton that the below function is a benchmark and what benchmark conditions to run 167 | @triton.testing.perf_report( 168 | triton.testing.Benchmark( 169 | x_names=['size'], # argument names to use as an x-axis for the plot 170 | x_vals=[2**i for i in range(12, 28, 1)], # different values of x_names to benchmark 171 | x_log = True, # makes x-axis logarithmic 172 | line_arg='provider', # title of the legend 173 | line_vals=['triton', 'torch'], # designators of the different entries in the legend 174 | line_names=['Triton', 'Torch'], # names to visibly go in the legend 175 | styles=[('blue', '-'), ('green', '-')], # triton will be blue; pytorch will be green 176 | ylabel='GB/s', # label name for y-axis 177 | plot_name='vector-add-performance', # also used as file name for saving plot 178 | args={}, # we'll see how this is used in a later tutorial; need it even if it's empty 179 | ) 180 | ) 181 | def benchmark(size, provider): 182 | # creating our input data 183 | x = torch.rand(size, device=DEVICE, dtype=torch.float32) 184 | y = torch.rand(size, device=DEVICE, dtype=torch.float32) 185 | # each benchmark runs multiple times and quantiles tells matplotlib what confidence intervals to plot 186 | quantiles = [0.5, 0.05, 0.95] 187 | # defining which function this benchmark instance runs 188 | if provider == 'torch': 189 | ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) 190 | if provider == 'triton': 191 | ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles) 192 | # turning the raw millisecond measurement into meaninful units 193 | gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) 194 | # 3 = number of memory operations (2 reads + 1 write) 195 | # x.numel() = number of elements 196 | # x.element_size() = bytes per element (4 for float32, 2 for float16) 197 | # 1e-9 converts bytes to GB 198 | # 1e-3 converts milliseconds to seconds 199 | return gbps(ms), gbps(max_ms), gbps(min_ms) 200 | 201 | if __name__ == "__main__": 202 | # always run unit-tests 203 | test_add_kernel(size=98432) 204 | 205 | # Only run benchmark if explicitly requested 206 | import sys 207 | if len(sys.argv) > 1 and sys.argv[1] == "--benchmark": 208 | benchmark.run(save_path='.', print_data=False) 209 | -------------------------------------------------------------------------------- /05_fused_softmax/README.md: -------------------------------------------------------------------------------- 1 | check out the accompanying video: 2 | 3 | [![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/ftknUZDQCPc/0.jpg)](https://www.youtube.com/watch?v=ftknUZDQCPc) 4 | -------------------------------------------------------------------------------- /05_fused_softmax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/05_fused_softmax/__init__.py -------------------------------------------------------------------------------- /05_fused_softmax/fused_softmax.py: -------------------------------------------------------------------------------- 1 | """ 2 | This "fused softmax" kernel only works on matrices whose rows fit in the GPU's SRAM. 3 | 4 | What you'll learn: 5 | - The importance of reducing memory reads/writes 6 | - How to fuse multiple operations into one kernel to reduce memory reads/writes 7 | - How to fetch GPU specifications 8 | - Some parts of the GPU architecture that you don't usually have to think about 9 | when writing Triton kernels 10 | - How to define meta-parameters using GPU-specific attributes and rough heuristics 11 | - Pipeline parallelism & the weird way that for-loops work within GPU kernels 12 | - How to choose the value of extra entries when masking 13 | 14 | Recommended order to read the code in: 15 | Step 1 - naive implementation 16 | Step 2 - unit test 17 | Step 3 - wrapper 18 | Step 4 - kernel 19 | Step 5 - benchmark 20 | 21 | watch the accompanying YouTube video: 22 | https://youtu.be/ftknUZDQCPc 23 | see original triton documentation: 24 | https://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.html 25 | """ 26 | import torch 27 | import triton 28 | import triton.language as tl 29 | DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}') 30 | 31 | ######### Step 1 ######### 32 | # first we'll look at the naive implementation jic you need a refresher 33 | def naive_softmax(x): 34 | ''' 35 | Built for input of size (M,N) 36 | Safe softmax is when we subtract the maximum element in order to avoid numerical 37 | overflows when doing .exp(); softmax is invariant to this shift 38 | ''' 39 | # read MN elements, find their max along N, and write M elements (the maxes) 40 | x_max = x.max(dim=1)[0] 41 | # pytorch actually outputs a tuple of (values, indices) so [0] grabs the values; 42 | # we ignored the indices when talking about memory writes above 43 | # read MN + M elements, subtraction is MN flops, and write MN elements 44 | z = x - x_max[:, None] 45 | # read MN elements and write MN elemnts 46 | numerator = torch.exp(z) 47 | # exp is actually a lot of flops per element but we're only worried about mem ops rn 48 | # read MN elements, do MN flops to find M sum values, and then write M elements 49 | denominator = numerator.sum(dim=1) 50 | # read MN + M elements, division is MN flops, then write MN elements 51 | out = numerator / denominator[:, None] 52 | 53 | # in total we did 8MN + 4M memory operations 54 | # (read 5MN + 2M elements; wrote 3MN + 2M elements) 55 | return out 56 | """ 57 | that's a whole lot of memory operations. we'd prefer to have a custom "fused" kernel that only 58 | reads x from DRAM once and does all the necessary computations on SRAM as opposed to repeatedly 59 | reading & writing to DRAM. that would give a ~4x speedup since 60 | (8MN + 4M)/2MN = 4 (ignoring the solo M term a la big O notation) 61 | 62 | torch.jit.script flag and torch.compile actually aim to do this fusion automatically but can't 63 | pull it off quite as well as we're about to 64 | 65 | our fused softmax kernel will work as follows: 66 | each program (individual call of the kernel) loads a set of rows of the input matrix X which are 67 | strided by number of programs, softmaxes it and writes back the result to the output Y 68 | 69 | note an important limitation of Triton is that each block must have a power-of-two number of 70 | elements, so we need to internally "pad" each row and guard the memory operations properly 71 | """ 72 | 73 | ######### Step 4 ######### 74 | @triton.jit 75 | def _softmax_kernel( 76 | input_ptr, output_ptr, 77 | input_row_stride, output_row_stride, # number of elements to skip when moving to next row 78 | n_rows, n_cols, # matrix dimensions 79 | BLOCK_SIZE: tl.constexpr, # lowest power-of-2 greater than n_cols 80 | num_stages: tl.constexpr, 81 | ): 82 | # the row that this program starts with is defined by the pid 83 | row_start = tl.program_id(0) 84 | # then this gets the total number of parallel programs, which we'll use to know how large 85 | # of a step to make in our for loop once we finish the first row 86 | row_step = tl.num_programs(0) 87 | # Each program processes rows strided by row_step 88 | # (ex. if there are 4 programs, program 0 handles rows 0,4,8...) 89 | 90 | # whereas tl.arange() provides an array of values, tl.range() acts as an iterator 91 | for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): 92 | # rather than actually implement each iteration of the for loop sequentially, triton can use 93 | # num_stages to work on different interations of the for loop simultaneously. Of course 94 | # only do this when the iterations don't depend on each other 95 | 96 | # the stride represents how much we need to increase the pointer to advance 1 row 97 | row_start_ptr = input_ptr + row_idx * input_row_stride 98 | # inyuiyively input_row_stride should be 1 as long as the input tensor is contiguous. 99 | # but what if a non-contiguous view of a manipulated tensor were passed in? then 100 | # input_row_stride matters 101 | 102 | # load the row into SRAM, using a mask since BLOCK_SIZE is > than n_cols if n_cols is not a power of 2 103 | col_offsets = tl.arange(0, BLOCK_SIZE) # we can fit each row in a single block 104 | input_ptrs = row_start_ptr + col_offsets 105 | mask = col_offsets < n_cols 106 | row = tl.load(input_ptrs, mask=mask, other=float('-inf')) 107 | # we fill in masked out indices with -inf since that's the value that won't influence softmax 108 | 109 | # subtract maximum for numerical stability 110 | row_minus_max = row - tl.max(row, axis=0) 111 | # all the invalid -inf values remain -inf when we subtract the max 112 | # note that exponentiation in Triton is fast but approximate; later we'll learn an even faster alternative 113 | numerator = tl.exp(row_minus_max) 114 | # all the -inf values get set to 0 since exp(-inf)=0 115 | denominator = tl.sum(numerator, axis=0) 116 | # all the invalid 0 values do get summed but don't matter since they're 0 117 | softmax_output = numerator / denominator 118 | # all the invalid 0's are 0/sum and therefore remain 0 119 | 120 | # write output back to DRAM 121 | output_row_start_ptr = output_ptr + row_idx * output_row_stride 122 | tl.store(output_row_start_ptr + col_offsets, softmax_output, mask=mask) 123 | # using our mask we only store back the valid n_cols values 124 | 125 | ######### Step 3 ######### 126 | """ 127 | before we create the wrapper function that enqueues the kernel and its meta-parameters, we're going to 128 | fetch the specifications of our GPU to help later when defining our meta-parameters such that they're 129 | especially well suited (fast) to the specific GPU we're using 130 | """ 131 | # fetching a dictionary full of the GPU's specifications 132 | properties = triton.runtime.driver.active.utils.get_device_properties(DEVICE.index) 133 | # each Streaming Multi-processor (SM) is like a mini-processor that can run multiple programs 134 | NUM_SM = properties["multiprocessor_count"] 135 | # registers are the fastest memory on the GPU 136 | NUM_REGS = properties["max_num_regs"] 137 | # each SM has a limited number of registers; 138 | # programs share these registers, so using too many per program limits parallelism 139 | # each SM has a dedicated pool of SRAM that it can access 140 | # since there can be multiple programs per SM, those programs share the same SRAM 141 | # ^that will be very useful information later in the matmul tutorial 142 | TOTAL_SRAM_PER_SM = properties["max_shared_mem"] 143 | # a warp is a group of threads that execute together 144 | # a thread can be thought of as analagous to a single CPU core, but far more limited in the operations it can do 145 | WARP_SIZE = properties["warpSize"]# usually 32 on nvidia GPUs and 64 on AMD 146 | 147 | def softmax(x): 148 | ''' 149 | helper/wrapper function to 150 | 1) allocate the output tensor and 151 | 2) enque the above kernel with appropriate grid/block sizes 152 | 153 | This wrapper function does not connect us to pytorch's graph, meaning it does not 154 | support backpropogation. That (as well as a backward pass kernel) is for a future lesson 155 | ''' 156 | # this kernel is only built to support matrices; expanding that support is simple but for a later lesson 157 | assert x.ndim == 2 158 | n_rows, n_cols = x.shape 159 | 160 | # the block size is the smallest power of 2 greater than the number of columns in x 161 | BLOCK_SIZE = triton.next_power_of_2(n_cols) 162 | 163 | # a trick we can use is to ask the compiler to use more threads per row by 164 | # increasing the number of warps (`num_warps`) over which each row is distributed. 165 | # for now these settings are just a heuristic 166 | # you will see in the next tutorial how to auto-tune this value in a more natural way 167 | # so you don't have to come up with manual heuristics yourself 168 | num_warps = 4 169 | if BLOCK_SIZE >= 2048: 170 | num_warps = 8 171 | if BLOCK_SIZE >= 4096: 172 | num_warps = 16 173 | 174 | # Rather than executing all code within a kernel sequentially, the GPU can actually do multiple things at once. 175 | # This is called the number of software pipelining stages. 176 | # For example, with 2 stages we can have one do the operation while the other is loading the next operands 177 | # from DRAM into SRAM. With 3 we can have one do current operations, one load next operands, and one saving 178 | # previous operands. 179 | # Triton just needs the number of stages and it'll handle how to use them efficiently. 180 | # Here we use a simple heuristic of "if we've got a lot of memory, use 4. otherwise use 2" 181 | num_stages = 4 if TOTAL_SRAM_PER_SM > 200_000 else 2 182 | 183 | # allocate output 184 | y = torch.empty_like(x) 185 | 186 | # .warmup() pre-compiles kernel and tells us how many registers and how much shared memory it needs 187 | kernel = _softmax_kernel.warmup(x, y, # this warmup depends on the attributes of the input and output 188 | x.stride(0), y.stride(0), # see below 189 | n_rows, n_cols, 190 | BLOCK_SIZE=BLOCK_SIZE, 191 | num_stages=num_stages, 192 | num_warps=num_warps, 193 | grid=(1,)) 194 | # x.stride() for each dimension tells us how many entries in memory a pointer needs to move forward in order 195 | # to get to the next element of the tensor along the specified dimension. 196 | # For any tensor x that is "contiguous", meaning ~cleanly/simply~ defined in memory and for a shape (M, N, K) 197 | # you can expect x.shape(0) == N*K, x.shape(1)==K, and x.shape(2)==1, or more generally 198 | # x.shape(-Z)==math.prod(x.shape[-Z:]) 199 | # A tensor might be non-contiguous if, for example, it's been saved to memory using torch.view() or some similar 200 | # operation that leaves the original data in place but messes with dimensions 201 | 202 | # here's the info that warmup process gave us 203 | kernel._init_handles() 204 | n_regs = kernel.n_regs 205 | sram_needed_per_program = kernel.metadata.shared 206 | 207 | # and here's how we use that info to setup our kernel 208 | # register-based occupancy 209 | reg_occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) 210 | # each SM has NUM_REGS registers (eg 65536) 211 | # each program uses 212 | # n_regs per register thread (eg 32) 213 | # WARP_SIZE threads per warp (32 on Nvidia, 64 on AMD) 214 | # num_warps warps per program (4, 8, or 16 in our case with the aforementioned heuristic) 215 | # so each program needs n_regs * WARP_SIZE * num_warps registers total 216 | # therefore we can fit reg_occupancy programs per SM 217 | # ex. 65536 // (32 * 32 * 8) = 8 programs per SM (assuming num_warps=8) 218 | # shared memory-based occupancy 219 | sram_occupancy = TOTAL_SRAM_PER_SM // sram_needed_per_program 220 | # determines how many programs can run per SM based on register usage and shared memory usage 221 | programs_per_sm = min(reg_occupancy, sram_occupancy) 222 | # the former is the optimal allocation assuming we have more than enough SRAM 223 | # the latter is our limit on SRAM when splitting it equally among all SMs 224 | # then given our number of SMs, we calculate how many programs to run in total 225 | num_programs = min(NUM_SM * programs_per_sm, n_rows) 226 | # ofc we have another limit since we've got no need to surpass the n_rows in the matrix 227 | 228 | # grid configuration; each row gets its own program 229 | grid = (num_programs, 1, 1) 230 | # the extra 1's are usually not necessary if they're not being used 231 | # we use them here because the .warmup() we used earlier has a weird quirk in the way 232 | # it's implemented that forces only 3D launch grids to be inputted once it's been used 233 | # in future lessons we don't use .warmup() so we'll not be required to do this again 234 | 235 | # And now we get to run the kernel with our heuristics-based launch grid 236 | kernel[grid]( 237 | x, y, 238 | x.stride(0), y.stride(0), 239 | n_rows, n_cols, 240 | ) 241 | return y 242 | 243 | ######### Step 2 ######### 244 | def test_softmax_kernel(size: tuple, atol=1e-3, rtol=1e-3, device=DEVICE): 245 | """ 246 | Here is where we test the wrapper function and kernel that we wrote 247 | above to ensure all our values are correct, using pytorch as the 248 | correct answer to compare against 249 | 250 | we'll use an irregular number of rows & cols to verify that our padding mechanism works 251 | """ 252 | # create input data 253 | torch.manual_seed(0) 254 | assert type(size) is tuple and len(size) == 2 255 | x = torch.randn(size[0], size[1], device=DEVICE) 256 | # run kernel & pytorch reference implementation 257 | z_tri = softmax(x) 258 | z_ref = torch.softmax(x, axis=1) 259 | # notice our implementation doesn't give a choice for what axis to softmax along. 260 | # this is a common theme of custom GPU kernels; because pytorch has to write code that 261 | # is more general, it is slower than it could be 262 | # compare 263 | torch.testing.assert_close(z_tri, z_ref, atol=atol, rtol=rtol) 264 | print("PASSED") 265 | 266 | ######### Step 5 ######### 267 | @triton.testing.perf_report( 268 | triton.testing.Benchmark( 269 | x_names=['N'], 270 | x_vals=[128 * i for i in range(2, 100)], 271 | line_arg='provider', 272 | line_vals=['triton', 'torch'], 273 | line_names=["Triton", "Torch"], 274 | styles=[('blue', '-'), ('green', '-')], 275 | ylabel="GB/s", 276 | plot_name="softmax-performance", 277 | args={'M': 4096} # values for function arguments not in x_names 278 | )) 279 | def benchmark(M, N, provider): 280 | # making the input data 281 | x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) 282 | 283 | # these two lines ensure more accurate benchmarks; i usually forget to use them but it's not a big deal 284 | stream = getattr(torch, DEVICE.type).Stream() 285 | getattr(torch, DEVICE.type).set_stream(stream) 286 | 287 | if provider == 'torch': 288 | ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) 289 | if provider == 'triton': 290 | ms = triton.testing.do_bench(lambda: softmax(x)) 291 | gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) 292 | # 2 = number of memory operations (1 read + 1 write) 293 | # x.numel() = number of elements 294 | # x.element_size() = bytes per element (4 for float32) 295 | # 1e-9 converts bytes to GB 296 | # 1e-3 converts milliseconds to seconds 297 | return gbps(ms) 298 | 299 | if __name__ == "__main__": 300 | # always run unit-tests 301 | test_softmax_kernel(size=(1823, 781)) 302 | 303 | # Only run benchmark if explicitly requested 304 | import sys 305 | if len(sys.argv) > 1 and sys.argv[1] == "--benchmark": 306 | benchmark.run(save_path='.', print_data=False) -------------------------------------------------------------------------------- /05_fused_softmax/softmax-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/05_fused_softmax/softmax-performance.png -------------------------------------------------------------------------------- /06_matmul/README.md: -------------------------------------------------------------------------------- 1 | check out the accompanying video: 2 | 3 | [![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/XAr_iVE8uUk/0.jpg)](https://www.youtube.com/watch?v=XAr_iVE8uUk) 4 | -------------------------------------------------------------------------------- /06_matmul/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/06_matmul/__init__.py -------------------------------------------------------------------------------- /06_matmul/block_wise_matmul.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/06_matmul/block_wise_matmul.jpeg -------------------------------------------------------------------------------- /06_matmul/grouped_vs_row_major_ordering_annotated.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/06_matmul/grouped_vs_row_major_ordering_annotated.jpg -------------------------------------------------------------------------------- /06_matmul/matmul-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/06_matmul/matmul-performance.png -------------------------------------------------------------------------------- /06_matmul/matmul.py: -------------------------------------------------------------------------------- 1 | """ 2 | This matmul kernel can be a bit confusing but is very crucial to understand 3 | 4 | What you'll learn: 5 | - Automatic performance tuning 6 | - Program re-ordering for improved SRAM hit rate 7 | - Multi-dimensional pointer arithmetic 8 | - High precision data type accumulation 9 | - using the Triton interpreter (kind of) 10 | 11 | Recommended order to read the code in: 12 | Step 1 - unit test 13 | Step 2 - wrapper 14 | Step 3 - kernel 15 | Step 4 - benchmark 16 | 17 | For matmul of A @ B = C of shapes (M, K) @ (K, N) = (M, N), the following 18 | algorithm is numerically equivalent to what our code will output, but we'll 19 | get to the answer in a different way 20 | for m in range(0, M, BLOCK_SIE_M): # do in parallel 21 | for n in range(0, N, BLOCK_SIZE_N): # do in parallel 22 | acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32) 23 | for k in range(0, K, BLOCK_SIZE_K): 24 | a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K] 25 | b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N] 26 | acc += dot(a,b) 27 | C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc 28 | 29 | see original 30 | https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html 31 | """ 32 | import torch 33 | import triton 34 | import triton.language as tl 35 | DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}') 36 | 37 | ######### Step 3 ######### 38 | 39 | # un-comment this to run a numpy emulation of Triton on CPU & be able to debug with print() statements 40 | #import os 41 | #os.environ["TRITON_INTERPRET"] = "1" 42 | 43 | # autotuning is just setting up a bunch of different potential meta-parameters configurations that Triton will automatically 44 | # choose from later based on which one performs best on our specific GPU. Triton will figure out for us which one to use. They're 45 | # all values chosen heuristically, but notice everything is a multiple of 32 in sticking w/ the number of threads in a warp. 46 | autotune_configs = [ 47 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE': 8}, num_stages=3, num_warps=8), 48 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=4, num_warps=4), 49 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=4, num_warps=4), 50 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=4, num_warps=4), 51 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=4, num_warps=4), 52 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=4, num_warps=4), 53 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=5, num_warps=2), 54 | triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=5, num_warps=2) 55 | ] 56 | # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator which consumes 57 | # 1) a list of `triton.Config` objects that define different configs of meta-parameters and compilation options 58 | # 2) an auto-tuning *key* whose change in values will trigger a new evaluation of all the provided configs, meaning 59 | # that any time either M, N, or K changes with a new input, Triton will check which config is best all over again 60 | @triton.autotune(configs = autotune_configs, key=['M', 'N', 'K']) 61 | @triton.jit 62 | def _matmul_kernel( 63 | a_ptr, b_ptr, c_ptr, 64 | M, N, K, 65 | stride_a_M, stride_a_K, 66 | stride_b_K, stride_b_N, 67 | stride_c_M, stride_c_N, 68 | # meta-parameters 69 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, 70 | GROUP_SIZE: tl.constexpr, 71 | ): 72 | """ 73 | First we need to map each program id to the block of C it will compute. 74 | Let's look at an example where 75 | M = N = K = 8, 76 | BLOCK_SIZE_M = BLOCK_SIZE_K = BLOCK_SIZE_N = 2 77 | A naive implementation might do something like 78 | [0, 1, 2, 3] 79 | [4, 5, 6, 7] 80 | [8, 9, 10, 11] 81 | [12, 13, 14, 15] 82 | where each of those PIDs corresponds to a 2x2 block of C (which is of size 8x8). 83 | What parts of A and B do we need in order to compone one of them, say 0? 84 | Let's look at how matmul works, where x's will denote the blocks in A and B that we need to use to create 85 | the block of C corresponding to PID=0 86 | A @ B = C 87 | [x, x, x, x] [x, _, _, _] [0, _, _, _] 88 | [_, _, _, _] [x, _, _, _] [_, _, _, _] 89 | [_, _, _, _] [x, _, _, _] [_, _, _, _] 90 | [_, _, _, _] [x, _, _, _] [_, _, _, _] 91 | So in order to create the 2x2 block of C that corresponds to PID=0, we need four 2x2 blocks from the top rows 92 | of A and four 2x2 blocks from the first columns of B. 93 | Note that rather than loading all 8 blocks into SRAM at the same time, we can iterate over the columns/rows 94 | of A/B (respectively), doing a kind of mini matmul between two corresponding blocks and adding them together as we go. 95 | A @ B 96 | [--------->] [ | , _, _, _] 97 | [_, _, _, _] [ | , _, _, _] 98 | [_, _, _, _] [ | , _, _, _] 99 | [_, _, _, _] [\|/, _, _, _] 100 | If this fact isn't intuitive, check out `./block_wise_matmul.png` 101 | Great, if we were to implement this algorithm as described it'd work. 102 | However, it would not be nearly as fast as PyTorch's method, which implements something far more clever. 103 | To see why, we need to think about our SRAM usage. 104 | Notice that PIDs 0 through 3 all utilize the same row of blocks of A, and remember that we can have 105 | multiple programs per SM all sharing the same pool of SRAM. 106 | That means that rather than each of them loading that same row of blocks of A separately, leading to a bunch of 107 | duplicates, once one PID loads a block of A along that row then the other 3 could just re-use it! 108 | 109 | Luckily we do not have to ~explicitly~ tell Triton to do this; every time a PID runs tl.load() it'll first automatically 110 | check to see if that data already exists in SRAM thanks to some other PID sharing the same SM that got to it first. 111 | While we don't have to explicitly tell Triton to do this, we should think very carefully about helping Triton take 112 | best advantage of this ability by manipulating the orderin gof the PIDs. 113 | Importantly, PIDs get assigned to SMs IN ORDER!! 114 | To re-state, the order of your PIDs determines which blocks of C get to share SRAM!! 115 | Let's look again at PIDs 0 through 3, specifically at which blocks of A and B they need to load: 116 | PID = 0 117 | [x, x, x, x] [x, _, _, _] 118 | [_, _, _, _] [x, _, _, _] 119 | [_, _, _, _] [x, _, _, _] 120 | [_, _, _, _] [x, _, _, _] 121 | PID = 1 122 | [x, x, x, x] [_, x, _, _] 123 | [_, _, _, _] [_, x, _, _] 124 | [_, _, _, _] [_, x, _, _] 125 | [_, _, _, _] [_, x, _, _] 126 | PID = 2 127 | [x, x, x, x] [_, _, x, _] 128 | [_, _, _, _] [_, _, x, _] 129 | [_, _, _, _] [_, _, x, _] 130 | [_, _, _, _] [_, _, x, _] 131 | PID = 3 132 | [x, x, x, x] [_, _, _, x] 133 | [_, _, _, _] [_, _, _, x] 134 | [_, _, _, _] [_, _, _, x] 135 | [_, _, _, _] [_, _, _, x] 136 | Notice that although they can all share the first row of blocks of A and therefore avoid loading the other three 137 | rows of blocks, they actually end up loading every single column of blocks of B. 138 | Can we do better? 139 | Can we get the same number of PIDs (and therefore the same number of blocks of C) to be constructed using fewer 140 | total blocks of A and B? 141 | Currently, with this method that we'll call "row-major ordering", we're loading: 142 | (1 row of blocks of A) + (4 columns of blocks of B) = 5 total rows/cols of blocks loaded to SRAM 143 | 144 | Well what if instead of putting PIDs 0 through 3 onto the same SM, we could put PIDs 0, 1, 4, and 5 on the same SM? 145 | Taking a look at what PIDs 4 and 5 need to load: 146 | PID = 4 147 | [_, _, _, _] [x, _, _, _] 148 | [x, x, x, x] [x, _, _, _] 149 | [_, _, _, _] [x, _, _, _] 150 | [_, _, _, _] [x, _, _, _] 151 | PID = 5 152 | [_, _, _, _] [_, x, _, _] 153 | [x, x, x, x] [_, x, _, _] 154 | [_, _, _, _] [_, x, _, _] 155 | [_, _, _, _] [_, x, _, _] 156 | Now suddenly with this hypoethetical new setup, we would only need to load 157 | (2 rows of blocks of A) + (2 columns of blocks of B) = 4 total rows/cols of blocks loaded to SRAM 158 | And yet we're still getting the same number of blocks of C as output! 159 | This strategy is called "group-major ordering". 160 | The effect doesn't seem too huge with this tiny example, but as the number of blocks increases it becomes increasingly 161 | effective at saving us from having to do so many duplicate loads of blocks of A and B onto different SMs. 162 | 163 | However, remember that Triton loads blocks into SMs based on the order of PIDs, meaning that even though we'd love it 164 | if PIDs 0, 1, 4, and 5 all shared the same SRAM, in reality PIDs 3 and 4 are likely going to get in the way of 165 | that happening. 166 | So how do we ensure the blocks of C corresponding to PIDs 4 and 5 get loaded onto the same SM as 0 and 1? 167 | We'll actually have to re-order our PIDs, meaning re-assign them to different blocks of C. 168 | Remember our input launch grid is 1-dimensional (like all previous launch grids we've seen), meaning it 169 | was defined by a tuple with only one entry. 170 | It's our job once inside the kernel to take that 1D list of PIDs and morph them into the shape we desire. 171 | I'll reiterate, the key thing to note here is that PIDs that are numerically closer together are more likely to 172 | end up on the same SM, meaning that even though we said earlier it'd be great if 0, 1, 4, and 5 all 173 | shared SRAM, in reality according to our earlier "naive" launch grid, 0, 1, 2, and 3 are going to be grouped together. 174 | So what we need to do instead is move 2 and 3 such that they correspond to the blocks of C that we previously had 175 | assigned to 4 and 5. 176 | Instead of explaining, check out this new visual ordering: 177 | [0, 2, 4, 6] 178 | [1, 3, 5, 7] 179 | [8, 10, 12, 14] 180 | [9, 11, 13, 15] 181 | Now, 0 through 3 correspond to group-major ordering! Notice in this example we can visualize it as splitting our 182 | PIDs into "groups" demarcated by the dashed lines 183 | [0, 2, | 4, 6] 184 | [1, 3, | 5, 7] 185 | --------|-------- 186 | [8, 10, | 12, 14] 187 | [9, 11, | 13, 15] 188 | The size of these groups is defined by our "GROUP_SIZE" meta-parameter. 189 | To get this re-ordering of our PIDs we'll need to do some technically simple but surprisingly difficult to keep 190 | track of math. 191 | """ 192 | # we start with a 1D launch grid that we will turn into a 2D grid with a complicated "group-wise" ordering 193 | PID = tl.program_id(axis=0) 194 | # defining the size of groups 195 | num_PID_along_M = tl.cdiv(M, BLOCK_SIZE_M) # the number of blocks along M dimension 196 | num_PID_along_N = tl.cdiv(N, BLOCK_SIZE_N) # the number of blocks along N dimension 197 | num_PID_in_group = GROUP_SIZE * num_PID_along_N 198 | # figurinig out which group this PID is in 199 | group_id = PID // num_PID_in_group 200 | # tells us which row to start at for this group 201 | first_PID_in_group_along_M = group_id * GROUP_SIZE 202 | # this is usually equal to GROUP_SIZE; the alternative case happens when we're at edge of the tensor and 203 | # its dimensions don't cleanly divde into GROUP_SIZE # TODO is this true? 204 | group_size_adj = min(num_PID_along_M - first_PID_in_group_along_M, GROUP_SIZE) 205 | # this is the bulk of the actual mapping of PIDs to group-major ordering 206 | PID_M = first_PID_in_group_along_M + ((PID % num_PID_in_group) % group_size_adj) 207 | # (PID % num_PID_in_group) puts the current program id into the context of a group 208 | # (first_PID_in_group_along_m + ...) shifts the PID into the correct group 209 | # (... % group_size_adj) removes the column component to get us onto the correct row 210 | PID_N = (PID % num_PID_in_group) // group_size_adj 211 | # (... // group_size_adj) removes the row component to get us onto the correct column 212 | 213 | # Now that the PID nightmare is done we can move onto the kernel code you're more used to seeing. 214 | 215 | # Let's create pointer vectors for the first group of blocks of the input matrices 216 | offsets_M = PID_M * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 217 | offsets_N = PID_N * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 218 | offsets_K = tl.arange(0, BLOCK_SIZE_K) 219 | # in previous lessons the blocks we loaded into SRAM were vectors; here they are matrices 220 | a_offsets = offsets_M[:, None] * stride_a_M + offsets_K[None, :] * stride_a_K 221 | b_offsets = offsets_K[:, None] * stride_b_K + offsets_N[None, :] * stride_b_N 222 | """ 223 | [:, None] turns [m1,m2,m3] into [[m1],[m2],[m3]] 224 | [None, :] turns [n1,n2,n3] into [[n1,n2,n3]] 225 | combining them gives the matrix 226 | [[m1n1, m1n2, m1n3], 227 | [m2n1, m2n2, m2n3], 228 | [m3n1, m3n2, m3n3]] 229 | """ 230 | 231 | # inputs tensors are fp16 but we accumulate into a block of fp32 values for higher accuracy (we'll revert later) 232 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # the full C is shape (M, N) 233 | # for a demonstration of why accumulation works, check out `./block_wise_matmul.png` 234 | 235 | # we'll iterate along the K dimension of both A and B to compute a single block of the C matrix 236 | for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 237 | # out-of-bounds entries (along K) need to be masked out 238 | mask = offsets_K < K - k * BLOCK_SIZE_K 239 | # k * BLOCK_SIZE_K is the current starting index of offsets_k. 240 | # so this only really activates when k is within BLOCK_SIZE_K entries from K 241 | # meaning this gets triggered on the last iteration of the loop, and only if K is not a multiple of BLOCK_SIZE_K 242 | 243 | # Now we load blocks of A and B matrices. If multiple blocks in a group are on the same SM, 244 | # they can share these loaded values, which reduces the number of expensive loads from DRAM 245 | a = tl.load(a_ptr + a_offsets, mask=mask[None, :], other=0.0) # shape (BLOCK_SIZE_M, BLOCK_SIZE_K) 246 | b = tl.load(b_ptr + b_offsets, mask=mask[:, None], other=0.0) # shape (BLOCK_SIZE_K, BLOCK_SIZE_N) 247 | # fill in any masked-out parts with 0.0's since they don't have any effect on the summation in the next step 248 | 249 | # we accumulate along the K dimension 250 | accumulator = tl.dot(a, b, acc=accumulator) 251 | # triton is weird with operation notation; this is actually a tiny matmul not a dot product 252 | # shape (BLOCK_SIZE_M, BLOCK_SIZE_K) @ (BLOCK_SIZE_K, BLOCK_SIZE_N) = (BLOCK_SIZE_M, BLOCK_SIZE_N) 253 | # `acc` tells Triton to write the output of the matmul directly to accumulator, which is more efficient than 254 | # accumulator += tl.dot(a, b) 255 | 256 | # advance the pointers to the next block along K 257 | a_offsets += BLOCK_SIZE_K * stride_a_K 258 | b_offsets += BLOCK_SIZE_K * stride_b_K 259 | """ 260 | A visual representation of the accumulation movement for PID=0 261 | A @ B 262 | [--------->] [ | , _, _, _] 263 | [_, _, _, _] [ | , _, _, _] 264 | [_, _, _, _] [ | , _, _, _] 265 | [_, _, _, _] [\|/, _, _, _] 266 | """ 267 | 268 | # and now we reset the data type to the expected output 269 | accumulator = accumulator.to(tl.float16) 270 | 271 | # write back the block of the output matrix C with masks 272 | c_offsets = stride_c_M * offsets_M[:, None] + stride_c_N * offsets_N[None, :] 273 | c_mask = (offsets_M[:, None] < M) & (offsets_N[None, :] < N) # notice the 2D mask 274 | tl.store(c_ptr + c_offsets, accumulator, mask=c_mask) # shape (BLOCK_SIZE_M, BLOCK_SIZE_N) 275 | 276 | 277 | ######### Step 2 ######### 278 | def matmul(a, b): 279 | # check constraints 280 | assert a.ndim == b.ndim == 2, "only supports matrices, not vectors or tensors" 281 | assert a.shape[1] == b.shape[0], "incompatible dimensions" 282 | #assert a.is_contiguous() and b.is_contiguous, "input matrices must be contiguous" 283 | a, b = a.to(torch.float16), b.to(torch.float16) 284 | 285 | # get dimesion lengths 286 | (M, K), (_, N) = a.shape, b.shape 287 | 288 | # allocates output 289 | c = torch.empty((M, N), device=a.device, dtype=torch.float16) 290 | 291 | # cdiv(x, y) = (x + (y - 1)) // y 292 | # A naive (slow) launch grid might try to separate our axes of parallelizatio into 2 dimensions, one 293 | # for cdiv(M, BLOCK_SIZE_M) and the other for cdiv(N, BLOCK_SIZE_N) 294 | # Here instead we use a 1D launch kernel defined by cdiv(M, BLOCK_SIZE_M) * cdiv(N, BLOCK_SIZE_N) 295 | # The reasoning behind this is explained inside the kernel 296 | grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), ) 297 | _matmul_kernel[grid]( 298 | a, b, c, 299 | M, N, K, 300 | a.stride(0), a.stride(1), 301 | b.stride(0), b.stride(1), 302 | c.stride(0), c.stride(1), 303 | ) 304 | return c 305 | 306 | ######### Step 1 ######### 307 | def test_matmul_kernel(size: tuple, atol=1e-2, rtol=1e-1, device=DEVICE): # TODO does rtol=0 mean we don't use rtol? 308 | """ 309 | Here is where we test the wrapper function and kernel that we wrote 310 | above to ensure all our values are correct, using pytorch as the 311 | correct answer to compare against 312 | 313 | We use higher tolerance values than previous tests because all the flop 314 | accumulation can really compound when it comes to a matmul; even slight 315 | differences in the block size and launch grid ordering from what PyTorch 316 | does can result in pretty sizeable discrepancies 317 | """ 318 | # create input data 319 | torch.manual_seed(0) 320 | assert type(size) == tuple and len(size) == 2 321 | a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) 322 | b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) 323 | # run kernel & pytorch reference implementation 324 | c_tri = matmul(a, b) 325 | c_ref = torch.matmul(a, b) 326 | # compare 327 | torch.testing.assert_close(c_tri, c_ref, atol=atol, rtol=rtol) 328 | print("PASSED") 329 | 330 | ######### Step 4 ######### 331 | configs = [ 332 | triton.testing.Benchmark( 333 | x_names = ["M", "N", "K"], # we can increase multiple dimensions simultaneously while benchmarking 334 | x_vals = [128 * i for i in range(2, 33)], 335 | line_arg = "provider", 336 | line_vals = ["torch", "triton"], 337 | line_names = ["PyTorch", "Triton"], 338 | styles = [("green", "-"), ("blue", "-")], 339 | ylabel = "TFLOPS", 340 | plot_name = "matmul-performance", 341 | args={}, 342 | ) 343 | ] 344 | @triton.testing.perf_report(configs) 345 | def benchmark(M, N, K, provider): 346 | a = torch.randn((M, K), device=DEVICE, dtype=torch.float16) 347 | b = torch.randn((K, N), device=DEVICE, dtype=torch.float16) 348 | quantiles = [0.5, 0.05, 0.95] 349 | if provider == 'torch': 350 | ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) 351 | if provider == 'triton': 352 | ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) 353 | perf = lambda ms: 3 * M * N * K * 1e-12 / (ms * 1e-3) 354 | # 3 = number of memory operations (2 read + 1 write) 355 | # M * N * K = number of elements per memory op 356 | # 1e-12 converts flops to Teraflops 357 | # 1e-3 converts milliseconds to seconds 358 | return perf(ms), perf(max_ms), perf(min_ms) 359 | 360 | if __name__ == "__main__": 361 | # always run unit-tests 362 | test_matmul_kernel(size=(1024, 1024)) 363 | 364 | # Only run benchmark if explicitly requested 365 | import sys 366 | if len(sys.argv) > 1 and sys.argv[1] == "--benchmark": 367 | benchmark.run(save_path='.', print_data=False) -------------------------------------------------------------------------------- /07_dropout/README.md: -------------------------------------------------------------------------------- 1 | check out the accompanying video: 2 | 3 | [![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/WWdpyNbrtBI/0.jpg)](https://www.youtube.com/watch?v=WWdpyNbrtBI) 4 | -------------------------------------------------------------------------------- /07_dropout/dropout.py: -------------------------------------------------------------------------------- 1 | """ 2 | This tutorial on low-memory dropout required the least editing of all the original Triton documentation tutorials 3 | 4 | What you'll learn: 5 | - Parallel pseudo-random number generation 6 | """ 7 | import torch 8 | import triton 9 | import triton.language as tl 10 | 11 | DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}') 12 | 13 | @triton.jit 14 | def _seeded_dropout( 15 | x_ptr, 16 | output_ptr, 17 | n_elements, 18 | p, # a float32 probability, so range [0,1] 19 | seed, # a single int32 20 | BLOCK_SIZE: tl.constexpr, 21 | ): 22 | # compute memory offsets of elements handled by this program 23 | pid = tl.program_id(axis=0) 24 | offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 25 | # load data from x 26 | mask = offsets < n_elements 27 | x = tl.load(x_ptr + offsets, mask=mask) # shape (BLOCK_SIZE) 28 | # the key insight is that we generate and use a mask entirely in SRAM without ever having to store it in DRAM. 29 | # this line generates uniformly distributed float32 values in [0, 1), given a seed and a block of int32 offsets 30 | random = tl.rand(seed, offsets) # shape (BLOCK_SIZE) 31 | # prune based on our desired probability threshold 32 | x_keep = random > p # values are either true or false 33 | output = tl.where(x_keep, x / (1 - p), 0.0) 34 | # where x_keep is True, the value is x/(1-p), and where False it's 0.0 35 | # write-back to DRAM 36 | tl.store(output_ptr + offsets, output, mask=mask) 37 | 38 | def seeded_dropout(x, p, seed): 39 | output = torch.empty_like(x) 40 | assert x.is_contiguous() 41 | n_elements = x.numel() 42 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) 43 | _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) 44 | return output 45 | 46 | x = torch.randn(size=(8, ), device=DEVICE) 47 | output1 = seeded_dropout(x, p=0.5, seed=123) 48 | output2 = seeded_dropout(x, p=0.5, seed=123) 49 | output3 = seeded_dropout(x, p=0.5, seed=512) 50 | print(x, output1, output2, output3, sep="\n") 51 | -------------------------------------------------------------------------------- /08_layernorm/README.md: -------------------------------------------------------------------------------- 1 | check out the accompanying video: 2 | 3 | [![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/lM9afcTmGDs/0.jpg)](https://www.youtube.com/watch?v=lM9afcTmGDs) 4 | -------------------------------------------------------------------------------- /08_layernorm/layer-norm-backward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/08_layernorm/layer-norm-backward.png -------------------------------------------------------------------------------- /08_layernorm/layernorm.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this lesson on LayerNorm we'll finally connect our kernels to PyTorch's backpropogation graph. Keep in mind 3 | this kernel is fast but it only works for normalizing vectors that fit within SRAM, so we've done a trade-off 4 | of better speed for worse generalize-abililty 5 | 6 | What you'll learn: 7 | - Writing a backward pass kernel 8 | - re-using intermediate values from a forward pass in the backward pass 9 | - Using torch.nn.functional to connect to PyTorch's backpropogation graph 10 | - Locks and atomic operations 11 | - How to use sequential kernels with intermediate tensors to complete a calculation 12 | more efficiently than one kernel alone could 13 | 14 | Recommended order to read the code in: 15 | Step 1 - unit test 16 | Step 2 - wrapper 17 | Step 3 - forward pass kernel 18 | Step 4 - backward pass kernels 19 | Step 5 - benchmark 20 | 21 | see original 22 | https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html#sphx-glr-getting-started-tutorials-05-layer-norm-py 23 | """ 24 | import torch 25 | import triton 26 | import triton.language as tl 27 | 28 | DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}') 29 | 30 | ######### Step 3 ######### 31 | @triton.jit 32 | def _layernorm_forward( 33 | x_ptr, y_ptr, # points to first entry of tensors of shape (M, N) 34 | w_ptr, b_ptr, # points to first entry of tensors of shape (N) 35 | mean_ptr, rstd_ptr, # points to first entry of tensors of shape (M) 36 | stride_M, # how much to increase the X pointer when moving through memory to the next row along x 37 | N, # number of columns in x, aka the tensor's embedding dimension 38 | eps, # small value used to avoid division by zero 39 | BLOCK_SIZE: tl.constexpr, 40 | ): 41 | # use the program id to move x_ptr and y_ptr to the row of X and Y they should compute 42 | row = tl.program_id(0) 43 | x_ptr += row * stride_M 44 | y_ptr += row * stride_M 45 | 46 | # Compute mean 47 | sum_accumulator = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 48 | for offset in range(0, N, BLOCK_SIZE): 49 | cols = offset + tl.arange(0, BLOCK_SIZE) 50 | # we're assuming in this over-simplified example that x is contiguous along N dimension, 51 | # so no need to multiply cols by a stride (which should just be equal to 1) 52 | x_ptrs = tl.load(x_ptr + cols, mask=cols < N, other=0.).to(tl.float32) # shape (BLOCK_SIZE) 53 | # x is fp16 but we want to accumulate in fp32 for increased accuracy 54 | # other=0.0 since zeros don't affect summation 55 | sum_accumulator += x_ptrs 56 | mean = tl.sum(sum_accumulator, axis=0) / N 57 | # shape goes from (BLOCK_SIZE) to (1) 58 | 59 | # Compute variance & reciprocal standard deviation 60 | acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 61 | for offset in range(0, N, BLOCK_SIZE): 62 | cols = offset + tl.arange(0, BLOCK_SIZE) 63 | x_ptrs = tl.load(x_ptr + cols, mask=cols < N, other=0.).to(tl.float32) 64 | diff = tl.where(cols < N, x_ptrs - mean, 0.) 65 | # mask here to prevent (0.0 - mean) at the masked out-of-bounds values 66 | acc += diff * diff 67 | # no need to mask operations here since 0 * 0 = 0 68 | var = tl.sum(acc, axis=0) / N # shape goes from (BLOCK_SIZE) to (1) 69 | rstd = 1 / tl.sqrt(var + eps) 70 | # eps is a small number (eg 1e-6) there to prevent division by 0 71 | 72 | # we save mean and rstd for the backward pass later 73 | tl.store(mean_ptr + row, mean) 74 | tl.store(rstd_ptr + row, rstd) 75 | 76 | # Normalize and apply linear transformation 77 | for offset in range(0, N, BLOCK_SIZE): 78 | # load input and parameters 79 | cols = offset + tl.arange(0, BLOCK_SIZE) 80 | mask = cols < N 81 | w_ptrs = tl.load(w_ptr + cols, mask=mask) 82 | b_ptrs = tl.load(b_ptr + cols, mask=mask) 83 | x_ptrs = tl.load(x_ptr + cols, mask=mask) 84 | 85 | # Normalize and apply linear transformation 86 | x_hat = (x_ptrs - mean) * rstd 87 | y = x_hat * w_ptrs + b_ptrs 88 | 89 | # Write output 90 | tl.store(y_ptr + cols, y, mask=mask) 91 | 92 | 93 | ######### Step 4 ######### 94 | @triton.jit 95 | def _layernorm_backward_dLdx( 96 | x_ptr, dLdx_ptr, dLdy_ptr, # pointers to first entries of tensors of shape (M, N) 97 | w_ptr, # pointers to first entries of tensors of shape (N) 98 | dLdw_intermediate_ptr, dLdb_intermediate_ptr, # pointers to first entries of tensors of shape (GROUP_SIZE, N) 99 | mean_ptr, rstd_ptr, # pointers to first entries of tensors of shape (M) 100 | locks_ptr, # pointers to first entry of tensor of shape (2 * GROUP_SIZE) 101 | stride, N, # dynamic variables determined at run-time 102 | GROUP_SIZE: tl.constexpr, BLOCK_SIZE_N: tl.constexpr # static variables determined at compile-time 103 | ): 104 | """ 105 | there's a weird grouping strategy being used here for the _dLdw and _dLdb that has visuals on the website 106 | https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html#sphx-glr-getting-started-tutorials-05-layer-norm-py 107 | the idea is that each pid is assigned some subset of rows (which are interleaved rather than next to each other) 108 | and it's that pid's job to accumulate the gradients over all of the rows it has been assigned 109 | then once each pid is done, in the next kernel we'll accumulate all of those individiual partial sums 110 | """ 111 | # Map the program id to the elements of x, dLdx, and dLdy it should compute. 112 | PID = tl.program_id(0) 113 | cols = tl.arange(0, BLOCK_SIZE_N) 114 | mask = cols < N # since we're holding an entire row within a single block 115 | x_ptr += PID * stride 116 | dLdx_ptr += PID * stride 117 | dLdy_ptr += PID * stride 118 | 119 | # Load data to SRAM 120 | # it's generally faster to do a bunch of loads before a bunch of flops rather than alternating back & forth 121 | x = tl.load(x_ptr + cols, mask=mask, other=0).to(tl.float32) # shape (BLOCK_SIZE_N) 122 | dLdy = tl.load(dLdy_ptr + cols, mask=mask, other=0).to(tl.float32) # shape (BLOCK_SIZE_N) 123 | w = tl.load(w_ptr + cols, mask=mask).to(tl.float32) # shape (BLOCK_SIZE_N) 124 | mean = tl.load(mean_ptr + PID) # shape (1) 125 | rstd = tl.load(rstd_ptr + PID) # shape (1) 126 | 127 | # Compute dLdx 128 | x_normalized = tl.where(mask, (x - mean) * rstd, 0.) # shape (BLOCK_SIZE_N) 129 | dydx_normed = tl.where(mask, w * dLdy, 0.) # shape (BLOCK_SIZE_N) 130 | # c1 and c2 are just intermediary labels; the names don't have any real meaning 131 | c1 = tl.sum(x_normalized * dydx_normed, axis=0) / N # shape (1) 132 | c2 = tl.sum(dydx_normed, axis=0) / N # shape (1) 133 | dLdx = (dydx_normed - (x_normalized * c1 + c2)) * rstd # shape (BLOCK_SIZE_N) 134 | 135 | # Write dLdx back to DRAM 136 | tl.store(dLdx_ptr + cols, dLdx, mask=mask) 137 | 138 | # Here we'll accumulate partial sums for dLdw and dLdb, meaning these are only the single rows of 139 | # the dLdw and dLdb gradients that this PID had the job of calculating 140 | dLdw_contribution = (dLdy * x_normalized).to(w.dtype) 141 | dLdb_contribution = (dLdy).to(w.dtype) 142 | 143 | """ 144 | Now we'd like to take our single contributions to dLdw and dLdb and somehow aggregate them with 145 | the portions that all of the other PIDs have calculated. 146 | The reason this aggregation has to happen is because the input x is of shape (M, N) while 147 | the weights and biases are of shape (N), meaning they receive gradients from all M rows of x, 148 | and this PID holds the gradient of one of those rows, but it's not easy to communicate 149 | that information between PIDs. 150 | The specific operation to do between all these rows is to sum them up, but we can't just 151 | naively tl.load(), then add our row, then tl.store() because all of the PIDs would do so 152 | at slightly different and completely unpredictable times, meaning all the tl.store() calls 153 | would overwrite each other. 154 | What we need first a way to ensure that only one PID at a time does the read, flop, and 155 | write while all the other PIDs others wait their turn. 156 | For this we can use what's called a lock, which is a way for us to ensure that only one 157 | PID can work on a given part of a tensor in DRAM at a time, AKA "locking" it. 158 | 159 | However, even that's not great because if only one PID can do work at a time and we have a 160 | lot of PIDs, then that's a whole lot of time leaving a large majority of the GPU sitting 161 | idle while they wait in line. 162 | What we need then is a way for GROUPS of PIDs to work sequentially with a lock while each 163 | group works in parallel to the others. 164 | This is why we created dLdw_intermediate and dLdb_intermediate, each of which has shape 165 | (GROUP_SIZE, N). 166 | We're going to assign every PID to a group, and then use our locking mechanism to ensure 167 | that only (M // GROUP_SIZE) PIDs attempt to wait around for their turn to work on a row 168 | of dLdw_intermediate and dLdb_intermediate at a time. 169 | In this way we've now gone from a sequential process with M steps to one with 170 | (M // GROUP_SIZE) steps. 171 | Then in the next kernel we'll take these (GROUP_SIZE, N) matrices and reduce them further 172 | down to the desired shape (N) matrices of dLdw and dLdb. 173 | 174 | But how do locks actually work? 175 | In this case we've got a tensor of shape (2 * GROUP_SIZE) and datatype int32 that's 176 | initialized to all zeroes. 177 | The first GROUP_SIZE entries are for holding an indicator of the state of that lock; 178 | 0 means unlocked and 1 means locked for the row of dLdw_intermediate and dLdb_intermediate 179 | that it corresponds to. 180 | The latter GROUP_SIZE entries are for holding an indicator of whether this lock has 181 | ever been used before, which is useful because we'll want to run different code if 182 | this PID happens to be the first one to add its values to dLdw_intermediate and dLdb_intermediate. 183 | To use the lock, we check if the entry corresponding to the group that our PID is 184 | in is locked or unlocked: 185 | - if it's locked, then we wait and check again in a moment until it's unlocked 186 | - if it's unlocked then we'll lock it, load the current value of our group's row of 187 | dLdw_intermediate and dLdb_intermediate, add our dLdw_contribution and dLdb_contribution 188 | respectively, write those new values back to DRAM, and finally unlock it 189 | """ 190 | # To start we figure out which lock ID corresponds to our PID and move our pointers accordingly 191 | lock_id = PID % GROUP_SIZE # so there are GROUP_SIZE number of locks 192 | # the first GROUP_SIZE entries in Lock hold the state of that lock in the entry locks_ptr for each pid 193 | locks_ptr += lock_id 194 | # the next GROUP_SIZE entries hold the count of how many accumulations have already happened on that lock 195 | count_ptr = locks_ptr + GROUP_SIZE 196 | # then we figre out which row of dLdw_intermediate and dLdb_intermediate we're meant to point to 197 | dLdw_intermediate_ptrs = dLdw_intermediate_ptr + lock_id * N + cols 198 | dLdb_intermediate_ptrs = dLdb_intermediate_ptr + lock_id * N + cols 199 | # we can use N in place of a .stride() here since these tensors are generated specifically for 200 | # this purpose and therefore guaranteed to be contiguous in memory 201 | 202 | # atomic_cas() compares the contents of a memory location with a given value and, 203 | # only if they are the same, modifies the contents of that memory location to a new given value. 204 | while tl.atomic_cas(locks_ptr, 0, 1) == 1: 205 | pass 206 | # so here, we're looking at the location locks_ptr_ptr and: 207 | # - If it's 0 (unlocked), change it to 1 (locked) and return 0 (False) to exit the while loop 208 | # - If it's 1 (already locked), leave it as 1 and return 1 (True) so that we stay in the while loop 209 | 210 | # then here we grab the number of times this lock position has already been accumulated into 211 | count = tl.load(count_ptr) # shape (1) 212 | if count == 0: # if this PID is the first one to access the lock 213 | # then no need to do the memory reads & flops; we can just set the row of dLdw_intermediate & 214 | # dLdB_intermediate equal to dLdw_contribution and dLdb_contribution (done below, outside the if/else) 215 | # atomic_xchg() sets the value at Count_ptr equal to 1 so the next PID knows we've been here 216 | tl.atomic_xchg(count_ptr, 1) 217 | else: # but if this is not the first pid in the accumulation process, 218 | # then we've actually gotta accumulate by grabbing the values already there in 219 | # DRAM and adding them to the rows of dLdw_contribution and dLdb_contribution that our PID generated 220 | dLdw_contribution += tl.load(dLdw_intermediate_ptrs, mask=mask) # we load and add in one step (+= operator) 221 | dLdb_contribution += tl.load(dLdb_intermediate_ptrs, mask=mask) # so as not to consume unnecessary SRAM 222 | 223 | # now we get to store our accumulated values back to DRAM 224 | tl.store(dLdw_intermediate_ptrs, dLdw_contribution, mask=mask) 225 | tl.store(dLdb_intermediate_ptrs, dLdb_contribution, mask=mask) 226 | 227 | # and finally release the lock so that any pids waiting in their while loop can take their turn 228 | tl.atomic_xchg(locks_ptr, 0) # we set the value at our lock equal to 0 229 | # whichever pid gets to the 0 value first with its .atomic_cas() will get to go next 230 | 231 | @triton.jit 232 | def _layernorm_backward_dLdw_dLdb( 233 | dLdw_intermediate_ptr, dLdb_intermediate_ptr, # pointers to first entries of tensors of shape (GROUP_SIZE, N) 234 | dLdw_ptr, dLdb_ptr, # pointers to first entries of tensors of shape (N) 235 | GROUP_SIZE, N, 236 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, 237 | ): 238 | # our PIDs are split up within the N dimension 239 | PID = tl.program_id(0) 240 | col_ptrs = PID * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 241 | 242 | # here is where we'll accumulate the stored group values into as we read them 243 | dLdw_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 244 | dLdb_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 245 | 246 | # Iterate through the rows of _dLdw and _dLdb to sum them up 247 | for i in range(0, GROUP_SIZE, BLOCK_SIZE_M): 248 | row_ptrs = i + tl.arange(0, BLOCK_SIZE_M) 249 | mask = (row_ptrs[:, None] < GROUP_SIZE) & (col_ptrs[None, :] < N) 250 | offsets = row_ptrs[:, None] * N + col_ptrs[None, :] 251 | 252 | # load the partial sums from all that group locking nonsense earlier and add them to our final output 253 | dLdw_acc += tl.load(dLdw_intermediate_ptr + offsets, mask=mask, other=0.) 254 | dLdb_acc += tl.load(dLdb_intermediate_ptr + offsets, mask=mask, other=0.) 255 | # masked-out values get set to 0 so they don't affect sum 256 | 257 | # sum along our BLOCK_SIZE_M dimension to get the final BLOCK_SIZE_N chunk of dLdw & dLdb that this 258 | # PID was assigned to 259 | sum_dLdw = tl.sum(dLdw_acc, axis=0) # shape (BLOCK_SIZE_N) 260 | sum_dLdb = tl.sum(dLdb_acc, axis=0) 261 | 262 | # Write the final sum to the output. 263 | tl.store(dLdw_ptr + col_ptrs, sum_dLdw, mask=col_ptrs < N) 264 | tl.store(dLdb_ptr + col_ptrs, sum_dLdb, mask=col_ptrs < N) 265 | 266 | 267 | ######### Step 2 ######### 268 | class LayerNorm(torch.autograd.Function): 269 | """ 270 | We can implement our own custom functions that play nice with PyTorch's autograd graph 271 | by subclassing torch.autograd.Function and implementing the forward and backward passes 272 | with static methods forward() and backward(). 273 | """ 274 | 275 | @staticmethod 276 | def forward( 277 | ctx, # ctx is an object we use to store info that'll be used later in the backward pass 278 | # it doesn't actually get inputted when using .forward(), rather it's handled by the parent class 279 | x, # the input; however many dimensions will be turned into a matrix of shape (M, N) 280 | normalized_shape, # this never gets used, but putting it here keeps arguments consistent with pytorch which does use it 281 | weight, # so this LayerNorm class is in fact acting as a function rather than a module since w&b are stored elsewhere 282 | bias, # weight and bias both of shape (x.shape[-1]) 283 | eps # very small value (eg 1e-6) to prevent division by zero in the reciprocal standard deviation calculation 284 | ): 285 | # reshape to 2D tensor and grab said shapes 286 | M, N = x.reshape(-1, x.shape[-1]).shape 287 | # allocate intermediary tensors and final output 288 | mean = torch.empty((M, ), dtype=torch.float32, device=x.device) 289 | rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) 290 | y = torch.empty_like(x) 291 | 292 | # if there's less than 64KB per feature then we can use our fused kernel 293 | MAX_FUSED_SIZE = 65536 // x.element_size() 294 | # .element_size() returns number of bytes per a single entry 295 | # fp32 element_size = 4, fp16 element_size = 2, fp8 element_size = 1 296 | # so this is used to calculate how many elements can fit within a 64KB block of memory 297 | # 64KB is a heuristic for the smallest possible SRAM size our GPU is likely to have; it'd be beter 298 | # if we got our GPU's actual SRAM size and used that (look back at lesson 5 for how to do this) 299 | BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) 300 | # we'll either define block_size by 301 | # - the maximum amount of entries that a 64kb block of memory can hold or 302 | # - the smallest size that can hold the dimension N 303 | if N > BLOCK_SIZE: # so if we used MAX_FUSED_SIZE 304 | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") 305 | # in order to support feature_dim bigger than SRAM size we'd have to parallelize within feature_dim 306 | 307 | # heuristics for number of warps 308 | num_warps = min(max(BLOCK_SIZE // 256, 1), 8) 309 | 310 | _layernorm_forward[(M, )]( # grid parallelizes using a separate program for each non-embedding dimension entry 311 | x, y, weight, bias, 312 | mean, rstd, # pre-allocated intermediary useful tensors 313 | x.stride(0), # number of memory items needed to move forward to hit the next row of x (should be = N if x is contiguous) 314 | N, # model embedding dimension will be used for hardware mask 315 | eps, # small number to prevent division by 0 in reciprocal standard deviation calculation 316 | # meta-paramaters 317 | BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, 318 | ) 319 | 320 | # ctx is an object that can be used to stash information that's useful for the backward pass computation 321 | # You can cache arbitrary objects using the ctx.save_for_backward method 322 | ctx.save_for_backward(x, weight, bias, mean, rstd) 323 | # save_for_backward is mostly for tensors, whereas meta-parameters get saved as individual entries in the object 324 | ctx.BLOCK_SIZE = BLOCK_SIZE 325 | ctx.num_warps = num_warps 326 | ctx.eps = eps 327 | 328 | # and finally return our output 329 | return y 330 | 331 | @staticmethod 332 | def backward( 333 | ctx, # when calling .backward() we don't actually input ctx; rather it is handled by torch.autograd.Function 334 | dLdy # partial derivative of the loss with respect to y 335 | ): 336 | """ 337 | In the backward pass we receive a Tensor containing the gradient of the loss with respect to the output, and 338 | we need to compute the gradient of the loss with respect to the input(s). 339 | """ 340 | # fetcing the original inputs, intermediary tensors, and meta-parameters 341 | x, w, b, mean, rstd = ctx.saved_tensors 342 | M, N = x.reshape(-1, x.shape[-1]).shape 343 | 344 | # allocate gradients of original inputs 345 | dLdw = torch.empty((N, ), dtype=w.dtype, device=w.device) 346 | dLdb = torch.empty((N, ), dtype=w.dtype, device=w.device) 347 | dLdx = torch.empty_like(dLdy) 348 | 349 | # heuristics for amount of parallel reduction stream for dLdw & dLdB; explained a bit below but mostly in the kernel 350 | GROUP_SIZE = 64 351 | if N <= 8192: GROUP_SIZE = 96 352 | if N <= 4096: GROUP_SIZE = 128 353 | if N <= 1024: GROUP_SIZE = 256 354 | 355 | # Rather than computing all three gradients immediately in one kernel, we're actually going to call two kernels. 356 | # The first will compute dLdx and intermediary steps on the way to dLdw and dLdb; we'll call these _dLdw and _dLdb 357 | dLdw_intermediate = torch.zeros((GROUP_SIZE, N), dtype=x.dtype, device=w.device) 358 | dLdb_intermediate = torch.zeros((GROUP_SIZE, N), dtype=x.dtype, device=w.device) 359 | 360 | # When multiple programs want to edit the same entries in a tensor stored in DRAM, we need a way to prevent them from 361 | # doing so out of order and from overwriting each other's work. For that we can use a lock, which is another tensor 362 | # with the job of keeping track of which entries are currently being worked on by a different program and which are 363 | # free to be edited 364 | locks = torch.zeros(2 * GROUP_SIZE, dtype=torch.int32, device=w.device) 365 | # the first GROUP_SIZE entries in our locks tensor will be used to determine whether a lock is on or off 366 | # (AKA whether the important tensor is occupied or available) 367 | # the second will keep track of whether the lock has been used before, since in the kernel we will need to 368 | # treat the first use differently from all successive uses 369 | 370 | # enqueue kernel that uses forward pass heuristics to calculate both dLdx and the partial contributions to dLdw and dLdb 371 | _layernorm_backward_dLdx[(M, )]( # parallelize across rows 372 | x, dLdx, dLdy, 373 | w, dLdw_intermediate, dLdb_intermediate, 374 | mean, rstd, 375 | locks, 376 | x.stride(0), N, # dynamic run-time variables 377 | GROUP_SIZE = GROUP_SIZE, BLOCK_SIZE_N = ctx.BLOCK_SIZE, num_warps = ctx.num_warps) # static compile-time variables 378 | 379 | # Now we'll do a seperate call to the second kernel, who's job is to accumulate dLdw_intermediate into dLdw and 380 | # dLdb_intermediate into dLdb. 381 | # We do this in a separate kernel since this final set of operations requires 382 | # 1) fewer pids as opposed to the previous kernel which called M pids 383 | # and 2) dLdw_intermediate and dLdb_intermediate to be completed before it can begin 384 | grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] # parallelize within rows 385 | _layernorm_backward_dLdw_dLdb[grid]( 386 | dLdw_intermediate, dLdb_intermediate, dLdw, dLdb, 387 | min(GROUP_SIZE, M), N, # run-time integer values 388 | BLOCK_SIZE_M=32, BLOCK_SIZE_N=128, # heuristically chosen compile-time values 389 | ) 390 | 391 | # pytorch expects .backward() to return a value for every single input into .forward() in order (except for ctx) 392 | # so that it can keep track for the backpropogation graph 393 | return dLdx, None, dLdw, dLdb, None 394 | # the None values correspond to the inputs of .forward() that don't need gradients (order matters!) 395 | 396 | # this line just creates a reference to the apply function of LayerNorm rather than having it act like an object 397 | layernorm = LayerNorm.apply 398 | 399 | 400 | ######### Step 1 ######### 401 | def test_layernorm_kernel(M, N, dtype, eps=1e-5, device=DEVICE): 402 | # create data 403 | x = -2.3 + 0.5 * torch.randn((M, N), dtype=dtype, device=device) 404 | weight = torch.rand((N, ), dtype=dtype, device=device, requires_grad=True) 405 | bias = torch.rand((N, ), dtype=dtype, device=device, requires_grad=True) 406 | dLdy = .1 * torch.randn_like(x) 407 | # setting requires_grad to True here instead of x's initial definition means the graph doesn't have to move through 408 | # the -2.3 and 0.5 operations. That's not a big deal here for testing but if we didn't do it in the benchmark then 409 | # those results would be confounded by the kernels pytorch implements for entry-wise multiplication and addition 410 | x.requires_grad_(True) 411 | # forward pass 412 | y_tri = layernorm(x, (N,), weight, bias, eps) 413 | y_ref = torch.nn.functional.layer_norm(x, (N,), weight, bias, eps).to(dtype) 414 | torch.testing.assert_close(y_tri, y_ref, atol=1e-2, rtol=0) 415 | print("Passed fwd") 416 | # backward pass (triton) 417 | y_tri.backward(dLdy, retain_graph=True) # this writes directly to x.grad, weight.grad and bias.grad 418 | # retain_graph is used to control whether the computation graph should be kept in memory after the backward pass. 419 | # Setting retain_graph=True allows you to perform multiple backward passes on the same graph, but it can increase 420 | # memory usage, so it's generally recommended to use it only when necessary for a scenario like this 421 | # This detaches our gradients so that we can run pytorch on the same input tensors and test against each other later 422 | dLdx_tri, dLdw_tri, dLdb_tri = [_.grad.clone() for _ in [x, weight, bias]] 423 | # when denoting derivatives, it's always with respect to the loss function L and we use "d" instead of "partial" 424 | # because it's more concise albiet bad practice from a mathematician's perspective 425 | x.grad, weight.grad, bias.grad = None, None, None 426 | # backward pass (torch) 427 | y_ref.backward(dLdy, retain_graph=True) 428 | dLdx_ref, dLdw_ref, dLdb_ref = [_.grad.clone() for _ in [x, weight, bias]] 429 | # compare 430 | torch.testing.assert_close(dLdx_tri, dLdx_ref, atol=1e-2, rtol=0) 431 | torch.testing.assert_close(dLdb_tri, dLdb_ref, atol=1e-2, rtol=0) 432 | torch.testing.assert_close(dLdw_tri, dLdw_ref, atol=1e-2, rtol=0) 433 | # rtol=0 means we don't use relative tolerance 434 | print("Passed bwd") 435 | 436 | 437 | ######### Step 5 ######### 438 | @triton.testing.perf_report( 439 | triton.testing.Benchmark( 440 | x_names=['N'], 441 | x_vals=[512 * i for i in range(2, 32)], # if you increase past 32 the kernel will break since features become larger than 64kb 442 | line_arg='provider', 443 | line_vals=['triton', 'torch'], 444 | line_names=['Triton', 'Torch'], 445 | styles=[('blue', '-'), ('green', '-')], 446 | ylabel='GB/s', 447 | plot_name='layer-norm-backward', 448 | args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}, # so we're actually only benchmarking the backward pass 449 | )) 450 | def benchmark(M, N, dtype, provider, mode='backward', eps=1e-5, device=DEVICE): 451 | # create data 452 | x_shape = (M, N) 453 | w_shape = (N, ) 454 | weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) 455 | bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) 456 | x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device)#, requires_grad=True) 457 | dLdy = .1 * torch.randn_like(x) 458 | x.requires_grad_(True) 459 | # setting this here instead of x's initial definition means the graph doesn't have to move through the -2.3 and 0.5 operations 460 | quantiles = [0.5, 0.05, 0.95] 461 | 462 | def y_fwd(): 463 | if provider == "triton": 464 | return layernorm(x, w_shape, weight, bias, eps) 465 | if provider == "torch": 466 | return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) 467 | 468 | # forward pass 469 | if mode == 'forward': 470 | gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) 471 | ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) 472 | # backward pass 473 | if mode == 'backward': 474 | y = y_fwd() 475 | gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) # noqa: F811, E704 476 | ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dLdy, retain_graph=True), quantiles=quantiles, 477 | grad_to_none=[x], rep=500) 478 | return gbps(ms), gbps(max_ms), gbps(min_ms) 479 | 480 | 481 | if __name__ == "__main__": 482 | # always run unit-tests 483 | test_layernorm_kernel(1151, 8192, torch.float16) 484 | 485 | # Only run benchmark if explicitly requested 486 | import sys 487 | if len(sys.argv) > 1 and sys.argv[1] == "--benchmark": 488 | benchmark.run(save_path='.', print_data=False) -------------------------------------------------------------------------------- /09_flash_attention/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/09_flash_attention/.DS_Store -------------------------------------------------------------------------------- /09_flash_attention/Note Jan 20, 2026.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/09_flash_attention/Note Jan 20, 2026.pdf -------------------------------------------------------------------------------- /09_flash_attention/README.md: -------------------------------------------------------------------------------- 1 | check out the accompanying videos: 2 | 3 | [![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/6ap2QVWKFH0/0.jpg)](https://www.youtube.com/watch?v=6ap2QVWKFH0) 4 | 5 | [![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/cygYBmB5ow8/0.jpg)](https://www.youtube.com/watch?v=cygYBmB5ow8) 6 | 7 | -------------------------------------------------------------------------------- /09_flash_attention/attention-performance-bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/09_flash_attention/attention-performance-bwd.png -------------------------------------------------------------------------------- /09_flash_attention/attention-performance-fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/09_flash_attention/attention-performance-fwd.png -------------------------------------------------------------------------------- /09_flash_attention/flash_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | this implementation of flash-attention only supports a causal mask, no other masks or lack of a mask 3 | the forward pass is based primarily on the pseudocode from the two original papers 4 | https://arxiv.org/abs/2205.14135 5 | https://arxiv.org/abs/2307.08691 6 | and the backward passs is based primarily on the triton documentation implementation since it's 7 | significantly faster than the pseudocode from the original papers 8 | https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html#sphx-glr-getting-started-tutorials-06-fused-attention-py 9 | 10 | What you'll learn: 11 | - calling sub-kernels 12 | - using the faster tl.exp2() instead of tl.exp() 13 | - Flash attention's specific parallelization strategy 14 | - tl.static_assert() 15 | - multi-axis launch grids & the importance of launch grid axis ordering 16 | - when to pre-compute certain values in a separate kernel 17 | - using approximate constant values rather than calculating 18 | 19 | Notable features this kernel does NOT include: 20 | - suppport for datatypes other than fp32 and mixed precision 21 | - dropout 22 | - likely more I'm forgetting 23 | 24 | Also note, the benchmarking is setup but not used (lists have single entries) 25 | So if you wanted even better performance you could re-enable autotuning 26 | """ 27 | 28 | import torch 29 | import triton 30 | import triton.language as tl 31 | import math 32 | 33 | DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}') 34 | 35 | #import os 36 | #os.environ["TRITON_INTERPRET"] = "1" 37 | 38 | @triton.jit 39 | def _attn_fwd_inner( 40 | Q, O, L, M, 41 | K_ptr, V_ptr, 42 | K_T_offsets, V_offsets, 43 | block_index_QO, 44 | softmax_scale, 45 | stride_K_N, stride_V_N, 46 | BLOCK_SIZE_QO: tl.constexpr, BLOCK_SIZE_KV: tl.constexpr, 47 | DIAGONAL: tl.constexpr, 48 | offsets_QO_N: tl.constexpr, offsets_KV_N: tl.constexpr, 49 | N: tl.constexpr, Dh: tl.constexpr, 50 | ): 51 | """ 52 | arrows indicate direction of this pid's for loop; each arrow is a different PID 53 | N of K & V 54 | ------------> 55 | ------------> 56 | N of Q ------------> 57 | ------------> 58 | ------------> 59 | but if we actually take into account the causal mask then really it's more like 60 | N of K & V 61 | > 62 | ---> 63 | N of Q ------> 64 | ---------> 65 | ------------> 66 | and to get even more accurate, we do the diagonal in our second call of this inner kernel 67 | N of K & V 68 | x 69 | x 70 | N of Q x 71 | x 72 | x 73 | and then the first call gets all the parts below the diagonal 74 | N of K & V 75 | 76 | --> 77 | N of Q -----> 78 | --------> 79 | -----------> 80 | """ 81 | if DIAGONAL: 82 | # Used only for the blocks along the diagonal in which there is transition between non-masked and masked keys 83 | lo = block_index_QO * BLOCK_SIZE_QO 84 | hi = (block_index_QO + 1) * BLOCK_SIZE_QO 85 | # let the compiler know lo is a muliple of BLOCK_SIZE_QO to speed things up 86 | lo = tl.multiple_of(lo, BLOCK_SIZE_QO) # TODO not sure why this doesn't also help with hi; it prolly does 87 | else: 88 | # this part is for any blocks in the causal mask below the diagonal 89 | lo, hi = 0, block_index_QO * BLOCK_SIZE_QO 90 | 91 | K_T_offsets += lo * stride_K_N 92 | V_offsets += lo * stride_V_N 93 | offsets_KV_N += lo 94 | 95 | # loop over blocks along the N dimension of K & V and update the O accumulator while doing so 96 | for start_KV in range(lo, hi, BLOCK_SIZE_KV): 97 | # Just let the compiler know that start_KV is a multiple of BLOCK_SIZE_KV, so the compiler can do optimizations 98 | start_KV = tl.multiple_of(start_KV, BLOCK_SIZE_KV) 99 | # when in doubt, i guess use tl.multiple_of() for any dynamic variable (as opposed to static variables) 100 | 101 | # compute (Q @ K^T) / sqrt(Dh) 102 | mask_KV_N = offsets_KV_N < N 103 | K_T = tl.load(K_ptr + K_T_offsets, mask=mask_KV_N[None, :], other=0.) # shape (Dh, BLOCK_SIZE_KV) 104 | # sequence mask sets non-existent tokens in the block past N to zero vectors 105 | S = tl.dot(Q, K_T) * softmax_scale # shape (BLOCK_SIZE_QO, BLOCK_SIZE_KV) 106 | # the masked tokens create columns & rows of zeros hugging the bottom and right edges of S 107 | 108 | if DIAGONAL: # if we're currently on a block containing the diagonal 109 | # the causal mask is True on the lower-triangular including the diagonal 110 | causal_mask = offsets_QO_N[:, None] >= (offsets_KV_N[None, :]) 111 | # causal mask addition sets upper-triangular values (excluding diagonal) to -inf 112 | S += tl.where(causal_mask, 0, -1.0e6) # shape (BLOCK_SIZE_QO, BLOCK_SIZE_KV) 113 | # notice that the masked out tokens previously hugging the right edge of S have mostly been replaced with -inf 114 | # and the masked out tokens hugging the bottom edge are still mostly 0's but with some -infs towards the 115 | # right edge of each of them, except for the last one which is only 0's 116 | 117 | # find the max values of the new block and compare them to those of all previous blocks to get an update 118 | M_new = tl.maximum(M, tl.max(S, axis=1)) # shape is (BLOCK_SIZE_QO) 119 | # masked token rows at the bottom will return a maximum value of 0 since their only values are 0 and -inf 120 | # adjust S block for safe softmax 121 | S -= M_new[:, None] # shape (BLOCK_SIZE_QO, BLOCK_SIZE_KV) 122 | # in the case of masked non-existent tokens that means subtracting by 0 so no difference 123 | 124 | # Compute the exponential of each safe dot product, which will be the numerator of our softmax 125 | P = tl.exp2(S) # shape (BLOCK_SIZE_QO, BLOCK_SIZE_KV) 126 | # we're using base 2 instead of base e because it's faster and softmax is invariant to the change, 127 | # however it does make the derivative in the backward pass a bit more complicated. 128 | # for the masked non-existent tokens as the bottom that will be 2^0=1 for all those entries 129 | 130 | # Compute the sum by rows of the attention scores 131 | L_new = tl.sum(P, axis=1) # shape (BLOCK_SIZE_QO) 132 | # for the masked non-existent tokens we're summing a bunch of 1's with some -infs, except for the 133 | # very bottom one which is just 1's and therefore its sum is the largest being equal to BLOCK_SIZE_QO 134 | # This alpha is the correction factor we'll use on the previous L 135 | alpha = tl.exp2(M - M_new) # shape (BLOCK_SIZE_Q) 136 | # for the masked non-existent tokens that's just 2^(1-1)=2^0=1=alpha_i so no correction 137 | # Apply the correction factor to the previous L and add the new L 138 | L = L * alpha + L_new # shape (BLOCK_SIZE_QO) 139 | # for each of the masked non-existent tokens they approach N for their entry L_i 140 | 141 | # This computes O = P @ V + O * alpha 142 | V = tl.load(V_ptr + V_offsets, mask=mask_KV_N[:, None], other=0.) # shape (BLOCK_SIZE_KV, Dh) 143 | # adjusts previous values based on potential new max 144 | O = O * alpha[:, None] # shape (BLOCK_SIZE_QO, Dh) 145 | # accumulated P and V block dot product into O 146 | O = tl.dot(P, V, acc=O) # shape (BLOCK_SIZE_QO, Dh) 147 | # notice we're doing this V projection before we've actually divided by our softmax denominator l_i 148 | # which is possible because in this context the two operations are associative 149 | # acc tells triton to accumulate the values into O_block 150 | # the masked non-existent tokens are a bunch of 1's in the bottom rows of P and 0's in the bottom 151 | # rows of V. This matmul leaves O with a bunch of incorrect values in its bottom rows, but they 152 | # will get ignored later when we store O with a proper mask 153 | 154 | # sets old max equal to new max, ready to be used by next iteration of for loop 155 | M = M_new 156 | 157 | # iterate pointers 158 | K_T_offsets += BLOCK_SIZE_KV * stride_K_N 159 | V_offsets += BLOCK_SIZE_KV * stride_V_N 160 | offsets_KV_N += BLOCK_SIZE_KV 161 | 162 | return O, L, M # we save these three specifically for use later in the backward pass 163 | 164 | 165 | @triton.autotune( # decorator figures out what meta-parameters will be most efficient 166 | [ 167 | triton.Config( 168 | {"BLOCK_SIZE_QO": BLOCK_SIZE_QO, "BLOCK_SIZE_KV": BLOCK_SIZE_KV}, 169 | num_stages=num_stages, num_warps=num_warps, 170 | ) 171 | for BLOCK_SIZE_QO in [16]#, 32, 64, 128] 172 | for BLOCK_SIZE_KV in [16]#, 32, 64, 128] 173 | for num_stages in [3]#, 5, 7] 174 | for num_warps in [4]#, 8, 16] 175 | ], 176 | key=["Dh"], 177 | ) 178 | @triton.jit 179 | def attn_fwd( 180 | Q_ptr, K_ptr, V_ptr, # each shape (B, H, N, Dh) 181 | O_ptr, # shape (B, H, N, Dh). where we store the final output 182 | LSE_ptr, # shape (B, H, N). here we first store the max values of each row & later the logsumexp trick 183 | softmax_scale, 184 | stride_Q_B, stride_Q_H, stride_Q_N, stride_Q_Dh, 185 | stride_K_B, stride_K_H, stride_K_N, stride_K_Dh, 186 | stride_V_B, stride_V_H, stride_V_N, stride_V_Dh, 187 | stride_O_B, stride_O_H, stride_O_N, stride_O_Dh, 188 | stride_LSE_B, stride_LSE_H, stride_LSE_N, 189 | B, # unlike other tensor dimensions, batch size can be more flexible for runtime differences 190 | # meta-parameters (decided at compile-time) 191 | H: tl.constexpr, N: tl.constexpr, 192 | Dh: tl.constexpr, # should always be a power of 2 193 | BLOCK_SIZE_QO: tl.constexpr, BLOCK_SIZE_KV: tl.constexpr, 194 | ): 195 | # in order to use tl.exp2 later isntead of tl.exp (the former is faster) we need to scale our softmax scale by ln2 196 | rln2: tl.constexpr = 1.4426950408889634 197 | softmax_scale *= rln2 198 | """ 199 | let's show that e^x = 2^(x * rln2) 200 | e^x = (2^(log_2(e)))^x since a = 2^log_2(a) 201 | then using the power rule 202 | (2^(log_2(e)))^x = 2^(x * log_2(e)) 203 | fundamental property of logarithm is log_2(e) = 1/log_e(2) 204 | therefore e^x = 2^(x * 1/log_e(2)) 205 | AKA e^x = 2^(x * rln2) 206 | then later in the backward pass we'll have to remember to account for this in the gradient 207 | """ 208 | 209 | # as opposed to regular assert, static_assert occurs at compile-time 210 | tl.static_assert(BLOCK_SIZE_KV <= Dh) 211 | # I'm not sure why the original triton docs tutorial had this assertion, but it doesn't hurt anything 212 | 213 | # This indicates which block in the sequence length to process 214 | block_index_QO = tl.program_id(0) 215 | # This indicates which head and batch to process. Each program is associated with a single head of a single batch 216 | index_BH = tl.program_id(1) 217 | # This indicates which batch this program is associated with (each batch has H heads) 218 | index_B = index_BH // H 219 | # This indicates the position of the head in the batch 220 | index_H = index_BH % H 221 | 222 | # This allows to get the shape (N, Dh) block in the Q, K, V, and O by indexing it by batch and head 223 | Q_ptr += index_B * stride_Q_B + index_H * stride_Q_H 224 | K_ptr += index_B * stride_K_B + index_H * stride_K_H 225 | V_ptr += index_B * stride_V_B + index_H * stride_V_H 226 | O_ptr += index_B * stride_O_B + index_H * stride_O_H 227 | 228 | # Offsets for N are split by pids but for Dh we keep the whole thing in SRAM. 229 | offsets_QO_N = block_index_QO * BLOCK_SIZE_QO + tl.arange(0, BLOCK_SIZE_QO) 230 | offsets_KV_N = tl.arange(0, BLOCK_SIZE_KV) 231 | offsets_Dh = tl.arange(0, Dh) 232 | 233 | # create offsets specific to each tensor 234 | Q_offsets = (offsets_QO_N[:, None] * stride_Q_N + offsets_Dh[None, :] * stride_Q_Dh) 235 | # shape (BLOCK_SIZE_QO, Dh) 236 | # we transpose K while loading it (as opposed to writing a whole separate kernel for transpose) 237 | K_T_offsets = (offsets_Dh[:, None] * stride_K_Dh + offsets_KV_N[None, :] * stride_K_N) 238 | # shape (Dh, BLOCK_SIZE_KV) 239 | V_offsets = (offsets_KV_N[:, None] * stride_V_N + offsets_Dh[None, :] * stride_V_Dh) 240 | # shape (BLOCK_SIZE_KV, Dh) 241 | 242 | # load the block of Q that this PID will use; it will stay in SRAM throughout the inner loop 243 | mask_QO_N = offsets_QO_N < N 244 | Q = tl.load(Q_ptr + Q_offsets, mask=mask_QO_N[:, None], other=0.) # shape (BLOCK_SIZE_QO, Dh) 245 | # sequence mask sets non-existent tokens in the block past N to zero vectors 246 | 247 | ## pre-allocate tensors for storing intermediate & output values 248 | # the running maximum. We have one entry for each query in the block we're currently working on 249 | M = tl.full(shape=[BLOCK_SIZE_QO], value=-1e6, dtype=tl.float32) # large negative number will get ignored by tl.max() 250 | # the running sum. We have one entry for each query (since we sum the attention scores by rows) 251 | L = tl.full(shape=[BLOCK_SIZE_QO], value=1.0, dtype=tl.float32) # 1 is because we'll be using exponentials and e^0=1 252 | # the accumulator for the output, which is a group of rows of the O matrix 253 | O = tl.zeros([BLOCK_SIZE_QO, Dh], dtype=tl.float32) 254 | 255 | # calculate attention for dense blocks (those where the mask if full of 1's). 256 | # This step runs for the blocks below the diagonal in causal attention 257 | O, L, M = _attn_fwd_inner( 258 | Q, O, L, M, 259 | K_ptr, V_ptr, 260 | K_T_offsets, V_offsets, 261 | block_index_QO, 262 | softmax_scale, 263 | stride_K_N, stride_V_N, 264 | BLOCK_SIZE_QO, BLOCK_SIZE_KV, 265 | False, # blocks on the DIAGONAL get special treatment if this is set to true; we use it below 266 | offsets_QO_N, offsets_KV_N, 267 | N, Dh, 268 | ) 269 | 270 | # This step runs for the blocks on the diagonal in the causal attention mask 271 | O, L, M = _attn_fwd_inner( 272 | Q, O, L, M, 273 | K_ptr, V_ptr, 274 | K_T_offsets, V_offsets, 275 | block_index_QO, 276 | softmax_scale, 277 | stride_K_N, stride_V_N, 278 | BLOCK_SIZE_QO, BLOCK_SIZE_KV, 279 | True, # blocks on the diagonal get special masking treatment 280 | offsets_QO_N, offsets_KV_N, 281 | N, Dh, 282 | ) 283 | 284 | # finally dividing by the denominator of our softmax. 285 | # notice we've already multiplied by V to get O, so this was done out-of-order from naive softmax implementations 286 | O = O / L[:, None] # shapes (BLOCK_SIZE_QO, Dh) / (BLOCK_SIZE_QO, 1) = (BLOCK_SIZE_QO, Dh) 287 | # we can do this out-of-order since the matmul (the tl.dot in _attn_fwd_inner) and this entry-wise division 288 | # are associative. matmul and entry-wise-ops are not normally, but at this level of granularity it's no longer 289 | # actually a matmul but instead individual dot-products 290 | # the masked non-existent tokens are a bunch of meaningless values in the bottom rows of O and generally 291 | # roughly equal to N in the bottom entries of L. Dividing the former by the latter isn't going to break 292 | # anything and we'll mask them out later when storing 293 | 294 | # This is needed to compute the logsumexp (LSE) for the backwards pass. basically instead of saving the maxes 295 | # and the sums separately, we save them together which still works thanks to exponential arithmetic 296 | LSE = M + tl.math.log2(L) # shape (BLOCK_SIZE_QO) 297 | # L was composed using the sum & exp operations in _attn_fwd_inner() 298 | # this will work because softmax(x_i) = exp(x_i - m_i) / l_i 299 | # = exp(x_i - m_i) / exp(log(l_i)) 300 | # = exp(x_i - m_i - log(l_i)) 301 | # the masked non-existent tokens are a bunch of 0's in the bottom entries of M and a bunch of values roughly 302 | # equal to N in the bottom entries of L. So in LSE they'll be a bunch of log_2(N) entries at the bottom 303 | # that we of course don't plan to use 304 | 305 | ## storing it all back to DRAM 306 | LSE_offsets = index_BH * stride_LSE_H + offsets_QO_N 307 | LSE_mask = block_index_QO * BLOCK_SIZE_QO + tl.arange(0, BLOCK_SIZE_QO) < N 308 | tl.store(LSE_ptr + LSE_offsets, LSE, mask=LSE_mask) # shape (BLOCK_SIZE_QO) 309 | # the mask prevents us from saving the useless log_2(n) values at the bottom of LSE 310 | O_offsets = (offsets_QO_N[:, None] * stride_O_N + offsets_Dh[None, :] * stride_O_Dh) 311 | tl.store(O_ptr + O_offsets, O, mask=mask_QO_N[:, None]) # shape (BLOCK_SIZE_Q, Dh) 312 | # the mask prevents us from saving the useless values at the bottom of O corresponding to non-existent tokens 313 | 314 | 315 | @triton.autotune( 316 | [ 317 | triton.Config({"PRE_BLOCK_SIZE_ROW": PRE_BLOCK_SIZE_ROW}, 318 | num_stages=num_stages, num_warps=num_warps,) 319 | for PRE_BLOCK_SIZE_ROW in [32]#, 64, 128, 256] 320 | for num_stages in [3]#, 5, 7] 321 | for num_warps in [4]#, 8, 16] 322 | ], 323 | key=["Dh"], 324 | ) 325 | @triton.jit 326 | def attn_backward_preprocess( 327 | O_ptr, dLdO_ptr, Delta_ptr, 328 | stride_O_B, stride_O_H, stride_O_N, stride_O_Dh, 329 | stride_dLdO_B, stride_dLdO_H, stride_dLdO_N, stride_dLdO_Dh, 330 | stride_Delta_B, stride_Delta_H, stride_Delta_N, 331 | N, Dh: tl.constexpr, 332 | PRE_BLOCK_SIZE_ROW: tl.constexpr, 333 | ): 334 | """the job of this kernel is to pre-compute Delta since Delta is used by both of the following two kernels""" 335 | index_BH = tl.program_id(1) # B * H number of pids 336 | row = tl.program_id(0) # N / BLOCK_SIZE_ROW number of pids 337 | 338 | row_offsets = row * PRE_BLOCK_SIZE_ROW + tl.arange(0, PRE_BLOCK_SIZE_ROW) 339 | col_offsets = tl.arange(0, Dh) 340 | mask = row_offsets < N 341 | 342 | # Load PRE_BLOCK_SIZE_ROW rows of O 343 | O_ptr += index_BH * stride_O_H # moves O_ptr to the correct batch & head for this pid. 344 | O_offsets = row_offsets[:, None] * stride_O_N + col_offsets[None, :] * stride_O_Dh 345 | O = tl.load(O_ptr + O_offsets, mask = mask[:, None], other=0.) # shape (PRE_BLOCK_SIZE_ROW, D) 346 | 347 | # Load PRE_BLOCK_SIZE_ROW rows of dLdO 348 | dLdO_ptr += index_BH * stride_dLdO_H 349 | dLdO_offsets = row_offsets[:, None] * stride_dLdO_N + col_offsets[None, :] * stride_dLdO_Dh 350 | dLdO = tl.load(dLdO_ptr + dLdO_offsets, mask = mask[:, None], other=0.) # shape (PRE_BLOCK_SIZE_ROW, D) 351 | 352 | # Delta is the dot product of O and dLdO along Dh, giving us a single scalar Delta_i per token in N 353 | # it will be useful in later parts of the backward pass 354 | Delta = tl.sum(dLdO.to(tl.float32) * O.to(tl.float32), axis=1) # shape (PRE_BLOCK_SIZE_ROW) 355 | Delta_ptr += index_BH * stride_Delta_H 356 | tl.store(Delta_ptr + row_offsets, Delta, mask = mask) 357 | 358 | 359 | @triton.jit 360 | def _attn_backward_KV( 361 | K, V, dLdK, dLdV, # shape (BLOCK_SIZE_COL, D) 362 | Q_ptr, dLdO_ptr, 363 | LSE_ptr, Delta_ptr, 364 | stride_N, stride_Dh, 365 | H, N, Dh: tl.constexpr, 366 | BLOCK_SIZE_ROW: tl.constexpr, # no more _1 because this sub-kernel is the _1 367 | BLOCK_SIZE_COL: tl.constexpr, 368 | start_ROW, start_COL, num_steps, 369 | scale, ln2: tl.constexpr, rln2: tl.constexpr, 370 | MASK: tl.constexpr 371 | ): 372 | """ 373 | this sub-kernel will be looking at a specific chunk of K & V , where we call 374 | the sequence length of K & V the columns of our NxN attention matrix, 375 | and iterating through rows of Q's sequence length to calculate that 376 | chunk of dLdK and dLdV 377 | N of K & V 378 | | | | | | 379 | | | | | | 380 | N of Q | | | | | 381 | | | | | | 382 | \|/ \|/ \|/ \|/ \|/ 383 | arrows indicate direction of this pid's for loop; each arrow is a different PID 384 | """ 385 | offsets_ROW = start_ROW + tl.arange(0, BLOCK_SIZE_ROW) 386 | offsets_COL = start_COL + tl.arange(0, BLOCK_SIZE_COL) 387 | offsets_Dh = tl.arange(0, Dh) 388 | 389 | # we transpose Q while loading it rather than in a separate kernel 390 | Q_T_offsets = offsets_Dh[:, None] * stride_Dh + offsets_ROW[None, :] * stride_N 391 | dLdO_offsets = offsets_ROW[:, None] * stride_N + offsets_Dh[None, :] * stride_Dh 392 | 393 | for block_idx in range(num_steps): 394 | # we load M before computing S to reduce pipeline stall (and dLdO before computing dLdV) 395 | # meaning the Triton compiler can have an easier time doing the loading of M 396 | # and the dot product of K and QT simultaneously. in general you should load a bunch 397 | # of stuff then calc a bunch of stuff rather than flipping b/w loads and calcs 398 | mask_N = offsets_ROW < N 399 | Q_T = tl.load(Q_ptr + Q_T_offsets, mask=mask_N[None, :], other=0.) # shape (Dh, BLOCK_SIZE_ROW) 400 | LSE = tl.load(LSE_ptr + offsets_ROW, mask=mask_N, other=0.) # shape (BLOCK_SIZE_ROW) 401 | dLdO = tl.load(dLdO_ptr + dLdO_offsets, mask=mask_N[:, None], other=0.) # shape (BLOCK_SIZE_ROW, Dh) 402 | Delta = tl.load(Delta_ptr + offsets_ROW, mask=mask_N, other=0.) # shape (BLOCK_SIZE_ROW) 403 | # ^notice the order we load these in is based on the order we use them below 404 | 405 | # we'll re-calculate transpose of S and P matrices since doing that here is faster & more importantly 406 | # cheaper on memory consumption than if we were to have saved them in our forward pass & read them here 407 | S_T = tl.dot(K, Q_T) # shape (BLOCK_SIZE_COL, BLOCK_SIZE_ROW) 408 | # no scale here because the operation is associative so we did it earlier on K 409 | # thanks to masking of K & Q_T, the non-existent out-of-bounds tokens look like a bunch 410 | # of zeros hugged up against the bottom and right edges of S_T 411 | # subtract S_T by the logsumexp then exponentiate to get P_T 412 | P_T = tl.exp2(S_T - LSE[None, :]) # shape (BLOCK_SIZE_COL, BLOCK_SIZE_ROW) 413 | # this derivative actually requires an extra *ln(2) which we do below at dLdS_T 414 | # the non-existent tokens that were a bunch of 0's are now a bunch of 1's 415 | 416 | if MASK: # if we're on the block-diagonal 417 | # implement a lower-triangular mask. it looks like upper-triangular because we've 418 | # transposed, which is also the reason why our columns & rows are reversed 419 | mask = (offsets_COL[:, None] <= offsets_ROW[None, :]) # (BLOCK_SIZE_COL, BLOCK_SIZE_ROW) 420 | P_T = tl.where(mask, P_T, 0.) 421 | 422 | # compute dLdV 423 | dLdV = tl.dot(P_T, dLdO, acc=dLdV) # shape (BLOCK_SIZE_COL, Dh) 424 | 425 | # compute dLdP_T and dLdS_T to get dLdK 426 | dLdP_T = tl.dot(V, tl.trans(dLdO)) # shape (BLOCK_SIZE_COL, BLOCK_SIZE_ROW) 427 | dLdS_T = (P_T * (dLdP_T - Delta[None, :]) * ln2) # shape (BLOCK_SIZE_COL, BLOCK_SIZE_ROW) 428 | dLdK = tl.dot(dLdS_T, tl.trans(Q_T), acc=dLdK) # shape (BLOCK_SIZE_COL, D) 429 | # acc tells the tl.dot to accumulate into dLdK 430 | 431 | # increment pointers 432 | offsets_ROW += BLOCK_SIZE_ROW 433 | Q_ptr += BLOCK_SIZE_ROW * stride_N 434 | dLdO_ptr += BLOCK_SIZE_ROW * stride_N 435 | 436 | return dLdK, dLdV 437 | 438 | 439 | @triton.jit 440 | def _attn_backward_Q( 441 | dLdQ, Q, dLdO, LSE, 442 | K_ptr, V_ptr, Delta_ptr, 443 | stride_N, stride_Dh, 444 | H, N, Dh: tl.constexpr, 445 | BLOCK_SIZE_ROW: tl.constexpr, 446 | BLOCK_SIZE_COL: tl.constexpr, 447 | start_ROW, start_COL, num_steps, 448 | scale, ln2: tl.constexpr, rln2: tl.constexpr, 449 | MASK: tl.constexpr 450 | ): 451 | """ 452 | this sub-kernel will be looking at a specific chunk of Q and iterating through 453 | rows of K & V to calculate that chunk of dLdQ 454 | I say "rows" of K and V but really we refer to them as colums since we're thinking 455 | not in terms of the (B, H, N, D) shaped matrices but rather the (B, H, N, N) shaped 456 | attention logits, where the first N are split up by "BLOCK_SIZE_ROW" and the second N 457 | is split up by "BLOCK_SIZE_COL" 458 | N of K & V 459 | -------------------> 460 | -------------------> 461 | N of Q -------------------> 462 | -------------------> 463 | -------------------> 464 | arrows indicate direction of this pid's for loop; each arrow is a different PID 465 | """ 466 | offsets_ROW = start_ROW + tl.arange(0, BLOCK_SIZE_ROW) 467 | offsets_COL = start_COL + tl.arange(0, BLOCK_SIZE_COL) 468 | offsets_Dh = tl.arange(0, Dh) 469 | 470 | # we transpose V while loading it 471 | K_and_V_T_offsets = offsets_Dh[:, None] * stride_Dh + offsets_COL[None, :] * stride_N 472 | 473 | Delta = tl.load(Delta_ptr + offsets_ROW, mask=offsets_ROW= offsets_COL[None, :]) # (BLOCK_SIZE_ROW, BLOCK_SIZE_COL) 487 | # setting lower-triangular values to zero since the gradient is upper-triangular 488 | P = tl.where(mask, P, 0.) # shape (BLOCK_SIZE_ROW, BLOCK_SIZE_COL) 489 | 490 | # calc dLdP and dLdS to get dLdQ 491 | dLdP = tl.dot(dLdO, V_T) # shape (BLOCK_SIZE_ROW, BLOCK_SIZE_COL) 492 | dLdS = (P * (dLdP - Delta[:, None]) * ln2) # shape (BLOCK_SIZE_ROW, BLOCK_SIZE_COL) 493 | # ^this line is equivalent to: 494 | #weighted_dLdP = tl.sum(dLdP * P, axis=1) # row-sum over keys 495 | #dLdS = P * (dLdP - weighted_dLdP[:, None]) 496 | # but trades-off a memory access for a binary & then a reduction op 497 | dLdQ += tl.dot(dLdS, tl.trans(K_T)) # shape (BLOCK_SIZE_ROW, Dh) 498 | # we'll need to de-sdcale dLdQ in the end because K_T was pre-scaled 499 | # we do it later instead of now bc now would mean num_steps * flops versus just flops 500 | 501 | # increment pointers 502 | offsets_COL += BLOCK_SIZE_COL 503 | K_ptr += BLOCK_SIZE_COL * stride_N 504 | V_ptr += BLOCK_SIZE_COL * stride_N 505 | 506 | return dLdQ 507 | 508 | 509 | @triton.autotune( 510 | [ 511 | triton.Config({"BLOCK_SIZE_MACRO": BLOCK_SIZE_MACRO, "BLOCK_SIZE_MICRO": BLOCK_SIZE_MICRO}, 512 | num_stages=num_stages, num_warps=num_warps,) 513 | for BLOCK_SIZE_MICRO in [16]#, 32, 64] 514 | for BLOCK_SIZE_MACRO in [32]#, 64, 128] 515 | for num_stages in [3]#, 5, 7] 516 | for num_warps in [4]#, 8, 16] 517 | if BLOCK_SIZE_MACRO > BLOCK_SIZE_MICRO # could do >= but i wanna get mileage out of the loop code we wrote 518 | ], 519 | key=["Dh"], 520 | ) 521 | @triton.jit 522 | def attn_backward( 523 | Q_ptr, K_ptr, V_ptr, 524 | dLdO_ptr, dLdQ_ptr, dLdK_ptr, dLdV_ptr, 525 | LSE_ptr, Delta_ptr, 526 | scale, 527 | stride_B, stride_H, stride_N, stride_Dh, 528 | H, N, Dh: tl.constexpr, 529 | BLOCK_SIZE_MICRO: tl.constexpr, # 530 | BLOCK_SIZE_MACRO: tl.constexpr, # 531 | ): 532 | # we'll use these constants later on Q 533 | ln2: tl.constexpr = 0.6931471824645996 # = ln(2), natural logarithm of 2 534 | rln2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2), the reciprocal of the natural logarithm of 2 535 | # generally defining a known constant as an approximation of itself to some number of digits 536 | # is more efficient than calculating the actual value every time 537 | 538 | # move pointers of (B, H, N, D) matrices to get to the correct batch and head 539 | idx_batch_head = tl.program_id(1) 540 | idx_batch = idx_batch_head // H 541 | idx_head = idx_batch_head % H 542 | batch_head_jump = idx_batch * stride_B + idx_head * stride_H 543 | Q_ptr += batch_head_jump 544 | K_ptr += batch_head_jump 545 | V_ptr += batch_head_jump 546 | dLdO_ptr += batch_head_jump 547 | dLdQ_ptr += batch_head_jump 548 | dLdK_ptr += batch_head_jump 549 | dLdV_ptr += batch_head_jump 550 | 551 | # move pointers of (B, H, N) matrices to get to the correct batch and head 552 | batch_head_jump = idx_batch_head * N 553 | LSE_ptr += batch_head_jump 554 | Delta_ptr += batch_head_jump 555 | 556 | # BLOCK_SIZE_MACRO must be a multiple of BLOCK_SIZE_MICRO 557 | # because we use them to determine num_steps and don't want a remainder 558 | tl.static_assert(BLOCK_SIZE_MACRO % BLOCK_SIZE_MICRO == 0) 559 | 560 | ### STAGE 1: First we'll do dLdK and dLdV 561 | # in the fwd loop we held a block of Q in SRAM and iterated through K & V; here we'll do the opposite 562 | 563 | # ROW and COL refer to the N dimension of (Q & O) and (K & V) respectively 564 | # for stage 1 each PID will look at BLOCK_SIZE_MACRO tokens of K & V and there will 565 | # be an inner for-loop iterating over BLOCK_SIZE_MICRO tokens of Q at a time 566 | BLOCK_SIZE_ROW_1: tl.constexpr = BLOCK_SIZE_MICRO 567 | BLOCK_SIZE_COL_1: tl.constexpr = BLOCK_SIZE_MACRO 568 | 569 | # first we'l do the gradients along the block diagonal since they get treated differently 570 | # in that they have an triangular causal mask 571 | pid = tl.program_id(0) 572 | start_COL = pid * BLOCK_SIZE_COL_1 573 | start_ROW = start_COL 574 | num_steps = BLOCK_SIZE_COL_1 // BLOCK_SIZE_ROW_1 575 | 576 | # load K & V 577 | offsets_COL_1 = start_COL + tl.arange(0, BLOCK_SIZE_COL_1) 578 | offsets_Dh = tl.arange(0, Dh) 579 | KV_offsets = offsets_COL_1[:, None] * stride_N + offsets_Dh[None, :] * stride_Dh 580 | KV_mask = (offsets_COL_1[:, None] < N) # to avoid out-of-bounds non-existent tokens 581 | K = tl.load(K_ptr + KV_offsets, mask=KV_mask, other=0.) # shape (BLOCK_SIZE_COL_1, Dh) 582 | V = tl.load(V_ptr + KV_offsets, mask=KV_mask, other=0.) # shape (BLOCK_SIZE_COL_1, Dh) 583 | 584 | # pre-scaling K allows us to do the multiplication once here as opposed to 585 | # num_steps times inside _attn_backward_KV 586 | # we also scale by rln2 to account for the derivative of tl.exp2() and do it 587 | # here instead of inside _attn_backward_KV for the same reason 588 | K *= scale * rln2 589 | 590 | # we'll accumulate the gradients into these 591 | dLdK = tl.zeros([BLOCK_SIZE_COL_1, Dh], dtype=tl.float32) 592 | dLdV = tl.zeros([BLOCK_SIZE_COL_1, Dh], dtype=tl.float32) 593 | 594 | # compute dLdK and dLdV portions along the blocked diagonal 595 | dLdK, dLdV = _attn_backward_KV( 596 | K, V, dLdK, dLdV, 597 | Q_ptr, dLdO_ptr, LSE_ptr, Delta_ptr, 598 | stride_N, stride_Dh, 599 | H, N, Dh, 600 | BLOCK_SIZE_ROW_1, BLOCK_SIZE_COL_1, 601 | start_ROW, start_COL, num_steps, 602 | scale, ln2, rln2, 603 | MASK=True 604 | ) 605 | 606 | # next we'll do all the blocks that don't need the triangular mask on the block-diagonal. 607 | # this moves us forward to get off of the block-diagonal 608 | start_ROW += BLOCK_SIZE_COL_1 609 | # start_COL doesn't change since that's determined by our PID 610 | # then we calculate how many blocks need to be done. 611 | # this adjustment to N accounts for sequence lengths that are not clean multiples of BLOCK_SIZE_COL_1 612 | N_adj = tl.cdiv(N, BLOCK_SIZE_COL_1) * BLOCK_SIZE_COL_1 613 | num_steps = (N_adj - start_ROW) // BLOCK_SIZE_ROW_1 614 | 615 | # compute dLdK and dLdV for non-masked blocks 616 | dLdK, dLdV = _attn_backward_KV( 617 | K, V, dLdK, dLdV, 618 | Q_ptr, dLdO_ptr, LSE_ptr, Delta_ptr, 619 | stride_N, stride_Dh, 620 | H, N, Dh, 621 | BLOCK_SIZE_ROW_1, BLOCK_SIZE_COL_1, 622 | start_ROW, start_COL, num_steps, 623 | scale, ln2, rln2, 624 | MASK=False ### 625 | ) 626 | 627 | # scale since we didn't do it inside _attn_backward_KV to save flops 628 | dLdK *= scale * rln2 629 | # write back dLdK and dLdV 630 | tl.store(dLdK_ptr + KV_offsets, dLdK, mask=KV_mask) 631 | tl.store(dLdV_ptr + KV_offsets, dLdV, mask=KV_mask) 632 | 633 | ### STAGE 2: Now we do dLdQ 634 | # in this part, like the forward pass we look at a specific block of Q & iterate through K & V 635 | 636 | # ROW and COL refer to the N dimension of Q and K/V respectively 637 | # for stage 1 each PID will look at BLOCK_SIZE_MACRO tokens of K & V 638 | # and there will be an inner for loop iterating over BLOCK_SIZE_MICRO tokens of Q 639 | BLOCK_SIZE_ROW_2: tl.constexpr = BLOCK_SIZE_MACRO 640 | BLOCK_SIZE_COL_2: tl.constexpr = BLOCK_SIZE_MICRO 641 | 642 | # we again start off doing the block-diagonal 643 | start_ROW = pid * BLOCK_SIZE_ROW_2 644 | start_COL = start_ROW 645 | num_steps = BLOCK_SIZE_ROW_2 // BLOCK_SIZE_COL_2 646 | # ^this is number of steps for a single block, aka the blocks on the diagonal 647 | 648 | offsets_ROW = start_ROW + tl.arange(0, BLOCK_SIZE_ROW_2) 649 | QO_offsets = offsets_ROW[:, None] * stride_N + offsets_Dh[None, :] * stride_Dh 650 | mask_ROW = offsets_ROW < N 651 | Q = tl.load(Q_ptr + QO_offsets, mask=mask_ROW[:, None], other=0.) # shape (BLOCK_SIZE_ROW_2, Dh) 652 | Q *= scale * rln2 653 | dLdO = tl.load(dLdO_ptr + QO_offsets, mask=mask_ROW[:, None], other=0.) # shape (BLOCK_SIZE_ROW_2, Dh) 654 | LSE = tl.load(LSE_ptr + offsets_ROW, mask=mask_ROW, other=0.)[:, None] # shape (BLOCK_SIZE_ROW_2, 1) 655 | 656 | # accumulate the gradients into here 657 | dLdQ = tl.zeros([BLOCK_SIZE_ROW_2, Dh], dtype=tl.float32) 658 | 659 | # compute dQ for blocks on the diagonal 660 | dLdQ = _attn_backward_Q( 661 | dLdQ, Q, dLdO, LSE, 662 | K_ptr, V_ptr, Delta_ptr, 663 | stride_N, stride_Dh, 664 | H, N, Dh, 665 | BLOCK_SIZE_ROW_2, BLOCK_SIZE_COL_2, 666 | start_ROW, start_COL, num_steps, 667 | scale, ln2, rln2, 668 | MASK=True 669 | ) 670 | 671 | # now we'll do the parts that are not on the block-diagonal 672 | end_COL = start_COL 673 | start_COL = 0 #end_COL - num_steps * BLOCK_SIZE_COL_2 # could just call it 0 lmao 674 | num_steps = end_COL // BLOCK_SIZE_COL_2 675 | dLdQ = _attn_backward_Q( 676 | dLdQ, Q, dLdO, LSE, 677 | K_ptr, V_ptr, Delta_ptr, 678 | stride_N, stride_Dh, 679 | H, N, Dh, 680 | BLOCK_SIZE_ROW_2, BLOCK_SIZE_COL_2, 681 | start_ROW, start_COL, num_steps, 682 | scale, ln2, rln2, 683 | MASK=False 684 | ) 685 | dLdQ *= scale * rln2 686 | tl.store(dLdQ_ptr + QO_offsets, dLdQ, mask=mask_ROW[:, None]) 687 | 688 | 689 | class _flashattention(torch.autograd.Function): 690 | 691 | @staticmethod 692 | def forward(ctx, q, k, v, scale): 693 | assert q.shape == k.shape == v.shape 694 | assert q.shape[-1] <= 128, \ 695 | f'flash attention only supports head dimension of 128 less but got {q.shape[-1]}' 696 | # the kernel actually isn't this limited but too much larger and i think it might overwhelm SRAM 697 | B, H, N, Dh = q.shape 698 | assert q.device == k.device and q.device == v.device 699 | assert q.dtype == k.dtype == v.dtype == torch.float32 700 | 701 | # pre-allocate output tensor 702 | O = torch.empty_like(q) # output tensor will be pre head concatenation and mixing 703 | # and pre-allocate the tensor where we hold the logsumexp 704 | LSE = torch.empty((B, H, N), device=q.device, dtype=torch.float32) 705 | 706 | grid = lambda args: ( 707 | triton.cdiv(N, args["BLOCK_SIZE_QO"]), # primary parallelizatoin is across sequence length 708 | B * H, # further parallelize across the dimensions that don't matter 709 | ) 710 | # notice the sequence dimension axis is first, and BH parallelization axis is second 711 | # this is because we want the former to have PIDs on the same SM 712 | 713 | """ 714 | imagine for a launch grid of (3, 2) wiwth 3 SMs that can each hold 2 PIDs 715 | we'd have PIDs: 716 | [0, 0] \ SM0 717 | [1, 0] / 718 | [2, 0] \ SM1 719 | [0, 1] / 720 | [1, 1] \ SM2 721 | [2, 1] / 722 | """ 723 | 724 | attn_fwd[grid]( 725 | q, k, v, O, LSE, 726 | scale, 727 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 728 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 729 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 730 | O.stride(0), O.stride(1), O.stride(2), O.stride(3), 731 | LSE.stride(0), LSE.stride(1), LSE.stride(2), 732 | B, H, N, Dh, 733 | ) 734 | 735 | ctx.save_for_backward(q, k, v, O, LSE) 736 | ctx.grid = grid 737 | ctx.B, ctx.H, ctx.N, ctx.Dh = B, H, N, Dh 738 | ctx.scale = scale 739 | return O 740 | 741 | @staticmethod 742 | def backward(ctx, dLdO): 743 | q, k, v, O, LSE = ctx.saved_tensors 744 | grid = ctx.grid 745 | scale = ctx.scale 746 | B, H, N, Dh = ctx.B, ctx.H, ctx.N, ctx.Dh 747 | 748 | dLdq = torch.empty_like(q) # shape (B, H, N, Dh) 749 | dLdk = torch.empty_like(k) 750 | dLdv = torch.empty_like(v) 751 | 752 | dLdO = dLdO.contiguous() 753 | assert q.stride() == k.stride() == v.stride() == O.stride() == dLdO.stride() 754 | 755 | Delta = torch.empty_like(LSE) # shape (B, H, N) 756 | # the ordering of your grid matters because it determines which programs end up sharing the same SRAM 757 | pre_grid = lambda meta: (triton.cdiv(N, meta["PRE_BLOCK_SIZE_ROW"]), B * H) 758 | # in this case, we want the parallelizations along the N dimension to be near each other so they can 759 | # share data, while parallelization across batches & heads don't necessitate any sharing 760 | attn_backward_preprocess[pre_grid]( 761 | O, dLdO, Delta, 762 | O.stride(0), O.stride(1), O.stride(2), O.stride(3), 763 | dLdO.stride(0), dLdO.stride(1), dLdO.stride(2), dLdO.stride(3), 764 | Delta.stride(0), Delta.stride(1), Delta.stride(2), 765 | N, Dh, 766 | ) 767 | 768 | grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE_MACRO"]), B * H) 769 | attn_backward[grid]( 770 | q, k, v, 771 | dLdO, dLdq, dLdk, dLdv, 772 | LSE, Delta, 773 | scale, 774 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), # all tensors should share same stride 775 | H, N, Dh, 776 | ) 777 | 778 | return dLdq, dLdk, dLdv, None 779 | 780 | triton_attention = _flashattention.apply 781 | 782 | 783 | ######### Step 1 ######### 784 | def test_flashattention_kernel(B, H, N, Dh, device=DEVICE, atol=5e-3): 785 | # create data 786 | q = torch.randn((B, H, N, Dh), dtype=torch.float32, device=device, requires_grad=True) 787 | k = torch.randn((B, H, N, Dh), dtype=torch.float32, device=device, requires_grad=True) 788 | v = torch.randn((B, H, N, Dh), dtype=torch.float32, device=device, requires_grad=True) 789 | sm_scale = 1/math.sqrt(Dh) # idk why I made scale a parameter to be passed in, whatever too late now 790 | # forward pass 791 | tri_out = triton_attention(q, k, v, sm_scale) 792 | ref_out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True) 793 | 794 | """ 795 | # you could un-comment this if you want to visually analyze patterns in any errors 796 | import os 797 | import numpy as np 798 | import matplotlib.pyplot as plt 799 | # Convert to numpy arrays 800 | actual = tri_out.detach().cpu().numpy() 801 | expected = ref_out.detach().cpu().numpy() 802 | # Compute differences and masks 803 | abs_diff = np.abs(expected - actual) 804 | abs_fail_mask = (abs_diff > 1e-2).astype(np.int32) 805 | plt.figure(figsize=(8, 6)) 806 | plt.imshow(abs_fail_mask[0][0], cmap="hot", aspect="auto") 807 | plt.xlabel("Model/Head Dimension") 808 | plt.ylabel("Sequence Position") 809 | plt.colorbar() 810 | plt.savefig('./out_heatmap.png') 811 | plt.close() 812 | """ 813 | 814 | # compare 815 | torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0) 816 | print("passed fwd") 817 | 818 | # backward pass (triton) 819 | dLdout = 0.1 * torch.randn_like(q) 820 | tri_out.backward(dLdout, retain_graph=True) 821 | dLdq_tri, dLdk_tri, dLdv_tri = [_.grad.clone() for _ in [q, k, v]] 822 | q.grad, k.grad, v.grad = None, None, None 823 | # backward pass (torch) 824 | ref_out.backward(dLdout, retain_graph=True) 825 | dLdq_ref, dLdk_ref, dLdv_ref = [_.grad.clone() for _ in [q, k, v]] 826 | q.grad, k.grad, v.grad = None, None, None 827 | 828 | """ 829 | # you could un-comment this if you want to visually analyze patterns in any errors 830 | import os 831 | import numpy as np 832 | import matplotlib.pyplot as plt 833 | # dLdq Convert to numpy arrays 834 | actual = dLdq_ref.detach().cpu().numpy() 835 | expected = dLdq_tri.detach().cpu().numpy() 836 | # Compute differences and masks 837 | abs_diff = np.abs(expected - actual) 838 | abs_fail_mask = (abs_diff > atol).astype(np.int32) 839 | plt.figure(figsize=(8, 6)) 840 | plt.imshow(abs_fail_mask[0][0], cmap="hot", aspect="auto") 841 | plt.xlabel("Model/Head Dimension") 842 | plt.ylabel("Sequence Position") 843 | plt.colorbar() 844 | plt.savefig('./dLdq_out_heatmap.png') 845 | plt.close() 846 | # dLdk Convert to numpy arrays 847 | actual = dLdk_ref.detach().cpu().numpy() 848 | expected = dLdk_tri.detach().cpu().numpy() 849 | # Compute differences and masks 850 | abs_diff = np.abs(expected - actual) 851 | abs_fail_mask = (abs_diff > atol).astype(np.int32) 852 | plt.figure(figsize=(8, 6)) 853 | plt.imshow(abs_fail_mask[0][0], cmap="hot", aspect="auto") 854 | plt.xlabel("Model/Head Dimension") 855 | plt.ylabel("Sequence Position") 856 | plt.colorbar() 857 | plt.savefig('./dLdk_out_heatmap.png') 858 | plt.close() 859 | # dLdv Convert to numpy arrays 860 | actual = dLdv_ref.detach().cpu().numpy() 861 | expected = dLdv_tri.detach().cpu().numpy() 862 | # Compute differences and masks 863 | abs_diff = np.abs(expected - actual) 864 | abs_fail_mask = (abs_diff > atol).astype(np.int32) 865 | plt.figure(figsize=(8, 6)) 866 | plt.imshow(abs_fail_mask[0][0], cmap="hot", aspect="auto") 867 | plt.xlabel("Model/Head Dimension") 868 | plt.ylabel("Sequence Position") 869 | plt.colorbar() 870 | plt.savefig('./dLdv_out_heatmap.png') 871 | plt.close() 872 | """ 873 | 874 | # compare 875 | torch.testing.assert_close(dLdq_tri, dLdq_ref, atol=atol, rtol=0) 876 | torch.testing.assert_close(dLdk_tri, dLdk_ref, atol=atol, rtol=0) 877 | torch.testing.assert_close(dLdv_tri, dLdv_ref, atol=atol, rtol=0) 878 | print("Passed bwd") 879 | 880 | 881 | 882 | # vary seq length for fixed head and batch=4 883 | configs = [] 884 | for mode in ["fwd", "bwd"]: 885 | configs.append( 886 | triton.testing.Benchmark( 887 | x_names=["SEQ_LEN"], 888 | x_vals=[512 * i for i in range(1, 17)], # LOWER IF YOU DON'T HAVE ENOUGH RAM 889 | line_arg="provider", 890 | line_vals=["torch", 'this_tutorial'], 891 | line_names=[ 892 | "torch.nn.functional.scaled_dot_product_attention", 893 | "This tutorial's implementation" 894 | ], 895 | styles=[("red", "-"), ("blue", "-")], 896 | ylabel="TFLOPS", 897 | plot_name=f"attention-performance-{mode}", 898 | args={"mode": mode}, 899 | )) 900 | 901 | @triton.testing.perf_report(configs) 902 | def bench_flash_attention(SEQ_LEN, mode, provider, device=DEVICE): 903 | assert mode in ["fwd", "bwd"] 904 | dtype = torch.float32 905 | BATCH, N_HEADS = 32, 4 # LOWER THESE IF YOU DON'T HAVE ENOUGH RAM 906 | HEAD_DIM = 128 # AND THIS IF YOU DON"T HAVE ENOUGH SRAM 907 | q = torch.randn((BATCH, N_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 908 | k = torch.randn((BATCH, N_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 909 | v = torch.randn((BATCH, N_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 910 | sm_scale = 1 / math.sqrt(HEAD_DIM) 911 | if provider == 'torch': 912 | fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True) 913 | if provider == 'this_tutorial': 914 | fn = lambda: triton_attention(q, k, v, sm_scale) 915 | if mode == "bwd": 916 | O = fn() 917 | dLdO = torch.randn_like(O) 918 | fn = lambda: O.backward(dLdO, retain_graph=True) 919 | ms = triton.testing.do_bench(fn) 920 | flops_per_matmul = 2.0 * BATCH * N_HEADS * SEQ_LEN * SEQ_LEN * HEAD_DIM 921 | total_flops = 2 * flops_per_matmul * 0.5 922 | if mode == "bwd": 923 | total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) 924 | return total_flops * 1e-12 / (ms * 1e-3) 925 | 926 | if __name__ == "__main__": 927 | # always run unit-tests 928 | test_flashattention_kernel(1, 1, 128, 32) # without block masking 929 | test_flashattention_kernel(1, 1, 128, 64) # without block masking 930 | test_flashattention_kernel(1, 1, 128, 128) # without block masking 931 | test_flashattention_kernel(32, 8, 69, 128) # with block masking 932 | 933 | # Only run benchmark if explicitly requested 934 | import sys 935 | if len(sys.argv) > 1 and sys.argv[1] == "--benchmark": 936 | bench_flash_attention.run(save_path='.', print_data=True) -------------------------------------------------------------------------------- /10_CEloss_project/README.md: -------------------------------------------------------------------------------- 1 | # live-deriving Apple's cut cross-entropy loss 2 | As I've said multiple times throughout these lessons, watching my videos and copying my code line by line does not count as learning, or as having completed this course. Before you can claim that you know how to write GPU kernels, you need to go actually design and write one (or many, preferably many) from first principles. 3 | 4 | For that reason, this is not a "lesson" in the way that the prior 9 were. Rather, it's a demonstration of an example project that you should do to test your knowledge. One alternative way to think about them while you watch is that you, the viewer, are a prospective employer who just asked me a very hard job interview question and I'm now attempting to answer it to demonstrate my skills. 5 | 6 | In the first video I start with a goal derived from a vague memory of when I skimmed through Apple's cut cross-entropy loss [paper](https://arxiv.org/abs/2411.09009) months before having ever even written my first GPU kernel. Using that loose starting intuition I work out a plan for what will hopefully be a relatively efficient fused CE Loss kernel from first principles. Then for the second video I attempt to take that plan and put it into action. 7 | 8 | Will it even run? Was there something I wasn't accounting for? Will it be as fast (or hopefully faster than) PyTorch? Does it even resemble Apple's algorithm or does it take a different route entirely? Idk, this was all live off-the-cuff so we'll see. 9 | 10 | *the answer is that I could not in the length of the two videos get it working. if i feel like putting more effort into making it actually run (and even more to making it actually fast) then i'll update this readme accordingly and maybe even make a third explanatory video. but for now it doesn't work* 11 | 12 | [![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/7ox3HhC-j3E/0.jpg)](https://www.youtube.com/watch?v=7ox3HhC-j3E) 13 | [![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/JPsQP8TFTDM/0.jpg)](https://www.youtube.com/watch?v=JPsQP8TFTDM) 14 | 15 | BTW, if you're looking for an idea for a project to do, maybe try to not only fix this kernel but fuse the forward & backward pass into one single kernel. What I mean by that is instead of outputting the final loss value, you can have your kernel just skip that step and go straight to outputting $\frac{\partial L}{\partial x}$ and $\frac{\partial L}{\partial E}$. Then, if someone were to use your new kernel, instead of doing the traditional .backward() on the final loss value, they'd actually do it on x and manually accumulate to E using the gradients you gave them. I've not actually done this myself, but I'm vaguely under the impression that this is what [Liger Kernel](https://github.com/linkedin/Liger-Kernel) does from one time when I skimmed a few lines of the wrapper function around their kernel. -------------------------------------------------------------------------------- /10_CEloss_project/celoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | import math 5 | 6 | DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}') 7 | 8 | #import os 9 | #os.environ["TRITON_INTERPRET"] = "1" 10 | 11 | def naive_CELoss(x, E, targets): 12 | # Compute logits: (B, N, D) @ (D, V) -> (B, N, V) 13 | logits = x @ E 14 | 15 | # Reshape logits to (B*N, V) for cross entropy 16 | B, N, _ = x.shape 17 | V = E.shape[1] 18 | logits_2d = logits.reshape(-1, V) 19 | 20 | # Reshape targets to (B*N) for cross entropy 21 | targets_1d = targets.reshape(-1) 22 | 23 | # Compute cross entropy loss using log-softmax directly 24 | # 1. Apply log-softmax to logits (with numerical stability) 25 | max_logits, _ = torch.max(logits_2d, dim=1, keepdim=True) # (B*N) 26 | logits_shifted = logits_2d - max_logits # (B*N, V) 27 | log_sum_exp = torch.log(torch.sum(torch.exp(logits_shifted), dim=1, keepdim=True)) + max_logits # (B*N, V) 28 | log_softmax = logits_2d - log_sum_exp # (B*N, V) 29 | 30 | # 2. Get negative log probabilities of target classes (NLL loss) 31 | nll = -log_softmax[torch.arange(log_softmax.size(0), device=targets_1d.device), targets_1d] 32 | #nll = -log_softmax[:, targets_1d] 33 | 34 | # 3. Average the loss 35 | loss = torch.mean(nll) 36 | 37 | return loss 38 | 39 | 40 | @triton.autotune( # decorator figures out what meta-parameters will be most efficient 41 | [ 42 | triton.Config( 43 | {"bsN": bsN, "bsD": bsD, "bsV": bsV}, 44 | num_stages=num_stages, num_warps=num_warps, 45 | ) 46 | for bsN in [16]#, 32, 64, 128] 47 | for bsD in [16]#, 32, 64, 128] 48 | for bsV in [16]#, 32, 64, 128] 49 | for num_stages in [3]#, 5, 7] 50 | for num_warps in [4]#, 8, 16] 51 | ], 52 | key=["N", "D", "V"], 53 | ) 54 | @triton.jit 55 | def fused_CELoss_kernel( 56 | x_ptr, E_ptr, targets_ptr, out_ptr, 57 | stride_x_B, stride_x_N, stride_x_D, 58 | stride_E_D, stride_E_V, 59 | stride_tar_B, stride_tar_N, 60 | stride_out_B, stride_out_N, 61 | B, N, D: tl.constexpr, V: tl.constexpr, 62 | bsN: tl.constexpr, bsD: tl.constexpr, bsV: tl.constexpr, 63 | ): 64 | pid = tl.program_id(0) 65 | offsets_N = tl.arange(0, bsN) 66 | offsets_V = tl.arange(0, bsV) 67 | 68 | x_ptr += pid * bsN * stride_x_B 69 | targets_ptr += pid * bsN * stride_tar_N 70 | out_ptr += pid * bsN * stride_out_N 71 | 72 | M = tl.full((bsN,), value=-1e6, dtype=tl.float32) 73 | denominator = tl.full((bsN,), value=1.0, dtype=tl.float32) 74 | numerator_selected = tl.zeros((bsN,), dtype=tl.float32) 75 | 76 | targets = tl.load(targets_ptr + offsets_N * stride_tar_N).to(tl.int32) # (bsN) 77 | 78 | # moves along V dimension of (B, N, V) logits computing live softmax 79 | for block_start_outer in range(0, V, bsV): 80 | 81 | logits = tl.zeros((bsN, bsV), dtype=tl.float32) 82 | offsets_D = tl.arange(0, bsD) 83 | 84 | # moves along D dimension of (B*N, D) @ (D, V) computing matmul 85 | for block_start_inner in range(0, D, bsD): 86 | # load blocks of x and E shape (bsN, bsD) and (bsD, bsV) respectively 87 | x_offsets = offsets_N[:, None] * stride_x_N + offsets_D[None, :] * stride_x_D 88 | E_offsets = offsets_D[:, None] * stride_E_D + offsets_V[None, :] * stride_E_V 89 | x = tl.load(x_ptr + x_offsets) 90 | E = tl.load(E_ptr + E_offsets) 91 | logits = tl.dot(x, E, acc=logits) # shape (bsN, bsV) 92 | 93 | offsets_D += bsD 94 | offsets_V += bsV 95 | 96 | # find max of logits 97 | M_new = tl.maximum(M, tl.max(logits, axis=1)) # (bsN) 98 | # use logits & its max to do live softmax 99 | logits_shifted = logits - M_new[:, None] # (bsN, bsV) 100 | numerator = tl.exp(logits_shifted) # (bsN, bsV) 101 | alpha = tl.exp(M - M_new) # (bsN) 102 | denominator_new = tl.sum(numerator, axis=1) 103 | denominator = denominator * alpha + denominator_new # (bsN) 104 | 105 | # need to use targets to select values from numerator when applicable 106 | #targets_mask = (targets >= block_start_outer) & (targets < block_start_outer + bsV) # (bsN) 107 | targets_adj = targets - block_start_outer # (bsN) 108 | # Only select the numerator for the target class in this block 109 | mask = tl.arange(0, bsV)[None, :] == targets_adj[:, None] # (bsN, bsV) 110 | numerator_selected += tl.sum(tl.where(mask, numerator, 0.), axis=1) 111 | 112 | M = M_new 113 | 114 | P = numerator_selected / denominator 115 | nll = - tl.log(P) 116 | 117 | tl.store(out_ptr + offsets_N * stride_out_N, nll) 118 | 119 | 120 | def fused_CELoss(x, E, targets): 121 | assert x.shape[-1] == E.shape[0] 122 | B, N, D = x.shape 123 | _, V = E.shape 124 | 125 | # pre-allocate output 126 | out = torch.empty((B, N), dtype=torch.float32, device=x.device) 127 | 128 | grid = lambda meta: (triton.cdiv(B*N, meta['bsN']),) 129 | 130 | fused_CELoss_kernel[grid]( 131 | x, E, targets, out, 132 | x.stride(0), x.stride(1), x.stride(2), 133 | E.stride(0), E.stride(1), 134 | targets.stride(0), targets.stride(1), 135 | out.stride(0), out.stride(1), 136 | B, N, D, V, 137 | ) 138 | 139 | return torch.mean(out) 140 | 141 | def test_naiveCELoss(B, N, D, V, device=DEVICE, atol=1e-3): 142 | torch.cuda.empty_cache() 143 | assert V <= 32_768 144 | # create data 145 | x = torch.randn((B, N, D), dtype=torch.float32, device=device, requires_grad=False) 146 | E = torch.randn((D, V), dtype=torch.float32, device=device, requires_grad=False) 147 | targets = torch.randint(0, V, (B, N), device=device, requires_grad=False) 148 | # forward passes 149 | naive_loss = naive_CELoss(x, E, targets) 150 | logits = (x @ E).reshape(-1, V) 151 | targets_1d = targets.reshape(-1) 152 | ref_loss = torch.nn.functional.cross_entropy(logits, targets_1d) 153 | # compare 154 | torch.testing.assert_close(naive_loss, ref_loss, atol=atol, rtol=0) 155 | print(f"naive passed {V}") 156 | 157 | 158 | def test_fusedCELoss(B, N, D, V, device=DEVICE, atol=1e-3): 159 | torch.cuda.empty_cache() 160 | # create data 161 | x = torch.randn((B, N, D), dtype=torch.float32, device=device, requires_grad=False) 162 | E = torch.randn((D, V), dtype=torch.float32, device=device, requires_grad=False) 163 | targets = torch.randint(0, V, (B, N), device=device, requires_grad=False) 164 | # forward passes 165 | logits = (x @ E).reshape(-1, V) 166 | targets_1d = targets.reshape(-1) 167 | ref_loss = torch.nn.functional.cross_entropy(logits, targets_1d) 168 | tri_loss = fused_CELoss(x, E, targets) 169 | # compare 170 | torch.testing.assert_close(tri_loss, ref_loss, atol=atol, rtol=0) 171 | print(f"triton passed {V}") 172 | 173 | 174 | # vary seq length for fixed head and batch=4 175 | configs = [ 176 | triton.testing.Benchmark( 177 | x_names=["V"], 178 | x_vals=[2 ** i for i in range(10, 14)], # LOWER IF YOU DON'T HAVE ENOUGH RAM 179 | line_arg="provider", 180 | line_vals=[ 181 | "torch", 182 | 'triton' 183 | ], 184 | line_names=[ 185 | "torch.nn.functional.cross_entropy()", 186 | "Fused & sparse Triton implementation" 187 | ], 188 | styles=[ 189 | ("red", "-"), 190 | ("blue", "-") 191 | ], 192 | ylabel="TFLOPS", 193 | plot_name=f"CELoss-performance", 194 | args={}, 195 | ) 196 | ] 197 | @triton.testing.perf_report(configs) 198 | def bench_CELoss(V, provider, device=DEVICE): 199 | dtype = torch.float32 200 | B, N, D = 32, 1024, 384 # LOWER THESE IF YOU DON'T HAVE ENOUGH RAM 201 | x = torch.randn((B, N, D), dtype=dtype, device=device, requires_grad=False) 202 | E = torch.randn((D, V), dtype=dtype, device=device, requires_grad=False) 203 | targets = torch.randint(0, V, (B, N), device=device, requires_grad=False) 204 | if provider == 'torch': 205 | logits = (x @ E).reshape(-1, V) 206 | targets_1d = targets.reshape(-1) 207 | fn = lambda: torch.nn.functional.cross_entropy(logits, targets_1d) 208 | if provider == 'triton': 209 | fn = lambda: fused_CELoss(x, E, targets) 210 | 211 | # Calculate FLOPS: 212 | ms = triton.testing.do_bench(fn) 213 | # Matrix multiplication: 2*B*N*D*V (each element requires D multiplications and D-1 additions) 214 | # Softmax and CE loss operations: approximately 6*B*N*V 215 | total_flops = 2 * B * N * D * V + 6 * B * N * V 216 | return total_flops * 1e-12 / (ms * 1e-3) 217 | 218 | if __name__ == "__main__": 219 | # always run unit-tests 220 | #test_naiveCELoss(32, 1024, 384, 8192) 221 | #test_naiveCELoss(32, 1024, 384, 16_384) 222 | #test_naiveCELoss(32, 1024, 384, 32_768) 223 | 224 | test_fusedCELoss(32, 1024, 384, 32_768) 225 | #test_fusedCELoss(32, 1024, 384, 65_536) 226 | #test_fusedCELoss(32, 1024, 384, 131_072) 227 | #test_fusedCELoss(32, 1024, 384, 262_144) 228 | 229 | # Only run benchmark if explicitly requested 230 | import sys 231 | if len(sys.argv) > 1 and sys.argv[1] == "--benchmark": 232 | bench_CELoss.run(save_path='.', print_data=True) 233 | 234 | 235 | 236 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Evin Tunador 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Triton Tutorials 2 | making the [official triton documentation tutorials](https://triton-lang.org/main/getting-started/tutorials/index.html) actually comprehensible by *heavily* commenting in-detail about every little thing that's happening. Follow them in order of filename and check out the accompanying videos: 3 | 4 | [![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/TUQAyCNxFe4/0.jpg)](https://youtube.com/playlist?list=PLPefVKO3tDxOJLAmCA75uShbe1z_RNqkQ&si=C5VF9fNW8CYZzh9x) 5 | 6 | *Note:* these tutorials were all tested and benchmarked on an Nvidia RTX 4060 Ti. On different GPUs your mileage may vary, and on GPUs with less VRAM or SRAM you may even receive errors. I've also found older GPUs running the exact same Triton code to get incorrect values (eg. RTX 3090) so I recommend using at least a 40 series 7 | 8 | ## learning resources I used 9 | - of course the [official Triton documentation](https://triton-lang.org/main/getting-started/tutorials/index.html) 10 | - [here](https://github.com/hkproj/triton-flash-attention)'s a flash-attention implementation by one of my fav youtubers that comes with an [8 hour video](https://www.youtube.com/watch?v=zy8ChVd_oTM&t=1s) 11 | - and the original flash-attention papers [v1](https://arxiv.org/abs/2205.14135) & [v2](https://arxiv.org/abs/2307.08691) (you only really need v2) 12 | - [here](https://github.com/gpu-mode/lectures/tree/main 13 | )'s a wider set of GPU kernel guides that includes an intro to Triton in lesson 14 14 | 15 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evintunador/triton_docs_tutorials/6e465b838f0b5abc2d26abeb2a8782c023e84b86/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.6.0 2 | triton==3.2.0 3 | numpy 4 | matplotlib 5 | pandas 6 | pytest --------------------------------------------------------------------------------